1/*
2 * Copyright (c) 2011, 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.phases.common;
24
25import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
26import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
27import static org.graalvm.word.LocationIdentity.any;
28
29import java.util.EnumSet;
30import java.util.Iterator;
31import java.util.List;
32
33import org.graalvm.compiler.core.common.cfg.Loop;
34import org.graalvm.compiler.debug.DebugCloseable;
35import org.graalvm.compiler.graph.Graph.NodeEventScope;
36import org.graalvm.compiler.graph.Node;
37import org.graalvm.compiler.nodes.AbstractBeginNode;
38import org.graalvm.compiler.nodes.AbstractMergeNode;
39import org.graalvm.compiler.nodes.FixedNode;
40import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
41import org.graalvm.compiler.nodes.LoopBeginNode;
42import org.graalvm.compiler.nodes.LoopEndNode;
43import org.graalvm.compiler.nodes.LoopExitNode;
44import org.graalvm.compiler.nodes.PhiNode;
45import org.graalvm.compiler.nodes.ReturnNode;
46import org.graalvm.compiler.nodes.StartNode;
47import org.graalvm.compiler.nodes.StructuredGraph;
48import org.graalvm.compiler.nodes.ValueNodeUtil;
49import org.graalvm.compiler.nodes.calc.FloatingNode;
50import org.graalvm.compiler.nodes.cfg.Block;
51import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
52import org.graalvm.compiler.nodes.cfg.HIRLoop;
53import org.graalvm.compiler.nodes.memory.FloatableAccessNode;
54import org.graalvm.compiler.nodes.memory.FloatingAccessNode;
55import org.graalvm.compiler.nodes.memory.FloatingReadNode;
56import org.graalvm.compiler.nodes.memory.MemoryAccess;
57import org.graalvm.compiler.nodes.memory.MemoryAnchorNode;
58import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
59import org.graalvm.compiler.nodes.memory.MemoryMap;
60import org.graalvm.compiler.nodes.memory.MemoryMapNode;
61import org.graalvm.compiler.nodes.memory.MemoryNode;
62import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
63import org.graalvm.compiler.nodes.memory.ReadNode;
64import org.graalvm.compiler.nodes.util.GraphUtil;
65import org.graalvm.compiler.phases.Phase;
66import org.graalvm.compiler.phases.common.util.HashSetNodeEventListener;
67import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
68import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.LoopInfo;
69import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
70import org.graalvm.util.Equivalence;
71import org.graalvm.util.EconomicMap;
72import org.graalvm.util.EconomicSet;
73import org.graalvm.util.UnmodifiableMapCursor;
74import org.graalvm.word.LocationIdentity;
75
76public class FloatingReadPhase extends Phase {
77
78    private boolean createFloatingReads;
79    private boolean createMemoryMapNodes;
80
81    public static class MemoryMapImpl implements MemoryMap {
82
83        private final EconomicMap<LocationIdentity, MemoryNode> lastMemorySnapshot;
84
85        public MemoryMapImpl(MemoryMapImpl memoryMap) {
86            lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT, memoryMap.lastMemorySnapshot);
87        }
88
89        public MemoryMapImpl(StartNode start) {
90            this();
91            lastMemorySnapshot.put(any(), start);
92        }
93
94        public MemoryMapImpl() {
95            lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT);
96        }
97
98        @Override
99        public MemoryNode getLastLocationAccess(LocationIdentity locationIdentity) {
100            MemoryNode lastLocationAccess;
101            if (locationIdentity.isImmutable()) {
102                return null;
103            } else {
104                lastLocationAccess = lastMemorySnapshot.get(locationIdentity);
105                if (lastLocationAccess == null) {
106                    lastLocationAccess = lastMemorySnapshot.get(any());
107                    assert lastLocationAccess != null;
108                }
109                return lastLocationAccess;
110            }
111        }
112
113        @Override
114        public Iterable<LocationIdentity> getLocations() {
115            return lastMemorySnapshot.getKeys();
116        }
117
118        public EconomicMap<LocationIdentity, MemoryNode> getMap() {
119            return lastMemorySnapshot;
120        }
121    }
122
123    public FloatingReadPhase() {
124        this(true, false);
125    }
126
127    /**
128     * @param createFloatingReads specifies whether {@link FloatableAccessNode}s like
129     *            {@link ReadNode} should be converted into floating nodes (e.g.,
130     *            {@link FloatingReadNode}s) where possible
131     * @param createMemoryMapNodes a {@link MemoryMapNode} will be created for each return if this
132     *            is true
133     */
134    public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes) {
135        this.createFloatingReads = createFloatingReads;
136        this.createMemoryMapNodes = createMemoryMapNodes;
137    }
138
139    @Override
140    public float codeSizeIncrease() {
141        return 1.25f;
142    }
143
144    /**
145     * Removes nodes from a given set that (transitively) have a usage outside the set.
146     */
147    private static EconomicSet<Node> removeExternallyUsedNodes(EconomicSet<Node> set) {
148        boolean change;
149        do {
150            change = false;
151            for (Iterator<Node> iter = set.iterator(); iter.hasNext();) {
152                Node node = iter.next();
153                for (Node usage : node.usages()) {
154                    if (!set.contains(usage)) {
155                        change = true;
156                        iter.remove();
157                        break;
158                    }
159                }
160            }
161        } while (change);
162        return set;
163    }
164
165    protected void processNode(FixedNode node, EconomicSet<LocationIdentity> currentState) {
166        if (node instanceof MemoryCheckpoint.Single) {
167            processIdentity(currentState, ((MemoryCheckpoint.Single) node).getLocationIdentity());
168        } else if (node instanceof MemoryCheckpoint.Multi) {
169            for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
170                processIdentity(currentState, identity);
171            }
172        }
173    }
174
175    private static void processIdentity(EconomicSet<LocationIdentity> currentState, LocationIdentity identity) {
176        if (identity.isMutable()) {
177            currentState.add(identity);
178        }
179    }
180
181    protected void processBlock(Block b, EconomicSet<LocationIdentity> currentState) {
182        for (FixedNode n : b.getNodes()) {
183            processNode(n, currentState);
184        }
185    }
186
187    private EconomicSet<LocationIdentity> processLoop(HIRLoop loop, EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops) {
188        LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
189        EconomicSet<LocationIdentity> result = modifiedInLoops.get(loopBegin);
190        if (result != null) {
191            return result;
192        }
193
194        result = EconomicSet.create(Equivalence.DEFAULT);
195        for (Loop<Block> inner : loop.getChildren()) {
196            result.addAll(processLoop((HIRLoop) inner, modifiedInLoops));
197        }
198
199        for (Block b : loop.getBlocks()) {
200            if (b.getLoop() == loop) {
201                processBlock(b, result);
202            }
203        }
204
205        modifiedInLoops.put(loopBegin, result);
206        return result;
207    }
208
209    @Override
210    @SuppressWarnings("try")
211    protected void run(StructuredGraph graph) {
212        EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops = null;
213        if (graph.hasLoops()) {
214            modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
215            ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
216            for (Loop<?> l : cfg.getLoops()) {
217                HIRLoop loop = (HIRLoop) l;
218                processLoop(loop, modifiedInLoops);
219            }
220        }
221
222        HashSetNodeEventListener listener = new HashSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
223        try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
224            ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes), graph.start(), new MemoryMapImpl(graph.start()));
225        }
226
227        for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
228            if (n.isAlive() && n instanceof FloatingNode) {
229                n.replaceAtUsages(null);
230                GraphUtil.killWithUnusedFloatingInputs(n);
231            }
232        }
233        if (createFloatingReads) {
234            assert !graph.isAfterFloatingReadPhase();
235            graph.setAfterFloatingReadPhase(true);
236        }
237    }
238
239    public static MemoryMapImpl mergeMemoryMaps(AbstractMergeNode merge, List<? extends MemoryMap> states) {
240        MemoryMapImpl newState = new MemoryMapImpl();
241
242        EconomicSet<LocationIdentity> keys = EconomicSet.create(Equivalence.DEFAULT);
243        for (MemoryMap other : states) {
244            keys.addAll(other.getLocations());
245        }
246        assert checkNoImmutableLocations(keys);
247
248        for (LocationIdentity key : keys) {
249            int mergedStatesCount = 0;
250            boolean isPhi = false;
251            MemoryNode merged = null;
252            for (MemoryMap state : states) {
253                MemoryNode last = state.getLastLocationAccess(key);
254                if (isPhi) {
255                    ((MemoryPhiNode) merged).addInput(ValueNodeUtil.asNode(last));
256                } else {
257                    if (merged == last) {
258                        // nothing to do
259                    } else if (merged == null) {
260                        merged = last;
261                    } else {
262                        MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
263                        for (int j = 0; j < mergedStatesCount; j++) {
264                            phi.addInput(ValueNodeUtil.asNode(merged));
265                        }
266                        phi.addInput(ValueNodeUtil.asNode(last));
267                        merged = phi;
268                        isPhi = true;
269                    }
270                }
271                mergedStatesCount++;
272            }
273            newState.lastMemorySnapshot.put(key, merged);
274        }
275        return newState;
276
277    }
278
279    private static boolean checkNoImmutableLocations(EconomicSet<LocationIdentity> keys) {
280        keys.forEach(t -> {
281            assert t.isMutable();
282        });
283        return true;
284    }
285
286    public static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
287
288        private final EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops;
289        private boolean createFloatingReads;
290        private boolean createMemoryMapNodes;
291
292        public FloatingReadClosure(EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes) {
293            this.modifiedInLoops = modifiedInLoops;
294            this.createFloatingReads = createFloatingReads;
295            this.createMemoryMapNodes = createMemoryMapNodes;
296        }
297
298        @Override
299        protected MemoryMapImpl processNode(FixedNode node, MemoryMapImpl state) {
300            if (node instanceof MemoryAnchorNode) {
301                processAnchor((MemoryAnchorNode) node, state);
302                return state;
303            }
304
305            if (node instanceof MemoryAccess) {
306                processAccess((MemoryAccess) node, state);
307            }
308
309            if (createFloatingReads & node instanceof FloatableAccessNode) {
310                processFloatable((FloatableAccessNode) node, state);
311            } else if (node instanceof MemoryCheckpoint.Single) {
312                processCheckpoint((MemoryCheckpoint.Single) node, state);
313            } else if (node instanceof MemoryCheckpoint.Multi) {
314                processCheckpoint((MemoryCheckpoint.Multi) node, state);
315            }
316            assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
317
318            if (createMemoryMapNodes && node instanceof ReturnNode) {
319                ((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
320            }
321            return state;
322        }
323
324        /**
325         * Improve the memory graph by re-wiring all usages of a {@link MemoryAnchorNode} to the
326         * real last access location.
327         */
328        private static void processAnchor(MemoryAnchorNode anchor, MemoryMapImpl state) {
329            for (Node node : anchor.usages().snapshot()) {
330                if (node instanceof MemoryAccess) {
331                    MemoryAccess access = (MemoryAccess) node;
332                    if (access.getLastLocationAccess() == anchor) {
333                        MemoryNode lastLocationAccess = state.getLastLocationAccess(access.getLocationIdentity());
334                        assert lastLocationAccess != null;
335                        access.setLastLocationAccess(lastLocationAccess);
336                    }
337                }
338            }
339
340            if (anchor.hasNoUsages()) {
341                anchor.graph().removeFixed(anchor);
342            }
343        }
344
345        private static void processAccess(MemoryAccess access, MemoryMapImpl state) {
346            LocationIdentity locationIdentity = access.getLocationIdentity();
347            if (!locationIdentity.equals(LocationIdentity.any())) {
348                MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
349                access.setLastLocationAccess(lastLocationAccess);
350            }
351        }
352
353        private static void processCheckpoint(MemoryCheckpoint.Single checkpoint, MemoryMapImpl state) {
354            processIdentity(checkpoint.getLocationIdentity(), checkpoint, state);
355        }
356
357        private static void processCheckpoint(MemoryCheckpoint.Multi checkpoint, MemoryMapImpl state) {
358            for (LocationIdentity identity : checkpoint.getLocationIdentities()) {
359                processIdentity(identity, checkpoint, state);
360            }
361        }
362
363        private static void processIdentity(LocationIdentity identity, MemoryCheckpoint checkpoint, MemoryMapImpl state) {
364            if (identity.isAny()) {
365                state.lastMemorySnapshot.clear();
366            }
367            if (identity.isMutable()) {
368                state.lastMemorySnapshot.put(identity, checkpoint);
369            }
370        }
371
372        @SuppressWarnings("try")
373        private static void processFloatable(FloatableAccessNode accessNode, MemoryMapImpl state) {
374            StructuredGraph graph = accessNode.graph();
375            LocationIdentity locationIdentity = accessNode.getLocationIdentity();
376            if (accessNode.canFloat()) {
377                assert accessNode.getNullCheck() == false;
378                MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
379                try (DebugCloseable position = accessNode.withNodeSourcePosition()) {
380                    FloatingAccessNode floatingNode = accessNode.asFloatingNode(lastLocationAccess);
381                    graph.replaceFixedWithFloating(accessNode, floatingNode);
382                }
383            }
384        }
385
386        @Override
387        protected MemoryMapImpl merge(AbstractMergeNode merge, List<MemoryMapImpl> states) {
388            return mergeMemoryMaps(merge, states);
389        }
390
391        @Override
392        protected MemoryMapImpl afterSplit(AbstractBeginNode node, MemoryMapImpl oldState) {
393            MemoryMapImpl result = new MemoryMapImpl(oldState);
394            if (node.predecessor() instanceof InvokeWithExceptionNode) {
395                /*
396                 * InvokeWithException cannot be the lastLocationAccess for a FloatingReadNode.
397                 * Since it is both the invoke and a control flow split, the scheduler cannot
398                 * schedule anything immediately after the invoke. It can only schedule in the
399                 * normal or exceptional successor - and we have to tell the scheduler here which
400                 * side it needs to choose by putting in the location identity on both successors.
401                 */
402                InvokeWithExceptionNode invoke = (InvokeWithExceptionNode) node.predecessor();
403                result.lastMemorySnapshot.put(invoke.getLocationIdentity(), (MemoryCheckpoint) node);
404            }
405            return result;
406        }
407
408        @Override
409        protected EconomicMap<LoopExitNode, MemoryMapImpl> processLoop(LoopBeginNode loop, MemoryMapImpl initialState) {
410            EconomicSet<LocationIdentity> modifiedLocations = modifiedInLoops.get(loop);
411            EconomicMap<LocationIdentity, MemoryPhiNode> phis = EconomicMap.create(Equivalence.DEFAULT);
412            if (modifiedLocations.contains(LocationIdentity.any())) {
413                // create phis for all locations if ANY is modified in the loop
414                modifiedLocations = EconomicSet.create(Equivalence.DEFAULT, modifiedLocations);
415                modifiedLocations.addAll(initialState.lastMemorySnapshot.getKeys());
416            }
417
418            for (LocationIdentity location : modifiedLocations) {
419                createMemoryPhi(loop, initialState, phis, location);
420            }
421            initialState.lastMemorySnapshot.putAll(phis);
422
423            LoopInfo<MemoryMapImpl> loopInfo = ReentrantNodeIterator.processLoop(this, loop, initialState);
424
425            UnmodifiableMapCursor<LoopEndNode, MemoryMapImpl> endStateCursor = loopInfo.endStates.getEntries();
426            while (endStateCursor.advance()) {
427                int endIndex = loop.phiPredecessorIndex(endStateCursor.getKey());
428                UnmodifiableMapCursor<LocationIdentity, MemoryPhiNode> phiCursor = phis.getEntries();
429                while (phiCursor.advance()) {
430                    LocationIdentity key = phiCursor.getKey();
431                    PhiNode phi = phiCursor.getValue();
432                    phi.initializeValueAt(endIndex, ValueNodeUtil.asNode(endStateCursor.getValue().getLastLocationAccess(key)));
433                }
434            }
435            return loopInfo.exitStates;
436        }
437
438        private static void createMemoryPhi(LoopBeginNode loop, MemoryMapImpl initialState, EconomicMap<LocationIdentity, MemoryPhiNode> phis, LocationIdentity location) {
439            MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
440            phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
441            phis.put(location, phi);
442        }
443    }
444}
445