Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- GenericUniformAnalysis.cpp --------------------*- C++ -*------------===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | // This template implementation resides in a separate file so that it |
||
10 | // does not get injected into every .cpp file that includes the |
||
11 | // generic header. |
||
12 | // |
||
13 | // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO. |
||
14 | // |
||
15 | // This file should only be included by files that implement a |
||
16 | // specialization of the relvant templates. Currently these are: |
||
17 | // - UniformityAnalysis.cpp |
||
18 | // |
||
19 | // Note: The DEBUG_TYPE macro should be defined before using this |
||
20 | // file so that any use of LLVM_DEBUG is associated with the |
||
21 | // including file rather than this file. |
||
22 | // |
||
23 | //===----------------------------------------------------------------------===// |
||
24 | /// |
||
25 | /// \file |
||
26 | /// \brief Implementation of uniformity analysis. |
||
27 | /// |
||
28 | /// The algorithm is a fixed point iteration that starts with the assumption |
||
29 | /// that all control flow and all values are uniform. Starting from sources of |
||
30 | /// divergence (whose discovery must be implemented by a CFG- or even |
||
31 | /// target-specific derived class), divergence of values is propagated from |
||
32 | /// definition to uses in a straight-forward way. The main complexity lies in |
||
33 | /// the propagation of the impact of divergent control flow on the divergence of |
||
34 | /// values (sync dependencies). |
||
35 | /// |
||
36 | //===----------------------------------------------------------------------===// |
||
37 | |||
38 | #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H |
||
39 | #define LLVM_ADT_GENERICUNIFORMITYIMPL_H |
||
40 | |||
41 | #include "llvm/ADT/GenericUniformityInfo.h" |
||
42 | |||
43 | #include "llvm/ADT/SmallPtrSet.h" |
||
44 | #include "llvm/ADT/SparseBitVector.h" |
||
45 | #include "llvm/ADT/StringExtras.h" |
||
46 | #include "llvm/Support/raw_ostream.h" |
||
47 | |||
48 | #include <set> |
||
49 | |||
50 | #define DEBUG_TYPE "uniformity" |
||
51 | |||
52 | using namespace llvm; |
||
53 | |||
54 | namespace llvm { |
||
55 | |||
56 | template <typename Range> auto unique(Range &&R) { |
||
57 | return std::unique(adl_begin(R), adl_end(R)); |
||
58 | } |
||
59 | |||
60 | /// Construct a specially modified post-order traversal of cycles. |
||
61 | /// |
||
62 | /// The ModifiedPO is contructed using a virtually modified CFG as follows: |
||
63 | /// |
||
64 | /// 1. The successors of pre-entry nodes (predecessors of an cycle |
||
65 | /// entry that are outside the cycle) are replaced by the |
||
66 | /// successors of the successors of the header. |
||
67 | /// 2. Successors of the cycle header are replaced by the exit blocks |
||
68 | /// of the cycle. |
||
69 | /// |
||
70 | /// Effectively, we produce a depth-first numbering with the following |
||
71 | /// properties: |
||
72 | /// |
||
73 | /// 1. Nodes after a cycle are numbered earlier than the cycle header. |
||
74 | /// 2. The header is numbered earlier than the nodes in the cycle. |
||
75 | /// 3. The numbering of the nodes within the cycle forms an interval |
||
76 | /// starting with the header. |
||
77 | /// |
||
78 | /// Effectively, the virtual modification arranges the nodes in a |
||
79 | /// cycle as a DAG with the header as the sole leaf, and successors of |
||
80 | /// the header as the roots. A reverse traversal of this numbering has |
||
81 | /// the following invariant on the unmodified original CFG: |
||
82 | /// |
||
83 | /// Each node is visited after all its predecessors, except if that |
||
84 | /// predecessor is the cycle header. |
||
85 | /// |
||
86 | template <typename ContextT> class ModifiedPostOrder { |
||
87 | public: |
||
88 | using BlockT = typename ContextT::BlockT; |
||
89 | using FunctionT = typename ContextT::FunctionT; |
||
90 | using DominatorTreeT = typename ContextT::DominatorTreeT; |
||
91 | |||
92 | using CycleInfoT = GenericCycleInfo<ContextT>; |
||
93 | using CycleT = typename CycleInfoT::CycleT; |
||
94 | using const_iterator = typename std::vector<BlockT *>::const_iterator; |
||
95 | |||
96 | ModifiedPostOrder(const ContextT &C) : Context(C) {} |
||
97 | |||
98 | bool empty() const { return m_order.empty(); } |
||
99 | size_t size() const { return m_order.size(); } |
||
100 | |||
101 | void clear() { m_order.clear(); } |
||
102 | void compute(const CycleInfoT &CI); |
||
103 | |||
104 | unsigned count(BlockT *BB) const { return POIndex.count(BB); } |
||
105 | const BlockT *operator[](size_t idx) const { return m_order[idx]; } |
||
106 | |||
107 | void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) { |
||
108 | POIndex[&BB] = m_order.size(); |
||
109 | m_order.push_back(&BB); |
||
110 | LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB] |
||
111 | << "): " << Context.print(&BB) << "\n"); |
||
112 | if (isReducibleCycleHeader) |
||
113 | ReducibleCycleHeaders.insert(&BB); |
||
114 | } |
||
115 | |||
116 | unsigned getIndex(const BlockT *BB) const { |
||
117 | assert(POIndex.count(BB)); |
||
118 | return POIndex.lookup(BB); |
||
119 | } |
||
120 | |||
121 | bool isReducibleCycleHeader(const BlockT *BB) const { |
||
122 | return ReducibleCycleHeaders.contains(BB); |
||
123 | } |
||
124 | |||
125 | private: |
||
126 | SmallVector<const BlockT *> m_order; |
||
127 | DenseMap<const BlockT *, unsigned> POIndex; |
||
128 | SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders; |
||
129 | const ContextT &Context; |
||
130 | |||
131 | void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle, |
||
132 | SmallPtrSetImpl<BlockT *> &Finalized); |
||
133 | |||
134 | void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, |
||
135 | const CycleT *Cycle, |
||
136 | SmallPtrSetImpl<BlockT *> &Finalized); |
||
137 | }; |
||
138 | |||
139 | template <typename> class DivergencePropagator; |
||
140 | |||
141 | /// \class GenericSyncDependenceAnalysis |
||
142 | /// |
||
143 | /// \brief Locate join blocks for disjoint paths starting at a divergent branch. |
||
144 | /// |
||
145 | /// An analysis per divergent branch that returns the set of basic |
||
146 | /// blocks whose phi nodes become divergent due to divergent control. |
||
147 | /// These are the blocks that are reachable by two disjoint paths from |
||
148 | /// the branch, or cycle exits reachable along a path that is disjoint |
||
149 | /// from a path to the cycle latch. |
||
150 | |||
151 | // --- Above line is not a doxygen comment; intentionally left blank --- |
||
152 | // |
||
153 | // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis. |
||
154 | // |
||
155 | // The SyncDependenceAnalysis is used in the UniformityAnalysis to model |
||
156 | // control-induced divergence in phi nodes. |
||
157 | // |
||
158 | // -- Reference -- |
||
159 | // The algorithm is an extension of Section 5 of |
||
160 | // |
||
161 | // An abstract interpretation for SPMD divergence |
||
162 | // on reducible control flow graphs. |
||
163 | // Julian Rosemann, Simon Moll and Sebastian Hack |
||
164 | // POPL '21 |
||
165 | // |
||
166 | // |
||
167 | // -- Sync dependence -- |
||
168 | // Sync dependence characterizes the control flow aspect of the |
||
169 | // propagation of branch divergence. For example, |
||
170 | // |
||
171 | // %cond = icmp slt i32 %tid, 10 |
||
172 | // br i1 %cond, label %then, label %else |
||
173 | // then: |
||
174 | // br label %merge |
||
175 | // else: |
||
176 | // br label %merge |
||
177 | // merge: |
||
178 | // %a = phi i32 [ 0, %then ], [ 1, %else ] |
||
179 | // |
||
180 | // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid |
||
181 | // because %tid is not on its use-def chains, %a is sync dependent on %tid |
||
182 | // because the branch "br i1 %cond" depends on %tid and affects which value %a |
||
183 | // is assigned to. |
||
184 | // |
||
185 | // |
||
186 | // -- Reduction to SSA construction -- |
||
187 | // There are two disjoint paths from A to X, if a certain variant of SSA |
||
188 | // construction places a phi node in X under the following set-up scheme. |
||
189 | // |
||
190 | // This variant of SSA construction ignores incoming undef values. |
||
191 | // That is paths from the entry without a definition do not result in |
||
192 | // phi nodes. |
||
193 | // |
||
194 | // entry |
||
195 | // / \ |
||
196 | // A \ |
||
197 | // / \ Y |
||
198 | // B C / |
||
199 | // \ / \ / |
||
200 | // D E |
||
201 | // \ / |
||
202 | // F |
||
203 | // |
||
204 | // Assume that A contains a divergent branch. We are interested |
||
205 | // in the set of all blocks where each block is reachable from A |
||
206 | // via two disjoint paths. This would be the set {D, F} in this |
||
207 | // case. |
||
208 | // To generally reduce this query to SSA construction we introduce |
||
209 | // a virtual variable x and assign to x different values in each |
||
210 | // successor block of A. |
||
211 | // |
||
212 | // entry |
||
213 | // / \ |
||
214 | // A \ |
||
215 | // / \ Y |
||
216 | // x = 0 x = 1 / |
||
217 | // \ / \ / |
||
218 | // D E |
||
219 | // \ / |
||
220 | // F |
||
221 | // |
||
222 | // Our flavor of SSA construction for x will construct the following |
||
223 | // |
||
224 | // entry |
||
225 | // / \ |
||
226 | // A \ |
||
227 | // / \ Y |
||
228 | // x0 = 0 x1 = 1 / |
||
229 | // \ / \ / |
||
230 | // x2 = phi E |
||
231 | // \ / |
||
232 | // x3 = phi |
||
233 | // |
||
234 | // The blocks D and F contain phi nodes and are thus each reachable |
||
235 | // by two disjoins paths from A. |
||
236 | // |
||
237 | // -- Remarks -- |
||
238 | // * In case of cycle exits we need to check for temporal divergence. |
||
239 | // To this end, we check whether the definition of x differs between the |
||
240 | // cycle exit and the cycle header (_after_ SSA construction). |
||
241 | // |
||
242 | // * In the presence of irreducible control flow, the fixed point is |
||
243 | // reached only after multiple iterations. This is because labels |
||
244 | // reaching the header of a cycle must be repropagated through the |
||
245 | // cycle. This is true even in a reducible cycle, since the labels |
||
246 | // may have been produced by a nested irreducible cycle. |
||
247 | // |
||
248 | // * Note that SyncDependenceAnalysis is not concerned with the points |
||
249 | // of convergence in an irreducible cycle. It's only purpose is to |
||
250 | // identify join blocks. The "diverged entry" criterion is |
||
251 | // separately applied on join blocks to determine if an entire |
||
252 | // irreducible cycle is assumed to be divergent. |
||
253 | // |
||
254 | // * Relevant related work: |
||
255 | // A simple algorithm for global data flow analysis problems. |
||
256 | // Matthew S. Hecht and Jeffrey D. Ullman. |
||
257 | // SIAM Journal on Computing, 4(4):519–532, December 1975. |
||
258 | // |
||
259 | template <typename ContextT> class GenericSyncDependenceAnalysis { |
||
260 | public: |
||
261 | using BlockT = typename ContextT::BlockT; |
||
262 | using DominatorTreeT = typename ContextT::DominatorTreeT; |
||
263 | using FunctionT = typename ContextT::FunctionT; |
||
264 | using ValueRefT = typename ContextT::ValueRefT; |
||
265 | using InstructionT = typename ContextT::InstructionT; |
||
266 | |||
267 | using CycleInfoT = GenericCycleInfo<ContextT>; |
||
268 | using CycleT = typename CycleInfoT::CycleT; |
||
269 | |||
270 | using ConstBlockSet = SmallPtrSet<const BlockT *, 4>; |
||
271 | using ModifiedPO = ModifiedPostOrder<ContextT>; |
||
272 | |||
273 | // * if BlockLabels[B] == C then C is the dominating definition at |
||
274 | // block B |
||
275 | // * if BlockLabels[B] == nullptr then we haven't seen B yet |
||
276 | // * if BlockLabels[B] == B then: |
||
277 | // - B is a join point of disjoint paths from X, or, |
||
278 | // - B is an immediate successor of X (initial value), or, |
||
279 | // - B is X |
||
280 | using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>; |
||
281 | |||
282 | /// Information discovered by the sync dependence analysis for each |
||
283 | /// divergent branch. |
||
284 | struct DivergenceDescriptor { |
||
285 | // Join points of diverged paths. |
||
286 | ConstBlockSet JoinDivBlocks; |
||
287 | // Divergent cycle exits |
||
288 | ConstBlockSet CycleDivBlocks; |
||
289 | // Labels assigned to blocks on diverged paths. |
||
290 | BlockLabelMap BlockLabels; |
||
291 | }; |
||
292 | |||
293 | using DivergencePropagatorT = DivergencePropagator<ContextT>; |
||
294 | |||
295 | GenericSyncDependenceAnalysis(const ContextT &Context, |
||
296 | const DominatorTreeT &DT, const CycleInfoT &CI); |
||
297 | |||
298 | /// \brief Computes divergent join points and cycle exits caused by branch |
||
299 | /// divergence in \p Term. |
||
300 | /// |
||
301 | /// This returns a pair of sets: |
||
302 | /// * The set of blocks which are reachable by disjoint paths from |
||
303 | /// \p Term. |
||
304 | /// * The set also contains cycle exits if there two disjoint paths: |
||
305 | /// one from \p Term to the cycle exit and another from \p Term to |
||
306 | /// the cycle header. |
||
307 | const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock); |
||
308 | |||
309 | private: |
||
310 | static DivergenceDescriptor EmptyDivergenceDesc; |
||
311 | |||
312 | ModifiedPO CyclePO; |
||
313 | |||
314 | const DominatorTreeT &DT; |
||
315 | const CycleInfoT &CI; |
||
316 | |||
317 | DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>> |
||
318 | CachedControlDivDescs; |
||
319 | }; |
||
320 | |||
321 | /// \brief Analysis that identifies uniform values in a data-parallel |
||
322 | /// execution. |
||
323 | /// |
||
324 | /// This analysis propagates divergence in a data-parallel context |
||
325 | /// from sources of divergence to all users. It can be instantiated |
||
326 | /// for an IR that provides a suitable SSAContext. |
||
327 | template <typename ContextT> class GenericUniformityAnalysisImpl { |
||
328 | public: |
||
329 | using BlockT = typename ContextT::BlockT; |
||
330 | using FunctionT = typename ContextT::FunctionT; |
||
331 | using ValueRefT = typename ContextT::ValueRefT; |
||
332 | using ConstValueRefT = typename ContextT::ConstValueRefT; |
||
333 | using InstructionT = typename ContextT::InstructionT; |
||
334 | using DominatorTreeT = typename ContextT::DominatorTreeT; |
||
335 | |||
336 | using CycleInfoT = GenericCycleInfo<ContextT>; |
||
337 | using CycleT = typename CycleInfoT::CycleT; |
||
338 | |||
339 | using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; |
||
340 | using DivergenceDescriptorT = |
||
341 | typename SyncDependenceAnalysisT::DivergenceDescriptor; |
||
342 | using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; |
||
343 | |||
344 | GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT, |
||
345 | const CycleInfoT &CI, |
||
346 | const TargetTransformInfo *TTI) |
||
347 | : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT), |
||
348 | SDA(Context, DT, CI) {} |
||
349 | |||
350 | void initialize(); |
||
351 | |||
352 | const FunctionT &getFunction() const { return F; } |
||
353 | |||
354 | /// \brief Mark \p UniVal as a value that is always uniform. |
||
355 | void addUniformOverride(const InstructionT &Instr); |
||
356 | |||
357 | /// \brief Mark \p DivVal as a value that is always divergent. |
||
358 | /// \returns Whether the tracked divergence state of \p DivVal changed. |
||
359 | bool markDivergent(const InstructionT &I); |
||
360 | bool markDivergent(ConstValueRefT DivVal); |
||
361 | bool markDefsDivergent(const InstructionT &Instr, |
||
362 | bool AllDefsDivergent = true); |
||
363 | |||
364 | /// \brief Propagate divergence to all instructions in the region. |
||
365 | /// Divergence is seeded by calls to \p markDivergent. |
||
366 | void compute(); |
||
367 | |||
368 | /// \brief Whether any value was marked or analyzed to be divergent. |
||
369 | bool hasDivergence() const { return !DivergentValues.empty(); } |
||
370 | |||
371 | /// \brief Whether \p Val will always return a uniform value regardless of its |
||
372 | /// operands |
||
373 | bool isAlwaysUniform(const InstructionT &Instr) const; |
||
374 | |||
375 | bool hasDivergentDefs(const InstructionT &I) const; |
||
376 | |||
377 | bool isDivergent(const InstructionT &I) const { |
||
378 | if (I.isTerminator()) { |
||
379 | return DivergentTermBlocks.contains(I.getParent()); |
||
380 | } |
||
381 | return hasDivergentDefs(I); |
||
382 | }; |
||
383 | |||
384 | /// \brief Whether \p Val is divergent at its definition. |
||
385 | bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); } |
||
386 | |||
387 | bool hasDivergentTerminator(const BlockT &B) const { |
||
388 | return DivergentTermBlocks.contains(&B); |
||
389 | } |
||
390 | |||
391 | void print(raw_ostream &out) const; |
||
392 | |||
393 | protected: |
||
394 | /// \brief Value/block pair representing a single phi input. |
||
395 | struct PhiInput { |
||
396 | ConstValueRefT value; |
||
397 | BlockT *predBlock; |
||
398 | |||
399 | PhiInput(ConstValueRefT value, BlockT *predBlock) |
||
400 | : value(value), predBlock(predBlock) {} |
||
401 | }; |
||
402 | |||
403 | const ContextT &Context; |
||
404 | const FunctionT &F; |
||
405 | const CycleInfoT &CI; |
||
406 | const TargetTransformInfo *TTI = nullptr; |
||
407 | |||
408 | // Detected/marked divergent values. |
||
409 | std::set<ConstValueRefT> DivergentValues; |
||
410 | SmallPtrSet<const BlockT *, 32> DivergentTermBlocks; |
||
411 | |||
412 | // Internal worklist for divergence propagation. |
||
413 | std::vector<const InstructionT *> Worklist; |
||
414 | |||
415 | /// \brief Mark \p Term as divergent and push all Instructions that become |
||
416 | /// divergent as a result on the worklist. |
||
417 | void analyzeControlDivergence(const InstructionT &Term); |
||
418 | |||
419 | private: |
||
420 | const DominatorTreeT &DT; |
||
421 | |||
422 | // Recognized cycles with divergent exits. |
||
423 | SmallPtrSet<const CycleT *, 16> DivergentExitCycles; |
||
424 | |||
425 | // Cycles assumed to be divergent. |
||
426 | // |
||
427 | // We don't use a set here because every insertion needs an explicit |
||
428 | // traversal of all existing members. |
||
429 | SmallVector<const CycleT *> AssumedDivergent; |
||
430 | |||
431 | // The SDA links divergent branches to divergent control-flow joins. |
||
432 | SyncDependenceAnalysisT SDA; |
||
433 | |||
434 | // Set of known-uniform values. |
||
435 | SmallPtrSet<const InstructionT *, 32> UniformOverrides; |
||
436 | |||
437 | /// \brief Mark all nodes in \p JoinBlock as divergent and push them on |
||
438 | /// the worklist. |
||
439 | void taintAndPushAllDefs(const BlockT &JoinBlock); |
||
440 | |||
441 | /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on |
||
442 | /// the worklist. |
||
443 | void taintAndPushPhiNodes(const BlockT &JoinBlock); |
||
444 | |||
445 | /// \brief Identify all Instructions that become divergent because \p DivExit |
||
446 | /// is a divergent cycle exit of \p DivCycle. Mark those instructions as |
||
447 | /// divergent and push them on the worklist. |
||
448 | void propagateCycleExitDivergence(const BlockT &DivExit, |
||
449 | const CycleT &DivCycle); |
||
450 | |||
451 | /// \brief Internal implementation function for propagateCycleExitDivergence. |
||
452 | void analyzeCycleExitDivergence(const CycleT &OuterDivCycle); |
||
453 | |||
454 | /// \brief Mark all instruction as divergent that use a value defined in \p |
||
455 | /// OuterDivCycle. Push their users on the worklist. |
||
456 | void analyzeTemporalDivergence(const InstructionT &I, |
||
457 | const CycleT &OuterDivCycle); |
||
458 | |||
459 | /// \brief Push all users of \p Val (in the region) to the worklist. |
||
460 | void pushUsers(const InstructionT &I); |
||
461 | void pushUsers(ConstValueRefT V); |
||
462 | |||
463 | bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const; |
||
464 | |||
465 | /// \brief Whether \p Val is divergent when read in \p ObservingBlock. |
||
466 | bool isTemporalDivergent(const BlockT &ObservingBlock, |
||
467 | ConstValueRefT Val) const; |
||
468 | }; |
||
469 | |||
470 | template <typename ImplT> |
||
471 | void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) { |
||
472 | delete Impl; |
||
473 | } |
||
474 | |||
475 | /// Compute divergence starting with a divergent branch. |
||
476 | template <typename ContextT> class DivergencePropagator { |
||
477 | public: |
||
478 | using BlockT = typename ContextT::BlockT; |
||
479 | using DominatorTreeT = typename ContextT::DominatorTreeT; |
||
480 | using FunctionT = typename ContextT::FunctionT; |
||
481 | using ValueRefT = typename ContextT::ValueRefT; |
||
482 | |||
483 | using CycleInfoT = GenericCycleInfo<ContextT>; |
||
484 | using CycleT = typename CycleInfoT::CycleT; |
||
485 | |||
486 | using ModifiedPO = ModifiedPostOrder<ContextT>; |
||
487 | using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; |
||
488 | using DivergenceDescriptorT = |
||
489 | typename SyncDependenceAnalysisT::DivergenceDescriptor; |
||
490 | using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; |
||
491 | |||
492 | const ModifiedPO &CyclePOT; |
||
493 | const DominatorTreeT &DT; |
||
494 | const CycleInfoT &CI; |
||
495 | const BlockT &DivTermBlock; |
||
496 | const ContextT &Context; |
||
497 | |||
498 | // Track blocks that receive a new label. Every time we relabel a |
||
499 | // cycle header, we another pass over the modified post-order in |
||
500 | // order to propagate the header label. The bit vector also allows |
||
501 | // us to skip labels that have not changed. |
||
502 | SparseBitVector<> FreshLabels; |
||
503 | |||
504 | // divergent join and cycle exit descriptor. |
||
505 | std::unique_ptr<DivergenceDescriptorT> DivDesc; |
||
506 | BlockLabelMapT &BlockLabels; |
||
507 | |||
508 | DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT, |
||
509 | const CycleInfoT &CI, const BlockT &DivTermBlock) |
||
510 | : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock), |
||
511 | Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT), |
||
512 | BlockLabels(DivDesc->BlockLabels) {} |
||
513 | |||
514 | void printDefs(raw_ostream &Out) { |
||
515 | Out << "Propagator::BlockLabels {\n"; |
||
516 | for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) { |
||
517 | const auto *Block = CyclePOT[BlockIdx]; |
||
518 | const auto *Label = BlockLabels[Block]; |
||
519 | Out << Context.print(Block) << "(" << BlockIdx << ") : "; |
||
520 | if (!Label) { |
||
521 | Out << "<null>\n"; |
||
522 | } else { |
||
523 | Out << Context.print(Label) << "\n"; |
||
524 | } |
||
525 | } |
||
526 | Out << "}\n"; |
||
527 | } |
||
528 | |||
529 | // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this |
||
530 | // causes a divergent join. |
||
531 | bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) { |
||
532 | const auto *OldLabel = BlockLabels[&SuccBlock]; |
||
533 | |||
534 | LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n" |
||
535 | << "\tpushed label: " << Context.print(&PushedLabel) |
||
536 | << "\n" |
||
537 | << "\told label: " << Context.print(OldLabel) << "\n"); |
||
538 | |||
539 | // Early exit if there is no change in the label. |
||
540 | if (OldLabel == &PushedLabel) |
||
541 | return false; |
||
542 | |||
543 | if (OldLabel != &SuccBlock) { |
||
544 | auto SuccIdx = CyclePOT.getIndex(&SuccBlock); |
||
545 | // Assigning a new label, mark this in FreshLabels. |
||
546 | LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n"); |
||
547 | FreshLabels.set(SuccIdx); |
||
548 | } |
||
549 | |||
550 | // This is not a join if the succ was previously unlabeled. |
||
551 | if (!OldLabel) { |
||
552 | LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel) |
||
553 | << "\n"); |
||
554 | BlockLabels[&SuccBlock] = &PushedLabel; |
||
555 | return false; |
||
556 | } |
||
557 | |||
558 | // This is a new join. Label the join block as itself, and not as |
||
559 | // the pushed label. |
||
560 | LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n"); |
||
561 | BlockLabels[&SuccBlock] = &SuccBlock; |
||
562 | |||
563 | return true; |
||
564 | } |
||
565 | |||
566 | // visiting a virtual cycle exit edge from the cycle header --> temporal |
||
567 | // divergence on join |
||
568 | bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) { |
||
569 | if (!computeJoin(ExitBlock, Label)) |
||
570 | return false; |
||
571 | |||
572 | // Identified a divergent cycle exit |
||
573 | DivDesc->CycleDivBlocks.insert(&ExitBlock); |
||
574 | LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock) |
||
575 | << "\n"); |
||
576 | return true; |
||
577 | } |
||
578 | |||
579 | // process \p SuccBlock with reaching definition \p Label |
||
580 | bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) { |
||
581 | if (!computeJoin(SuccBlock, Label)) |
||
582 | return false; |
||
583 | |||
584 | // Divergent, disjoint paths join. |
||
585 | DivDesc->JoinDivBlocks.insert(&SuccBlock); |
||
586 | LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock) |
||
587 | << "\n"); |
||
588 | return true; |
||
589 | } |
||
590 | |||
591 | std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() { |
||
592 | assert(DivDesc); |
||
593 | |||
594 | LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " |
||
595 | << Context.print(&DivTermBlock) << "\n"); |
||
596 | |||
597 | // Early stopping criterion |
||
598 | int FloorIdx = CyclePOT.size() - 1; |
||
599 | const BlockT *FloorLabel = nullptr; |
||
600 | int DivTermIdx = CyclePOT.getIndex(&DivTermBlock); |
||
601 | |||
602 | // Bootstrap with branch targets |
||
603 | auto const *DivTermCycle = CI.getCycle(&DivTermBlock); |
||
604 | for (const auto *SuccBlock : successors(&DivTermBlock)) { |
||
605 | if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) { |
||
606 | // If DivTerm exits the cycle immediately, computeJoin() might |
||
607 | // not reach SuccBlock with a different label. We need to |
||
608 | // check for this exit now. |
||
609 | DivDesc->CycleDivBlocks.insert(SuccBlock); |
||
610 | LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: " |
||
611 | << Context.print(SuccBlock) << "\n"); |
||
612 | } |
||
613 | auto SuccIdx = CyclePOT.getIndex(SuccBlock); |
||
614 | visitEdge(*SuccBlock, *SuccBlock); |
||
615 | FloorIdx = std::min<int>(FloorIdx, SuccIdx); |
||
616 | } |
||
617 | |||
618 | while (true) { |
||
619 | auto BlockIdx = FreshLabels.find_last(); |
||
620 | if (BlockIdx == -1 || BlockIdx < FloorIdx) |
||
621 | break; |
||
622 | |||
623 | LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs())); |
||
624 | |||
625 | FreshLabels.reset(BlockIdx); |
||
626 | if (BlockIdx == DivTermIdx) { |
||
627 | LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n"); |
||
628 | continue; |
||
629 | } |
||
630 | |||
631 | const auto *Block = CyclePOT[BlockIdx]; |
||
632 | LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index " |
||
633 | << BlockIdx << "\n"); |
||
634 | |||
635 | const auto *Label = BlockLabels[Block]; |
||
636 | assert(Label); |
||
637 | |||
638 | bool CausedJoin = false; |
||
639 | int LoweredFloorIdx = FloorIdx; |
||
640 | |||
641 | // If the current block is the header of a reducible cycle that |
||
642 | // contains the divergent branch, then the label should be |
||
643 | // propagated to the cycle exits. Such a header is the "last |
||
644 | // possible join" of any disjoint paths within this cycle. This |
||
645 | // prevents detection of spurious joins at the entries of any |
||
646 | // irreducible child cycles. |
||
647 | // |
||
648 | // This conclusion about the header is true for any choice of DFS: |
||
649 | // |
||
650 | // If some DFS has a reducible cycle C with header H, then for |
||
651 | // any other DFS, H is the header of a cycle C' that is a |
||
652 | // superset of C. For a divergent branch inside the subgraph |
||
653 | // C, any join node inside C is either H, or some node |
||
654 | // encountered without passing through H. |
||
655 | // |
||
656 | auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * { |
||
657 | if (!CyclePOT.isReducibleCycleHeader(Block)) |
||
658 | return nullptr; |
||
659 | const auto *BlockCycle = CI.getCycle(Block); |
||
660 | if (BlockCycle->contains(&DivTermBlock)) |
||
661 | return BlockCycle; |
||
662 | return nullptr; |
||
663 | }; |
||
664 | |||
665 | if (const auto *BlockCycle = getReducibleParent(Block)) { |
||
666 | SmallVector<BlockT *, 4> BlockCycleExits; |
||
667 | BlockCycle->getExitBlocks(BlockCycleExits); |
||
668 | for (auto *BlockCycleExit : BlockCycleExits) { |
||
669 | CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label); |
||
670 | LoweredFloorIdx = |
||
671 | std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit)); |
||
672 | } |
||
673 | } else { |
||
674 | for (const auto *SuccBlock : successors(Block)) { |
||
675 | CausedJoin |= visitEdge(*SuccBlock, *Label); |
||
676 | LoweredFloorIdx = |
||
677 | std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock)); |
||
678 | } |
||
679 | } |
||
680 | |||
681 | // Floor update |
||
682 | if (CausedJoin) { |
||
683 | // 1. Different labels pushed to successors |
||
684 | FloorIdx = LoweredFloorIdx; |
||
685 | } else if (FloorLabel != Label) { |
||
686 | // 2. No join caused BUT we pushed a label that is different than the |
||
687 | // last pushed label |
||
688 | FloorIdx = LoweredFloorIdx; |
||
689 | FloorLabel = Label; |
||
690 | } |
||
691 | } |
||
692 | |||
693 | LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs())); |
||
694 | |||
695 | // Check every cycle containing DivTermBlock for exit divergence. |
||
696 | // A cycle has exit divergence if the label of an exit block does |
||
697 | // not match the label of its header. |
||
698 | for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle; |
||
699 | Cycle = Cycle->getParentCycle()) { |
||
700 | if (Cycle->isReducible()) { |
||
701 | // The exit divergence of a reducible cycle is recorded while |
||
702 | // propagating labels. |
||
703 | continue; |
||
704 | } |
||
705 | SmallVector<BlockT *> Exits; |
||
706 | Cycle->getExitBlocks(Exits); |
||
707 | auto *Header = Cycle->getHeader(); |
||
708 | auto *HeaderLabel = BlockLabels[Header]; |
||
709 | for (const auto *Exit : Exits) { |
||
710 | if (BlockLabels[Exit] != HeaderLabel) { |
||
711 | // Identified a divergent cycle exit |
||
712 | DivDesc->CycleDivBlocks.insert(Exit); |
||
713 | LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit) |
||
714 | << "\n"); |
||
715 | } |
||
716 | } |
||
717 | } |
||
718 | |||
719 | return std::move(DivDesc); |
||
720 | } |
||
721 | }; |
||
722 | |||
723 | template <typename ContextT> |
||
724 | typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor |
||
725 | llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc; |
||
726 | |||
727 | template <typename ContextT> |
||
728 | llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis( |
||
729 | const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI) |
||
730 | : CyclePO(Context), DT(DT), CI(CI) { |
||
731 | CyclePO.compute(CI); |
||
732 | } |
||
733 | |||
734 | template <typename ContextT> |
||
735 | auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks( |
||
736 | const BlockT *DivTermBlock) -> const DivergenceDescriptor & { |
||
737 | // trivial case |
||
738 | if (succ_size(DivTermBlock) <= 1) { |
||
739 | return EmptyDivergenceDesc; |
||
740 | } |
||
741 | |||
742 | // already available in cache? |
||
743 | auto ItCached = CachedControlDivDescs.find(DivTermBlock); |
||
744 | if (ItCached != CachedControlDivDescs.end()) |
||
745 | return *ItCached->second; |
||
746 | |||
747 | // compute all join points |
||
748 | DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock); |
||
749 | auto DivDesc = Propagator.computeJoinPoints(); |
||
750 | |||
751 | auto printBlockSet = [&](ConstBlockSet &Blocks) { |
||
752 | return Printable([&](raw_ostream &Out) { |
||
753 | Out << "["; |
||
754 | ListSeparator LS; |
||
755 | for (const auto *BB : Blocks) { |
||
756 | Out << LS << CI.getSSAContext().print(BB); |
||
757 | } |
||
758 | Out << "]\n"; |
||
759 | }); |
||
760 | }; |
||
761 | |||
762 | LLVM_DEBUG( |
||
763 | dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock) |
||
764 | << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks) |
||
765 | << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks) |
||
766 | << "\n"); |
||
767 | (void)printBlockSet; |
||
768 | |||
769 | auto ItInserted = |
||
770 | CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc)); |
||
771 | assert(ItInserted.second); |
||
772 | return *ItInserted.first->second; |
||
773 | } |
||
774 | |||
775 | template <typename ContextT> |
||
776 | bool GenericUniformityAnalysisImpl<ContextT>::markDivergent( |
||
777 | const InstructionT &I) { |
||
778 | if (I.isTerminator()) { |
||
779 | if (DivergentTermBlocks.insert(I.getParent()).second) { |
||
780 | LLVM_DEBUG(dbgs() << "marked divergent term block: " |
||
781 | << Context.print(I.getParent()) << "\n"); |
||
782 | return true; |
||
783 | } |
||
784 | return false; |
||
785 | } |
||
786 | |||
787 | return markDefsDivergent(I); |
||
788 | } |
||
789 | |||
790 | template <typename ContextT> |
||
791 | bool GenericUniformityAnalysisImpl<ContextT>::markDivergent( |
||
792 | ConstValueRefT Val) { |
||
793 | if (DivergentValues.insert(Val).second) { |
||
794 | LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n"); |
||
795 | return true; |
||
796 | } |
||
797 | return false; |
||
798 | } |
||
799 | |||
800 | template <typename ContextT> |
||
801 | void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride( |
||
802 | const InstructionT &Instr) { |
||
803 | UniformOverrides.insert(&Instr); |
||
804 | } |
||
805 | |||
806 | template <typename ContextT> |
||
807 | void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence( |
||
808 | const InstructionT &I, const CycleT &OuterDivCycle) { |
||
809 | if (isDivergent(I)) |
||
810 | return; |
||
811 | |||
812 | LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I) |
||
813 | << "\n"); |
||
814 | if (!usesValueFromCycle(I, OuterDivCycle)) |
||
815 | return; |
||
816 | |||
817 | if (isAlwaysUniform(I)) |
||
818 | return; |
||
819 | |||
820 | if (markDivergent(I)) |
||
821 | Worklist.push_back(&I); |
||
822 | } |
||
823 | |||
824 | // Mark all external users of values defined inside \param |
||
825 | // OuterDivCycle as divergent. |
||
826 | // |
||
827 | // This follows all live out edges wherever they may lead. Potential |
||
828 | // users of values defined inside DivCycle could be anywhere in the |
||
829 | // dominance region of DivCycle (including its fringes for phi nodes). |
||
830 | // A cycle C dominates a block B iff every path from the entry block |
||
831 | // to B must pass through a block contained in C. If C is a reducible |
||
832 | // cycle (or natural loop), C dominates B iff the header of C |
||
833 | // dominates B. But in general, we iteratively examine cycle cycle |
||
834 | // exits and their successors. |
||
835 | template <typename ContextT> |
||
836 | void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence( |
||
837 | const CycleT &OuterDivCycle) { |
||
838 | // Set of blocks that are dominated by the cycle, i.e., each is only |
||
839 | // reachable from paths that pass through the cycle. |
||
840 | SmallPtrSet<BlockT *, 16> DomRegion; |
||
841 | |||
842 | // The boundary of DomRegion, formed by blocks that are not |
||
843 | // dominated by the cycle. |
||
844 | SmallVector<BlockT *> DomFrontier; |
||
845 | OuterDivCycle.getExitBlocks(DomFrontier); |
||
846 | |||
847 | // Returns true if BB is dominated by the cycle. |
||
848 | auto isInDomRegion = [&](BlockT *BB) { |
||
849 | for (auto *P : predecessors(BB)) { |
||
850 | if (OuterDivCycle.contains(P)) |
||
851 | continue; |
||
852 | if (DomRegion.count(P)) |
||
853 | continue; |
||
854 | return false; |
||
855 | } |
||
856 | return true; |
||
857 | }; |
||
858 | |||
859 | // Keep advancing the frontier along successor edges, while |
||
860 | // promoting blocks to DomRegion. |
||
861 | while (true) { |
||
862 | bool Promoted = false; |
||
863 | SmallVector<BlockT *> Temp; |
||
864 | for (auto *W : DomFrontier) { |
||
865 | if (!isInDomRegion(W)) { |
||
866 | Temp.push_back(W); |
||
867 | continue; |
||
868 | } |
||
869 | DomRegion.insert(W); |
||
870 | Promoted = true; |
||
871 | for (auto *Succ : successors(W)) { |
||
872 | if (DomRegion.contains(Succ)) |
||
873 | continue; |
||
874 | Temp.push_back(Succ); |
||
875 | } |
||
876 | } |
||
877 | if (!Promoted) |
||
878 | break; |
||
879 | DomFrontier = Temp; |
||
880 | } |
||
881 | |||
882 | // At DomFrontier, only the PHI nodes are affected by temporal |
||
883 | // divergence. |
||
884 | for (const auto *UserBlock : DomFrontier) { |
||
885 | LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: " |
||
886 | << Context.print(UserBlock) << "\n"); |
||
887 | for (const auto &Phi : UserBlock->phis()) { |
||
888 | LLVM_DEBUG(dbgs() << " " << Context.print(&Phi) << "\n"); |
||
889 | analyzeTemporalDivergence(Phi, OuterDivCycle); |
||
890 | } |
||
891 | } |
||
892 | |||
893 | // All instructions inside the dominance region are affected by |
||
894 | // temporal divergence. |
||
895 | for (const auto *UserBlock : DomRegion) { |
||
896 | LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: " |
||
897 | << Context.print(UserBlock) << "\n"); |
||
898 | for (const auto &I : *UserBlock) { |
||
899 | LLVM_DEBUG(dbgs() << " " << Context.print(&I) << "\n"); |
||
900 | analyzeTemporalDivergence(I, OuterDivCycle); |
||
901 | } |
||
902 | } |
||
903 | } |
||
904 | |||
905 | template <typename ContextT> |
||
906 | void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence( |
||
907 | const BlockT &DivExit, const CycleT &InnerDivCycle) { |
||
908 | LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit) |
||
909 | << "\n"); |
||
910 | auto *DivCycle = &InnerDivCycle; |
||
911 | auto *OuterDivCycle = DivCycle; |
||
912 | auto *ExitLevelCycle = CI.getCycle(&DivExit); |
||
913 | const unsigned CycleExitDepth = |
||
914 | ExitLevelCycle ? ExitLevelCycle->getDepth() : 0; |
||
915 | |||
916 | // Find outer-most cycle that does not contain \p DivExit |
||
917 | while (DivCycle && DivCycle->getDepth() > CycleExitDepth) { |
||
918 | LLVM_DEBUG(dbgs() << " Found exiting cycle: " |
||
919 | << Context.print(DivCycle->getHeader()) << "\n"); |
||
920 | OuterDivCycle = DivCycle; |
||
921 | DivCycle = DivCycle->getParentCycle(); |
||
922 | } |
||
923 | LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: " |
||
924 | << Context.print(OuterDivCycle->getHeader()) << "\n"); |
||
925 | |||
926 | if (!DivergentExitCycles.insert(OuterDivCycle).second) |
||
927 | return; |
||
928 | |||
929 | // Exit divergence does not matter if the cycle itself is assumed to |
||
930 | // be divergent. |
||
931 | for (const auto *C : AssumedDivergent) { |
||
932 | if (C->contains(OuterDivCycle)) |
||
933 | return; |
||
934 | } |
||
935 | |||
936 | analyzeCycleExitDivergence(*OuterDivCycle); |
||
937 | } |
||
938 | |||
939 | template <typename ContextT> |
||
940 | void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs( |
||
941 | const BlockT &BB) { |
||
942 | LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n"); |
||
943 | for (const auto &I : instrs(BB)) { |
||
944 | // Terminators do not produce values; they are divergent only if |
||
945 | // the condition is divergent. That is handled when the divergent |
||
946 | // condition is placed in the worklist. |
||
947 | if (I.isTerminator()) |
||
948 | break; |
||
949 | |||
950 | // Mark this as divergent. We don't check if the instruction is |
||
951 | // always uniform. In a cycle where the thread convergence is not |
||
952 | // statically known, the instruction is not statically converged, |
||
953 | // and its outputs cannot be statically uniform. |
||
954 | if (markDivergent(I)) |
||
955 | Worklist.push_back(&I); |
||
956 | } |
||
957 | } |
||
958 | |||
959 | /// Mark divergent phi nodes in a join block |
||
960 | template <typename ContextT> |
||
961 | void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes( |
||
962 | const BlockT &JoinBlock) { |
||
963 | LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock) |
||
964 | << "\n"); |
||
965 | for (const auto &Phi : JoinBlock.phis()) { |
||
966 | if (ContextT::isConstantValuePhi(Phi)) |
||
967 | continue; |
||
968 | if (markDivergent(Phi)) |
||
969 | Worklist.push_back(&Phi); |
||
970 | } |
||
971 | } |
||
972 | |||
973 | /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles. |
||
974 | /// |
||
975 | /// \return true iff \p Candidate was added to \p Cycles. |
||
976 | template <typename CycleT> |
||
977 | static bool insertIfNotContained(SmallVector<CycleT *> &Cycles, |
||
978 | CycleT *Candidate) { |
||
979 | if (llvm::any_of(Cycles, |
||
980 | [Candidate](CycleT *C) { return C->contains(Candidate); })) |
||
981 | return false; |
||
982 | Cycles.push_back(Candidate); |
||
983 | return true; |
||
984 | } |
||
985 | |||
986 | /// Return the outermost cycle made divergent by branch outside it. |
||
987 | /// |
||
988 | /// If two paths that diverged outside an irreducible cycle join |
||
989 | /// inside that cycle, then that whole cycle is assumed to be |
||
990 | /// divergent. This does not apply if the cycle is reducible. |
||
991 | template <typename CycleT, typename BlockT> |
||
992 | static const CycleT *getExtDivCycle(const CycleT *Cycle, |
||
993 | const BlockT *DivTermBlock, |
||
994 | const BlockT *JoinBlock) { |
||
995 | assert(Cycle); |
||
996 | assert(Cycle->contains(JoinBlock)); |
||
997 | |||
998 | if (Cycle->contains(DivTermBlock)) |
||
999 | return nullptr; |
||
1000 | |||
1001 | if (Cycle->isReducible()) { |
||
1002 | assert(Cycle->getHeader() == JoinBlock); |
||
1003 | return nullptr; |
||
1004 | } |
||
1005 | |||
1006 | const auto *Parent = Cycle->getParentCycle(); |
||
1007 | while (Parent && !Parent->contains(DivTermBlock)) { |
||
1008 | // If the join is inside a child, then the parent must be |
||
1009 | // irreducible. The only join in a reducible cyle is its own |
||
1010 | // header. |
||
1011 | assert(!Parent->isReducible()); |
||
1012 | Cycle = Parent; |
||
1013 | Parent = Cycle->getParentCycle(); |
||
1014 | } |
||
1015 | |||
1016 | LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n"); |
||
1017 | return Cycle; |
||
1018 | } |
||
1019 | |||
1020 | /// Return the outermost cycle made divergent by branch inside it. |
||
1021 | /// |
||
1022 | /// This checks the "diverged entry" criterion defined in the |
||
1023 | /// docs/ConvergenceAnalysis.html. |
||
1024 | template <typename ContextT, typename CycleT, typename BlockT, |
||
1025 | typename DominatorTreeT> |
||
1026 | static const CycleT * |
||
1027 | getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock, |
||
1028 | const BlockT *JoinBlock, const DominatorTreeT &DT, |
||
1029 | ContextT &Context) { |
||
1030 | LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock) |
||
1031 | << "for internal branch " << Context.print(DivTermBlock) |
||
1032 | << "\n"); |
||
1033 | if (DT.properlyDominates(DivTermBlock, JoinBlock)) |
||
1034 | return nullptr; |
||
1035 | |||
1036 | // Find the smallest common cycle, if one exists. |
||
1037 | assert(Cycle && Cycle->contains(JoinBlock)); |
||
1038 | while (Cycle && !Cycle->contains(DivTermBlock)) { |
||
1039 | Cycle = Cycle->getParentCycle(); |
||
1040 | } |
||
1041 | if (!Cycle || Cycle->isReducible()) |
||
1042 | return nullptr; |
||
1043 | |||
1044 | if (DT.properlyDominates(Cycle->getHeader(), JoinBlock)) |
||
1045 | return nullptr; |
||
1046 | |||
1047 | LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader()) |
||
1048 | << " does not dominate join\n"); |
||
1049 | |||
1050 | const auto *Parent = Cycle->getParentCycle(); |
||
1051 | while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) { |
||
1052 | LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader()) |
||
1053 | << " does not dominate join\n"); |
||
1054 | Cycle = Parent; |
||
1055 | Parent = Parent->getParentCycle(); |
||
1056 | } |
||
1057 | |||
1058 | LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n"); |
||
1059 | return Cycle; |
||
1060 | } |
||
1061 | |||
1062 | template <typename ContextT, typename CycleT, typename BlockT, |
||
1063 | typename DominatorTreeT> |
||
1064 | static const CycleT * |
||
1065 | getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock, |
||
1066 | const BlockT *JoinBlock, const DominatorTreeT &DT, |
||
1067 | ContextT &Context) { |
||
1068 | if (!Cycle) |
||
1069 | return nullptr; |
||
1070 | |||
1071 | // First try to expand Cycle to the largest that contains JoinBlock |
||
1072 | // but not DivTermBlock. |
||
1073 | const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock); |
||
1074 | |||
1075 | // Continue expanding to the largest cycle that contains both. |
||
1076 | const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context); |
||
1077 | |||
1078 | if (Int) |
||
1079 | return Int; |
||
1080 | return Ext; |
||
1081 | } |
||
1082 | |||
1083 | template <typename ContextT> |
||
1084 | void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence( |
||
1085 | const InstructionT &Term) { |
||
1086 | const auto *DivTermBlock = Term.getParent(); |
||
1087 | DivergentTermBlocks.insert(DivTermBlock); |
||
1088 | LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock) |
||
1089 | << "\n"); |
||
1090 | |||
1091 | // Don't propagate divergence from unreachable blocks. |
||
1092 | if (!DT.isReachableFromEntry(DivTermBlock)) |
||
1093 | return; |
||
1094 | |||
1095 | const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock); |
||
1096 | SmallVector<const CycleT *> DivCycles; |
||
1097 | |||
1098 | // Iterate over all blocks now reachable by a disjoint path join |
||
1099 | for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { |
||
1100 | const auto *Cycle = CI.getCycle(JoinBlock); |
||
1101 | LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock) |
||
1102 | << "\n"); |
||
1103 | if (const auto *Outermost = getOutermostDivergentCycle( |
||
1104 | Cycle, DivTermBlock, JoinBlock, DT, Context)) { |
||
1105 | LLVM_DEBUG(dbgs() << "found divergent cycle\n"); |
||
1106 | DivCycles.push_back(Outermost); |
||
1107 | continue; |
||
1108 | } |
||
1109 | taintAndPushPhiNodes(*JoinBlock); |
||
1110 | } |
||
1111 | |||
1112 | // Sort by order of decreasing depth. This allows later cycles to be skipped |
||
1113 | // because they are already contained in earlier ones. |
||
1114 | llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) { |
||
1115 | return A->getDepth() > B->getDepth(); |
||
1116 | }); |
||
1117 | |||
1118 | // Cycles that are assumed divergent due to the diverged entry |
||
1119 | // criterion potentially contain temporal divergence depending on |
||
1120 | // the DFS chosen. Conservatively, all values produced in such a |
||
1121 | // cycle are assumed divergent. "Cycle invariant" values may be |
||
1122 | // assumed uniform, but that requires further analysis. |
||
1123 | for (auto *C : DivCycles) { |
||
1124 | if (!insertIfNotContained(AssumedDivergent, C)) |
||
1125 | continue; |
||
1126 | LLVM_DEBUG(dbgs() << "process divergent cycle\n"); |
||
1127 | for (const BlockT *BB : C->blocks()) { |
||
1128 | taintAndPushAllDefs(*BB); |
||
1129 | } |
||
1130 | } |
||
1131 | |||
1132 | const auto *BranchCycle = CI.getCycle(DivTermBlock); |
||
1133 | assert(DivDesc.CycleDivBlocks.empty() || BranchCycle); |
||
1134 | for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) { |
||
1135 | propagateCycleExitDivergence(*DivExitBlock, *BranchCycle); |
||
1136 | } |
||
1137 | } |
||
1138 | |||
1139 | template <typename ContextT> |
||
1140 | void GenericUniformityAnalysisImpl<ContextT>::compute() { |
||
1141 | // Initialize worklist. |
||
1142 | auto DivValuesCopy = DivergentValues; |
||
1143 | for (const auto DivVal : DivValuesCopy) { |
||
1144 | assert(isDivergent(DivVal) && "Worklist invariant violated!"); |
||
1145 | pushUsers(DivVal); |
||
1146 | } |
||
1147 | |||
1148 | // All values on the Worklist are divergent. |
||
1149 | // Their users may not have been updated yet. |
||
1150 | while (!Worklist.empty()) { |
||
1151 | const InstructionT *I = Worklist.back(); |
||
1152 | Worklist.pop_back(); |
||
1153 | |||
1154 | LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n"); |
||
1155 | |||
1156 | if (I->isTerminator()) { |
||
1157 | analyzeControlDivergence(*I); |
||
1158 | continue; |
||
1159 | } |
||
1160 | |||
1161 | // propagate value divergence to users |
||
1162 | assert(isDivergent(*I) && "Worklist invariant violated!"); |
||
1163 | pushUsers(*I); |
||
1164 | } |
||
1165 | } |
||
1166 | |||
1167 | template <typename ContextT> |
||
1168 | bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform( |
||
1169 | const InstructionT &Instr) const { |
||
1170 | return UniformOverrides.contains(&Instr); |
||
1171 | } |
||
1172 | |||
1173 | template <typename ContextT> |
||
1174 | GenericUniformityInfo<ContextT>::GenericUniformityInfo( |
||
1175 | FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI, |
||
1176 | const TargetTransformInfo *TTI) |
||
1177 | : F(&Func) { |
||
1178 | DA.reset(new ImplT{Func, DT, CI, TTI}); |
||
1179 | DA->initialize(); |
||
1180 | DA->compute(); |
||
1181 | } |
||
1182 | |||
1183 | template <typename ContextT> |
||
1184 | void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const { |
||
1185 | bool haveDivergentArgs = false; |
||
1186 | |||
1187 | if (DivergentValues.empty()) { |
||
1188 | assert(DivergentTermBlocks.empty()); |
||
1189 | assert(DivergentExitCycles.empty()); |
||
1190 | OS << "ALL VALUES UNIFORM\n"; |
||
1191 | return; |
||
1192 | } |
||
1193 | |||
1194 | for (const auto &entry : DivergentValues) { |
||
1195 | const BlockT *parent = Context.getDefBlock(entry); |
||
1196 | if (!parent) { |
||
1197 | if (!haveDivergentArgs) { |
||
1198 | OS << "DIVERGENT ARGUMENTS:\n"; |
||
1199 | haveDivergentArgs = true; |
||
1200 | } |
||
1201 | OS << " DIVERGENT: " << Context.print(entry) << '\n'; |
||
1202 | } |
||
1203 | } |
||
1204 | |||
1205 | if (!AssumedDivergent.empty()) { |
||
1206 | OS << "CYCLES ASSSUMED DIVERGENT:\n"; |
||
1207 | for (const CycleT *cycle : AssumedDivergent) { |
||
1208 | OS << " " << cycle->print(Context) << '\n'; |
||
1209 | } |
||
1210 | } |
||
1211 | |||
1212 | if (!DivergentExitCycles.empty()) { |
||
1213 | OS << "CYCLES WITH DIVERGENT EXIT:\n"; |
||
1214 | for (const CycleT *cycle : DivergentExitCycles) { |
||
1215 | OS << " " << cycle->print(Context) << '\n'; |
||
1216 | } |
||
1217 | } |
||
1218 | |||
1219 | for (auto &block : F) { |
||
1220 | OS << "\nBLOCK " << Context.print(&block) << '\n'; |
||
1221 | |||
1222 | OS << "DEFINITIONS\n"; |
||
1223 | SmallVector<ConstValueRefT, 16> defs; |
||
1224 | Context.appendBlockDefs(defs, block); |
||
1225 | for (auto value : defs) { |
||
1226 | if (isDivergent(value)) |
||
1227 | OS << " DIVERGENT: "; |
||
1228 | else |
||
1229 | OS << " "; |
||
1230 | OS << Context.print(value) << '\n'; |
||
1231 | } |
||
1232 | |||
1233 | OS << "TERMINATORS\n"; |
||
1234 | SmallVector<const InstructionT *, 8> terms; |
||
1235 | Context.appendBlockTerms(terms, block); |
||
1236 | bool divergentTerminators = hasDivergentTerminator(block); |
||
1237 | for (auto *T : terms) { |
||
1238 | if (divergentTerminators) |
||
1239 | OS << " DIVERGENT: "; |
||
1240 | else |
||
1241 | OS << " "; |
||
1242 | OS << Context.print(T) << '\n'; |
||
1243 | } |
||
1244 | |||
1245 | OS << "END BLOCK\n"; |
||
1246 | } |
||
1247 | } |
||
1248 | |||
1249 | template <typename ContextT> |
||
1250 | bool GenericUniformityInfo<ContextT>::hasDivergence() const { |
||
1251 | return DA->hasDivergence(); |
||
1252 | } |
||
1253 | |||
1254 | /// Whether \p V is divergent at its definition. |
||
1255 | template <typename ContextT> |
||
1256 | bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const { |
||
1257 | return DA->isDivergent(V); |
||
1258 | } |
||
1259 | |||
1260 | template <typename ContextT> |
||
1261 | bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) { |
||
1262 | return DA->hasDivergentTerminator(B); |
||
1263 | } |
||
1264 | |||
1265 | /// \brief T helper function for printing. |
||
1266 | template <typename ContextT> |
||
1267 | void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const { |
||
1268 | DA->print(out); |
||
1269 | } |
||
1270 | |||
1271 | template <typename ContextT> |
||
1272 | void llvm::ModifiedPostOrder<ContextT>::computeStackPO( |
||
1273 | SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle, |
||
1274 | SmallPtrSetImpl<BlockT *> &Finalized) { |
||
1275 | LLVM_DEBUG(dbgs() << "inside computeStackPO\n"); |
||
1276 | while (!Stack.empty()) { |
||
1277 | auto *NextBB = Stack.back(); |
||
1278 | if (Finalized.count(NextBB)) { |
||
1279 | Stack.pop_back(); |
||
1280 | continue; |
||
1281 | } |
||
1282 | LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB) |
||
1283 | << "\n"); |
||
1284 | auto *NestedCycle = CI.getCycle(NextBB); |
||
1285 | if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) { |
||
1286 | LLVM_DEBUG(dbgs() << " found a cycle\n"); |
||
1287 | while (NestedCycle->getParentCycle() != Cycle) |
||
1288 | NestedCycle = NestedCycle->getParentCycle(); |
||
1289 | |||
1290 | SmallVector<BlockT *, 3> NestedExits; |
||
1291 | NestedCycle->getExitBlocks(NestedExits); |
||
1292 | bool PushedNodes = false; |
||
1293 | for (auto *NestedExitBB : NestedExits) { |
||
1294 | LLVM_DEBUG(dbgs() << " examine exit: " |
||
1295 | << CI.getSSAContext().print(NestedExitBB) << "\n"); |
||
1296 | if (Cycle && !Cycle->contains(NestedExitBB)) |
||
1297 | continue; |
||
1298 | if (Finalized.count(NestedExitBB)) |
||
1299 | continue; |
||
1300 | PushedNodes = true; |
||
1301 | Stack.push_back(NestedExitBB); |
||
1302 | LLVM_DEBUG(dbgs() << " pushed exit: " |
||
1303 | << CI.getSSAContext().print(NestedExitBB) << "\n"); |
||
1304 | } |
||
1305 | if (!PushedNodes) { |
||
1306 | // All loop exits finalized -> finish this node |
||
1307 | Stack.pop_back(); |
||
1308 | computeCyclePO(CI, NestedCycle, Finalized); |
||
1309 | } |
||
1310 | continue; |
||
1311 | } |
||
1312 | |||
1313 | LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n"); |
||
1314 | // DAG-style |
||
1315 | bool PushedNodes = false; |
||
1316 | for (auto *SuccBB : successors(NextBB)) { |
||
1317 | LLVM_DEBUG(dbgs() << " examine succ: " |
||
1318 | << CI.getSSAContext().print(SuccBB) << "\n"); |
||
1319 | if (Cycle && !Cycle->contains(SuccBB)) |
||
1320 | continue; |
||
1321 | if (Finalized.count(SuccBB)) |
||
1322 | continue; |
||
1323 | PushedNodes = true; |
||
1324 | Stack.push_back(SuccBB); |
||
1325 | LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB) |
||
1326 | << "\n"); |
||
1327 | } |
||
1328 | if (!PushedNodes) { |
||
1329 | // Never push nodes twice |
||
1330 | LLVM_DEBUG(dbgs() << " finishing node: " |
||
1331 | << CI.getSSAContext().print(NextBB) << "\n"); |
||
1332 | Stack.pop_back(); |
||
1333 | Finalized.insert(NextBB); |
||
1334 | appendBlock(*NextBB); |
||
1335 | } |
||
1336 | } |
||
1337 | LLVM_DEBUG(dbgs() << "exited computeStackPO\n"); |
||
1338 | } |
||
1339 | |||
1340 | template <typename ContextT> |
||
1341 | void ModifiedPostOrder<ContextT>::computeCyclePO( |
||
1342 | const CycleInfoT &CI, const CycleT *Cycle, |
||
1343 | SmallPtrSetImpl<BlockT *> &Finalized) { |
||
1344 | LLVM_DEBUG(dbgs() << "inside computeCyclePO\n"); |
||
1345 | SmallVector<BlockT *> Stack; |
||
1346 | auto *CycleHeader = Cycle->getHeader(); |
||
1347 | |||
1348 | LLVM_DEBUG(dbgs() << " noted header: " |
||
1349 | << CI.getSSAContext().print(CycleHeader) << "\n"); |
||
1350 | assert(!Finalized.count(CycleHeader)); |
||
1351 | Finalized.insert(CycleHeader); |
||
1352 | |||
1353 | // Visit the header last |
||
1354 | LLVM_DEBUG(dbgs() << " finishing header: " |
||
1355 | << CI.getSSAContext().print(CycleHeader) << "\n"); |
||
1356 | appendBlock(*CycleHeader, Cycle->isReducible()); |
||
1357 | |||
1358 | // Initialize with immediate successors |
||
1359 | for (auto *BB : successors(CycleHeader)) { |
||
1360 | LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB) |
||
1361 | << "\n"); |
||
1362 | if (!Cycle->contains(BB)) |
||
1363 | continue; |
||
1364 | if (BB == CycleHeader) |
||
1365 | continue; |
||
1366 | if (!Finalized.count(BB)) { |
||
1367 | LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB) |
||
1368 | << "\n"); |
||
1369 | Stack.push_back(BB); |
||
1370 | } |
||
1371 | } |
||
1372 | |||
1373 | // Compute PO inside region |
||
1374 | computeStackPO(Stack, CI, Cycle, Finalized); |
||
1375 | |||
1376 | LLVM_DEBUG(dbgs() << "exited computeCyclePO\n"); |
||
1377 | } |
||
1378 | |||
1379 | /// \brief Generically compute the modified post order. |
||
1380 | template <typename ContextT> |
||
1381 | void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) { |
||
1382 | SmallPtrSet<BlockT *, 32> Finalized; |
||
1383 | SmallVector<BlockT *> Stack; |
||
1384 | auto *F = CI.getFunction(); |
||
1385 | Stack.reserve(24); // FIXME made-up number |
||
1386 | Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F)); |
||
1387 | computeStackPO(Stack, CI, nullptr, Finalized); |
||
1388 | } |
||
1389 | |||
1390 | } // namespace llvm |
||
1391 | |||
1392 | #undef DEBUG_TYPE |
||
1393 | |||
1394 | #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H |