1/*
2 * Copyright (c) 2011, 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.phases.common;
24
25import static org.graalvm.compiler.core.common.LocationIdentity.any;
26import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
27import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
28
29import java.util.Collection;
30import java.util.EnumSet;
31import java.util.HashMap;
32import java.util.Iterator;
33import java.util.List;
34import java.util.Map;
35import java.util.Set;
36
37import org.graalvm.compiler.core.common.CollectionsFactory;
38import org.graalvm.compiler.core.common.LocationIdentity;
39import org.graalvm.compiler.core.common.cfg.Loop;
40import org.graalvm.compiler.debug.DebugCloseable;
41import org.graalvm.compiler.graph.Graph.NodeEventScope;
42import org.graalvm.compiler.graph.Node;
43import org.graalvm.compiler.nodes.AbstractBeginNode;
44import org.graalvm.compiler.nodes.AbstractMergeNode;
45import org.graalvm.compiler.nodes.FixedNode;
46import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
47import org.graalvm.compiler.nodes.LoopBeginNode;
48import org.graalvm.compiler.nodes.LoopEndNode;
49import org.graalvm.compiler.nodes.LoopExitNode;
50import org.graalvm.compiler.nodes.PhiNode;
51import org.graalvm.compiler.nodes.ReturnNode;
52import org.graalvm.compiler.nodes.StartNode;
53import org.graalvm.compiler.nodes.StructuredGraph;
54import org.graalvm.compiler.nodes.ValueNodeUtil;
55import org.graalvm.compiler.nodes.calc.FloatingNode;
56import org.graalvm.compiler.nodes.cfg.Block;
57import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
58import org.graalvm.compiler.nodes.cfg.HIRLoop;
59import org.graalvm.compiler.nodes.extended.GuardingNode;
60import org.graalvm.compiler.nodes.extended.ValueAnchorNode;
61import org.graalvm.compiler.nodes.memory.FloatableAccessNode;
62import org.graalvm.compiler.nodes.memory.FloatingAccessNode;
63import org.graalvm.compiler.nodes.memory.FloatingReadNode;
64import org.graalvm.compiler.nodes.memory.MemoryAccess;
65import org.graalvm.compiler.nodes.memory.MemoryAnchorNode;
66import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
67import org.graalvm.compiler.nodes.memory.MemoryMap;
68import org.graalvm.compiler.nodes.memory.MemoryMapNode;
69import org.graalvm.compiler.nodes.memory.MemoryNode;
70import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
71import org.graalvm.compiler.nodes.memory.ReadNode;
72import org.graalvm.compiler.nodes.util.GraphUtil;
73import org.graalvm.compiler.phases.Phase;
74import org.graalvm.compiler.phases.common.util.HashSetNodeEventListener;
75import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
76import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.LoopInfo;
77import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
78
79public class FloatingReadPhase extends Phase {
80
81    private boolean createFloatingReads;
82    private boolean createMemoryMapNodes;
83
84    public static class MemoryMapImpl implements MemoryMap {
85
86        private final Map<LocationIdentity, MemoryNode> lastMemorySnapshot;
87
88        public MemoryMapImpl(MemoryMapImpl memoryMap) {
89            lastMemorySnapshot = CollectionsFactory.newMap(memoryMap.lastMemorySnapshot);
90        }
91
92        public MemoryMapImpl(StartNode start) {
93            lastMemorySnapshot = CollectionsFactory.newMap();
94            lastMemorySnapshot.put(any(), start);
95        }
96
97        public MemoryMapImpl() {
98            lastMemorySnapshot = CollectionsFactory.newMap();
99        }
100
101        @Override
102        public MemoryNode getLastLocationAccess(LocationIdentity locationIdentity) {
103            MemoryNode lastLocationAccess;
104            if (locationIdentity.isImmutable()) {
105                return null;
106            } else {
107                lastLocationAccess = lastMemorySnapshot.get(locationIdentity);
108                if (lastLocationAccess == null) {
109                    lastLocationAccess = lastMemorySnapshot.get(any());
110                    assert lastLocationAccess != null;
111                }
112                return lastLocationAccess;
113            }
114        }
115
116        @Override
117        public Collection<LocationIdentity> getLocations() {
118            return lastMemorySnapshot.keySet();
119        }
120
121        public Map<LocationIdentity, MemoryNode> getMap() {
122            return lastMemorySnapshot;
123        }
124    }
125
126    public FloatingReadPhase() {
127        this(true, false);
128    }
129
130    /**
131     * @param createFloatingReads specifies whether {@link FloatableAccessNode}s like
132     *            {@link ReadNode} should be converted into floating nodes (e.g.,
133     *            {@link FloatingReadNode}s) where possible
134     * @param createMemoryMapNodes a {@link MemoryMapNode} will be created for each return if this
135     *            is true
136     */
137    public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes) {
138        this.createFloatingReads = createFloatingReads;
139        this.createMemoryMapNodes = createMemoryMapNodes;
140    }
141
142    @Override
143    public float codeSizeIncrease() {
144        return 1.25f;
145    }
146
147    /**
148     * Removes nodes from a given set that (transitively) have a usage outside the set.
149     */
150    private static Set<Node> removeExternallyUsedNodes(Set<Node> set) {
151        boolean change;
152        do {
153            change = false;
154            for (Iterator<Node> iter = set.iterator(); iter.hasNext();) {
155                Node node = iter.next();
156                for (Node usage : node.usages()) {
157                    if (!set.contains(usage)) {
158                        change = true;
159                        iter.remove();
160                        break;
161                    }
162                }
163            }
164        } while (change);
165        return set;
166    }
167
168    protected void processNode(FixedNode node, Set<LocationIdentity> currentState) {
169        if (node instanceof MemoryCheckpoint.Single) {
170            currentState.add(((MemoryCheckpoint.Single) node).getLocationIdentity());
171        } else if (node instanceof MemoryCheckpoint.Multi) {
172            for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
173                currentState.add(identity);
174            }
175        }
176    }
177
178    protected void processBlock(Block b, Set<LocationIdentity> currentState) {
179        for (FixedNode n : b.getNodes()) {
180            processNode(n, currentState);
181        }
182    }
183
184    private Set<LocationIdentity> processLoop(HIRLoop loop, Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops) {
185        LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
186        Set<LocationIdentity> result = modifiedInLoops.get(loopBegin);
187        if (result != null) {
188            return result;
189        }
190
191        result = CollectionsFactory.newSet();
192        for (Loop<Block> inner : loop.getChildren()) {
193            result.addAll(processLoop((HIRLoop) inner, modifiedInLoops));
194        }
195
196        for (Block b : loop.getBlocks()) {
197            if (b.getLoop() == loop) {
198                processBlock(b, result);
199            }
200        }
201
202        modifiedInLoops.put(loopBegin, result);
203        return result;
204    }
205
206    @Override
207    @SuppressWarnings("try")
208    protected void run(StructuredGraph graph) {
209        Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops = null;
210        if (graph.hasLoops()) {
211            modifiedInLoops = new HashMap<>();
212            ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
213            for (Loop<?> l : cfg.getLoops()) {
214                HIRLoop loop = (HIRLoop) l;
215                processLoop(loop, modifiedInLoops);
216            }
217        }
218
219        HashSetNodeEventListener listener = new HashSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
220        try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
221            ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes), graph.start(), new MemoryMapImpl(graph.start()));
222        }
223
224        for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
225            if (n.isAlive() && n instanceof FloatingNode) {
226                n.replaceAtUsages(null);
227                GraphUtil.killWithUnusedFloatingInputs(n);
228            }
229        }
230        if (createFloatingReads) {
231            assert !graph.isAfterFloatingReadPhase();
232            graph.setAfterFloatingReadPhase(true);
233        }
234    }
235
236    public static MemoryMapImpl mergeMemoryMaps(AbstractMergeNode merge, List<? extends MemoryMap> states) {
237        MemoryMapImpl newState = new MemoryMapImpl();
238
239        Set<LocationIdentity> keys = CollectionsFactory.newSet();
240        for (MemoryMap other : states) {
241            keys.addAll(other.getLocations());
242        }
243        assert checkNoImmutableLocations(keys);
244
245        for (LocationIdentity key : keys) {
246            int mergedStatesCount = 0;
247            boolean isPhi = false;
248            MemoryNode merged = null;
249            for (MemoryMap state : states) {
250                MemoryNode last = state.getLastLocationAccess(key);
251                if (isPhi) {
252                    ((MemoryPhiNode) merged).addInput(ValueNodeUtil.asNode(last));
253                } else {
254                    if (merged == last) {
255                        // nothing to do
256                    } else if (merged == null) {
257                        merged = last;
258                    } else {
259                        MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
260                        for (int j = 0; j < mergedStatesCount; j++) {
261                            phi.addInput(ValueNodeUtil.asNode(merged));
262                        }
263                        phi.addInput(ValueNodeUtil.asNode(last));
264                        merged = phi;
265                        isPhi = true;
266                    }
267                }
268                mergedStatesCount++;
269            }
270            newState.lastMemorySnapshot.put(key, merged);
271        }
272        return newState;
273
274    }
275
276    private static boolean checkNoImmutableLocations(Set<LocationIdentity> keys) {
277        keys.forEach(t -> {
278            assert !t.isImmutable();
279        });
280        return true;
281    }
282
283    public static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
284
285        private final Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops;
286        private boolean createFloatingReads;
287        private boolean createMemoryMapNodes;
288
289        public FloatingReadClosure(Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes) {
290            this.modifiedInLoops = modifiedInLoops;
291            this.createFloatingReads = createFloatingReads;
292            this.createMemoryMapNodes = createMemoryMapNodes;
293        }
294
295        @Override
296        protected MemoryMapImpl processNode(FixedNode node, MemoryMapImpl state) {
297            if (node instanceof MemoryAnchorNode) {
298                processAnchor((MemoryAnchorNode) node, state);
299                return state;
300            }
301
302            if (node instanceof MemoryAccess) {
303                processAccess((MemoryAccess) node, state);
304            }
305
306            if (createFloatingReads & node instanceof FloatableAccessNode) {
307                processFloatable((FloatableAccessNode) node, state);
308            } else if (node instanceof MemoryCheckpoint.Single) {
309                processCheckpoint((MemoryCheckpoint.Single) node, state);
310            } else if (node instanceof MemoryCheckpoint.Multi) {
311                processCheckpoint((MemoryCheckpoint.Multi) node, state);
312            }
313            assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
314
315            if (createMemoryMapNodes && node instanceof ReturnNode) {
316                ((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
317            }
318            return state;
319        }
320
321        /**
322         * Improve the memory graph by re-wiring all usages of a {@link MemoryAnchorNode} to the
323         * real last access location.
324         */
325        private static void processAnchor(MemoryAnchorNode anchor, MemoryMapImpl state) {
326            for (Node node : anchor.usages().snapshot()) {
327                if (node instanceof MemoryAccess) {
328                    MemoryAccess access = (MemoryAccess) node;
329                    if (access.getLastLocationAccess() == anchor) {
330                        MemoryNode lastLocationAccess = state.getLastLocationAccess(access.getLocationIdentity());
331                        assert lastLocationAccess != null;
332                        access.setLastLocationAccess(lastLocationAccess);
333                    }
334                }
335            }
336
337            if (anchor.hasNoUsages()) {
338                anchor.graph().removeFixed(anchor);
339            }
340        }
341
342        private static void processAccess(MemoryAccess access, MemoryMapImpl state) {
343            LocationIdentity locationIdentity = access.getLocationIdentity();
344            if (!locationIdentity.equals(LocationIdentity.any())) {
345                MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
346                access.setLastLocationAccess(lastLocationAccess);
347            }
348        }
349
350        private static void processCheckpoint(MemoryCheckpoint.Single checkpoint, MemoryMapImpl state) {
351            processIdentity(checkpoint.getLocationIdentity(), checkpoint, state);
352        }
353
354        private static void processCheckpoint(MemoryCheckpoint.Multi checkpoint, MemoryMapImpl state) {
355            for (LocationIdentity identity : checkpoint.getLocationIdentities()) {
356                processIdentity(identity, checkpoint, state);
357            }
358        }
359
360        private static void processIdentity(LocationIdentity identity, MemoryCheckpoint checkpoint, MemoryMapImpl state) {
361            if (identity.isAny()) {
362                state.lastMemorySnapshot.clear();
363            }
364            state.lastMemorySnapshot.put(identity, checkpoint);
365        }
366
367        @SuppressWarnings("try")
368        private static void processFloatable(FloatableAccessNode accessNode, MemoryMapImpl state) {
369            StructuredGraph graph = accessNode.graph();
370            LocationIdentity locationIdentity = accessNode.getLocationIdentity();
371            if (accessNode.canFloat()) {
372                assert accessNode.getNullCheck() == false;
373                MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
374                try (DebugCloseable position = accessNode.withNodeSourcePosition()) {
375                    FloatingAccessNode floatingNode = accessNode.asFloatingNode(lastLocationAccess);
376                    ValueAnchorNode anchor = null;
377                    GuardingNode guard = accessNode.getGuard();
378                    if (guard != null) {
379                        anchor = graph.add(new ValueAnchorNode(guard.asNode()));
380                        graph.addAfterFixed(accessNode, anchor);
381                    }
382                    graph.replaceFixedWithFloating(accessNode, floatingNode);
383                }
384            }
385        }
386
387        @Override
388        protected MemoryMapImpl merge(AbstractMergeNode merge, List<MemoryMapImpl> states) {
389            return mergeMemoryMaps(merge, states);
390        }
391
392        @Override
393        protected MemoryMapImpl afterSplit(AbstractBeginNode node, MemoryMapImpl oldState) {
394            MemoryMapImpl result = new MemoryMapImpl(oldState);
395            if (node.predecessor() instanceof InvokeWithExceptionNode) {
396                /*
397                 * InvokeWithException cannot be the lastLocationAccess for a FloatingReadNode.
398                 * Since it is both the invoke and a control flow split, the scheduler cannot
399                 * schedule anything immediately after the invoke. It can only schedule in the
400                 * normal or exceptional successor - and we have to tell the scheduler here which
401                 * side it needs to choose by putting in the location identity on both successors.
402                 */
403                InvokeWithExceptionNode invoke = (InvokeWithExceptionNode) node.predecessor();
404                result.lastMemorySnapshot.put(invoke.getLocationIdentity(), (MemoryCheckpoint) node);
405            }
406            return result;
407        }
408
409        @Override
410        protected Map<LoopExitNode, MemoryMapImpl> processLoop(LoopBeginNode loop, MemoryMapImpl initialState) {
411            Set<LocationIdentity> modifiedLocations = modifiedInLoops.get(loop);
412            Map<LocationIdentity, MemoryPhiNode> phis = CollectionsFactory.newMap();
413            if (modifiedLocations.contains(LocationIdentity.any())) {
414                // create phis for all locations if ANY is modified in the loop
415                modifiedLocations = CollectionsFactory.newSet(modifiedLocations);
416                modifiedLocations.addAll(initialState.lastMemorySnapshot.keySet());
417            }
418
419            for (LocationIdentity location : modifiedLocations) {
420                createMemoryPhi(loop, initialState, phis, location);
421            }
422            for (Map.Entry<LocationIdentity, MemoryPhiNode> entry : phis.entrySet()) {
423                initialState.lastMemorySnapshot.put(entry.getKey(), entry.getValue());
424            }
425
426            LoopInfo<MemoryMapImpl> loopInfo = ReentrantNodeIterator.processLoop(this, loop, initialState);
427
428            for (Map.Entry<LoopEndNode, MemoryMapImpl> entry : loopInfo.endStates.entrySet()) {
429                int endIndex = loop.phiPredecessorIndex(entry.getKey());
430                for (Map.Entry<LocationIdentity, MemoryPhiNode> phiEntry : phis.entrySet()) {
431                    LocationIdentity key = phiEntry.getKey();
432                    PhiNode phi = phiEntry.getValue();
433                    phi.initializeValueAt(endIndex, ValueNodeUtil.asNode(entry.getValue().getLastLocationAccess(key)));
434                }
435            }
436            return loopInfo.exitStates;
437        }
438
439        private static void createMemoryPhi(LoopBeginNode loop, MemoryMapImpl initialState, Map<LocationIdentity, MemoryPhiNode> phis, LocationIdentity location) {
440            MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
441            phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
442            phis.put(location, phi);
443        }
444    }
445}
446