Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
14 pmbaty 1
//===- GenericUniformAnalysis.cpp --------------------*- C++ -*------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This template implementation resides in a separate file so that it
10
// does not get injected into every .cpp file that includes the
11
// generic header.
12
//
13
// DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14
//
15
// This file should only be included by files that implement a
16
// specialization of the relvant templates. Currently these are:
17
// - UniformityAnalysis.cpp
18
//
19
// Note: The DEBUG_TYPE macro should be defined before using this
20
// file so that any use of LLVM_DEBUG is associated with the
21
// including file rather than this file.
22
//
23
//===----------------------------------------------------------------------===//
24
///
25
/// \file
26
/// \brief Implementation of uniformity analysis.
27
///
28
/// The algorithm is a fixed point iteration that starts with the assumption
29
/// that all control flow and all values are uniform. Starting from sources of
30
/// divergence (whose discovery must be implemented by a CFG- or even
31
/// target-specific derived class), divergence of values is propagated from
32
/// definition to uses in a straight-forward way. The main complexity lies in
33
/// the propagation of the impact of divergent control flow on the divergence of
34
/// values (sync dependencies).
35
///
36
//===----------------------------------------------------------------------===//
37
 
38
#ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
39
#define LLVM_ADT_GENERICUNIFORMITYIMPL_H
40
 
41
#include "llvm/ADT/GenericUniformityInfo.h"
42
 
43
#include "llvm/ADT/SmallPtrSet.h"
44
#include "llvm/ADT/SparseBitVector.h"
45
#include "llvm/ADT/StringExtras.h"
46
#include "llvm/Support/raw_ostream.h"
47
 
48
#include <set>
49
 
50
#define DEBUG_TYPE "uniformity"
51
 
52
using namespace llvm;
53
 
54
namespace llvm {
55
 
56
template <typename Range> auto unique(Range &&R) {
57
  return std::unique(adl_begin(R), adl_end(R));
58
}
59
 
60
/// Construct a specially modified post-order traversal of cycles.
61
///
62
/// The ModifiedPO is contructed using a virtually modified CFG as follows:
63
///
64
/// 1. The successors of pre-entry nodes (predecessors of an cycle
65
///    entry that are outside the cycle) are replaced by the
66
///    successors of the successors of the header.
67
/// 2. Successors of the cycle header are replaced by the exit blocks
68
///    of the cycle.
69
///
70
/// Effectively, we produce a depth-first numbering with the following
71
/// properties:
72
///
73
/// 1. Nodes after a cycle are numbered earlier than the cycle header.
74
/// 2. The header is numbered earlier than the nodes in the cycle.
75
/// 3. The numbering of the nodes within the cycle forms an interval
76
///    starting with the header.
77
///
78
/// Effectively, the virtual modification arranges the nodes in a
79
/// cycle as a DAG with the header as the sole leaf, and successors of
80
/// the header as the roots. A reverse traversal of this numbering has
81
/// the following invariant on the unmodified original CFG:
82
///
83
///    Each node is visited after all its predecessors, except if that
84
///    predecessor is the cycle header.
85
///
86
template <typename ContextT> class ModifiedPostOrder {
87
public:
88
  using BlockT = typename ContextT::BlockT;
89
  using FunctionT = typename ContextT::FunctionT;
90
  using DominatorTreeT = typename ContextT::DominatorTreeT;
91
 
92
  using CycleInfoT = GenericCycleInfo<ContextT>;
93
  using CycleT = typename CycleInfoT::CycleT;
94
  using const_iterator = typename std::vector<BlockT *>::const_iterator;
95
 
96
  ModifiedPostOrder(const ContextT &C) : Context(C) {}
97
 
98
  bool empty() const { return m_order.empty(); }
99
  size_t size() const { return m_order.size(); }
100
 
101
  void clear() { m_order.clear(); }
102
  void compute(const CycleInfoT &CI);
103
 
104
  unsigned count(BlockT *BB) const { return POIndex.count(BB); }
105
  const BlockT *operator[](size_t idx) const { return m_order[idx]; }
106
 
107
  void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
108
    POIndex[&BB] = m_order.size();
109
    m_order.push_back(&BB);
110
    LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
111
                      << "): " << Context.print(&BB) << "\n");
112
    if (isReducibleCycleHeader)
113
      ReducibleCycleHeaders.insert(&BB);
114
  }
115
 
116
  unsigned getIndex(const BlockT *BB) const {
117
    assert(POIndex.count(BB));
118
    return POIndex.lookup(BB);
119
  }
120
 
121
  bool isReducibleCycleHeader(const BlockT *BB) const {
122
    return ReducibleCycleHeaders.contains(BB);
123
  }
124
 
125
private:
126
  SmallVector<const BlockT *> m_order;
127
  DenseMap<const BlockT *, unsigned> POIndex;
128
  SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
129
  const ContextT &Context;
130
 
131
  void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
132
                      SmallPtrSetImpl<BlockT *> &Finalized);
133
 
134
  void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI,
135
                      const CycleT *Cycle,
136
                      SmallPtrSetImpl<BlockT *> &Finalized);
137
};
138
 
139
template <typename> class DivergencePropagator;
140
 
141
/// \class GenericSyncDependenceAnalysis
142
///
143
/// \brief Locate join blocks for disjoint paths starting at a divergent branch.
144
///
145
/// An analysis per divergent branch that returns the set of basic
146
/// blocks whose phi nodes become divergent due to divergent control.
147
/// These are the blocks that are reachable by two disjoint paths from
148
/// the branch, or cycle exits reachable along a path that is disjoint
149
/// from a path to the cycle latch.
150
 
151
// --- Above line is not a doxygen comment; intentionally left blank ---
152
//
153
// Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
154
//
155
// The SyncDependenceAnalysis is used in the UniformityAnalysis to model
156
// control-induced divergence in phi nodes.
157
//
158
// -- Reference --
159
// The algorithm is an extension of Section 5 of
160
//
161
//   An abstract interpretation for SPMD divergence
162
//       on reducible control flow graphs.
163
//   Julian Rosemann, Simon Moll and Sebastian Hack
164
//   POPL '21
165
//
166
//
167
// -- Sync dependence --
168
// Sync dependence characterizes the control flow aspect of the
169
// propagation of branch divergence. For example,
170
//
171
//   %cond = icmp slt i32 %tid, 10
172
//   br i1 %cond, label %then, label %else
173
// then:
174
//   br label %merge
175
// else:
176
//   br label %merge
177
// merge:
178
//   %a = phi i32 [ 0, %then ], [ 1, %else ]
179
//
180
// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
181
// because %tid is not on its use-def chains, %a is sync dependent on %tid
182
// because the branch "br i1 %cond" depends on %tid and affects which value %a
183
// is assigned to.
184
//
185
//
186
// -- Reduction to SSA construction --
187
// There are two disjoint paths from A to X, if a certain variant of SSA
188
// construction places a phi node in X under the following set-up scheme.
189
//
190
// This variant of SSA construction ignores incoming undef values.
191
// That is paths from the entry without a definition do not result in
192
// phi nodes.
193
//
194
//       entry
195
//     /      \
196
//    A        \
197
//  /   \       Y
198
// B     C     /
199
//  \   /  \  /
200
//    D     E
201
//     \   /
202
//       F
203
//
204
// Assume that A contains a divergent branch. We are interested
205
// in the set of all blocks where each block is reachable from A
206
// via two disjoint paths. This would be the set {D, F} in this
207
// case.
208
// To generally reduce this query to SSA construction we introduce
209
// a virtual variable x and assign to x different values in each
210
// successor block of A.
211
//
212
//           entry
213
//         /      \
214
//        A        \
215
//      /   \       Y
216
// x = 0   x = 1   /
217
//      \  /   \  /
218
//        D     E
219
//         \   /
220
//           F
221
//
222
// Our flavor of SSA construction for x will construct the following
223
//
224
//            entry
225
//          /      \
226
//         A        \
227
//       /   \       Y
228
// x0 = 0   x1 = 1  /
229
//       \   /   \ /
230
//     x2 = phi   E
231
//         \     /
232
//         x3 = phi
233
//
234
// The blocks D and F contain phi nodes and are thus each reachable
235
// by two disjoins paths from A.
236
//
237
// -- Remarks --
238
// * In case of cycle exits we need to check for temporal divergence.
239
//   To this end, we check whether the definition of x differs between the
240
//   cycle exit and the cycle header (_after_ SSA construction).
241
//
242
// * In the presence of irreducible control flow, the fixed point is
243
//   reached only after multiple iterations. This is because labels
244
//   reaching the header of a cycle must be repropagated through the
245
//   cycle. This is true even in a reducible cycle, since the labels
246
//   may have been produced by a nested irreducible cycle.
247
//
248
// * Note that SyncDependenceAnalysis is not concerned with the points
249
//   of convergence in an irreducible cycle. It's only purpose is to
250
//   identify join blocks. The "diverged entry" criterion is
251
//   separately applied on join blocks to determine if an entire
252
//   irreducible cycle is assumed to be divergent.
253
//
254
// * Relevant related work:
255
//     A simple algorithm for global data flow analysis problems.
256
//     Matthew S. Hecht and Jeffrey D. Ullman.
257
//     SIAM Journal on Computing, 4(4):519–532, December 1975.
258
//
259
template <typename ContextT> class GenericSyncDependenceAnalysis {
260
public:
261
  using BlockT = typename ContextT::BlockT;
262
  using DominatorTreeT = typename ContextT::DominatorTreeT;
263
  using FunctionT = typename ContextT::FunctionT;
264
  using ValueRefT = typename ContextT::ValueRefT;
265
  using InstructionT = typename ContextT::InstructionT;
266
 
267
  using CycleInfoT = GenericCycleInfo<ContextT>;
268
  using CycleT = typename CycleInfoT::CycleT;
269
 
270
  using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
271
  using ModifiedPO = ModifiedPostOrder<ContextT>;
272
 
273
  // * if BlockLabels[B] == C then C is the dominating definition at
274
  //   block B
275
  // * if BlockLabels[B] == nullptr then we haven't seen B yet
276
  // * if BlockLabels[B] == B then:
277
  //   - B is a join point of disjoint paths from X, or,
278
  //   - B is an immediate successor of X (initial value), or,
279
  //   - B is X
280
  using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
281
 
282
  /// Information discovered by the sync dependence analysis for each
283
  /// divergent branch.
284
  struct DivergenceDescriptor {
285
    // Join points of diverged paths.
286
    ConstBlockSet JoinDivBlocks;
287
    // Divergent cycle exits
288
    ConstBlockSet CycleDivBlocks;
289
    // Labels assigned to blocks on diverged paths.
290
    BlockLabelMap BlockLabels;
291
  };
292
 
293
  using DivergencePropagatorT = DivergencePropagator<ContextT>;
294
 
295
  GenericSyncDependenceAnalysis(const ContextT &Context,
296
                                const DominatorTreeT &DT, const CycleInfoT &CI);
297
 
298
  /// \brief Computes divergent join points and cycle exits caused by branch
299
  /// divergence in \p Term.
300
  ///
301
  /// This returns a pair of sets:
302
  /// * The set of blocks which are reachable by disjoint paths from
303
  ///   \p Term.
304
  /// * The set also contains cycle exits if there two disjoint paths:
305
  ///   one from \p Term to the cycle exit and another from \p Term to
306
  ///   the cycle header.
307
  const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
308
 
309
private:
310
  static DivergenceDescriptor EmptyDivergenceDesc;
311
 
312
  ModifiedPO CyclePO;
313
 
314
  const DominatorTreeT &DT;
315
  const CycleInfoT &CI;
316
 
317
  DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
318
      CachedControlDivDescs;
319
};
320
 
321
/// \brief Analysis that identifies uniform values in a data-parallel
322
/// execution.
323
///
324
/// This analysis propagates divergence in a data-parallel context
325
/// from sources of divergence to all users. It can be instantiated
326
/// for an IR that provides a suitable SSAContext.
327
template <typename ContextT> class GenericUniformityAnalysisImpl {
328
public:
329
  using BlockT = typename ContextT::BlockT;
330
  using FunctionT = typename ContextT::FunctionT;
331
  using ValueRefT = typename ContextT::ValueRefT;
332
  using ConstValueRefT = typename ContextT::ConstValueRefT;
333
  using InstructionT = typename ContextT::InstructionT;
334
  using DominatorTreeT = typename ContextT::DominatorTreeT;
335
 
336
  using CycleInfoT = GenericCycleInfo<ContextT>;
337
  using CycleT = typename CycleInfoT::CycleT;
338
 
339
  using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
340
  using DivergenceDescriptorT =
341
      typename SyncDependenceAnalysisT::DivergenceDescriptor;
342
  using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
343
 
344
  GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT,
345
                                const CycleInfoT &CI,
346
                                const TargetTransformInfo *TTI)
347
      : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT),
348
        SDA(Context, DT, CI) {}
349
 
350
  void initialize();
351
 
352
  const FunctionT &getFunction() const { return F; }
353
 
354
  /// \brief Mark \p UniVal as a value that is always uniform.
355
  void addUniformOverride(const InstructionT &Instr);
356
 
357
  /// \brief Mark \p DivVal as a value that is always divergent.
358
  /// \returns Whether the tracked divergence state of \p DivVal changed.
359
  bool markDivergent(const InstructionT &I);
360
  bool markDivergent(ConstValueRefT DivVal);
361
  bool markDefsDivergent(const InstructionT &Instr,
362
                         bool AllDefsDivergent = true);
363
 
364
  /// \brief Propagate divergence to all instructions in the region.
365
  /// Divergence is seeded by calls to \p markDivergent.
366
  void compute();
367
 
368
  /// \brief Whether any value was marked or analyzed to be divergent.
369
  bool hasDivergence() const { return !DivergentValues.empty(); }
370
 
371
  /// \brief Whether \p Val will always return a uniform value regardless of its
372
  /// operands
373
  bool isAlwaysUniform(const InstructionT &Instr) const;
374
 
375
  bool hasDivergentDefs(const InstructionT &I) const;
376
 
377
  bool isDivergent(const InstructionT &I) const {
378
    if (I.isTerminator()) {
379
      return DivergentTermBlocks.contains(I.getParent());
380
    }
381
    return hasDivergentDefs(I);
382
  };
383
 
384
  /// \brief Whether \p Val is divergent at its definition.
385
  bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
386
 
387
  bool hasDivergentTerminator(const BlockT &B) const {
388
    return DivergentTermBlocks.contains(&B);
389
  }
390
 
391
  void print(raw_ostream &out) const;
392
 
393
protected:
394
  /// \brief Value/block pair representing a single phi input.
395
  struct PhiInput {
396
    ConstValueRefT value;
397
    BlockT *predBlock;
398
 
399
    PhiInput(ConstValueRefT value, BlockT *predBlock)
400
        : value(value), predBlock(predBlock) {}
401
  };
402
 
403
  const ContextT &Context;
404
  const FunctionT &F;
405
  const CycleInfoT &CI;
406
  const TargetTransformInfo *TTI = nullptr;
407
 
408
  // Detected/marked divergent values.
409
  std::set<ConstValueRefT> DivergentValues;
410
  SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
411
 
412
  // Internal worklist for divergence propagation.
413
  std::vector<const InstructionT *> Worklist;
414
 
415
  /// \brief Mark \p Term as divergent and push all Instructions that become
416
  /// divergent as a result on the worklist.
417
  void analyzeControlDivergence(const InstructionT &Term);
418
 
419
private:
420
  const DominatorTreeT &DT;
421
 
422
  // Recognized cycles with divergent exits.
423
  SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
424
 
425
  // Cycles assumed to be divergent.
426
  //
427
  // We don't use a set here because every insertion needs an explicit
428
  // traversal of all existing members.
429
  SmallVector<const CycleT *> AssumedDivergent;
430
 
431
  // The SDA links divergent branches to divergent control-flow joins.
432
  SyncDependenceAnalysisT SDA;
433
 
434
  // Set of known-uniform values.
435
  SmallPtrSet<const InstructionT *, 32> UniformOverrides;
436
 
437
  /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
438
  /// the worklist.
439
  void taintAndPushAllDefs(const BlockT &JoinBlock);
440
 
441
  /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
442
  /// the worklist.
443
  void taintAndPushPhiNodes(const BlockT &JoinBlock);
444
 
445
  /// \brief Identify all Instructions that become divergent because \p DivExit
446
  /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
447
  /// divergent and push them on the worklist.
448
  void propagateCycleExitDivergence(const BlockT &DivExit,
449
                                    const CycleT &DivCycle);
450
 
451
  /// \brief Internal implementation function for propagateCycleExitDivergence.
452
  void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
453
 
454
  /// \brief Mark all instruction as divergent that use a value defined in \p
455
  /// OuterDivCycle. Push their users on the worklist.
456
  void analyzeTemporalDivergence(const InstructionT &I,
457
                                 const CycleT &OuterDivCycle);
458
 
459
  /// \brief Push all users of \p Val (in the region) to the worklist.
460
  void pushUsers(const InstructionT &I);
461
  void pushUsers(ConstValueRefT V);
462
 
463
  bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
464
 
465
  /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
466
  bool isTemporalDivergent(const BlockT &ObservingBlock,
467
                           ConstValueRefT Val) const;
468
};
469
 
470
template <typename ImplT>
471
void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
472
  delete Impl;
473
}
474
 
475
/// Compute divergence starting with a divergent branch.
476
template <typename ContextT> class DivergencePropagator {
477
public:
478
  using BlockT = typename ContextT::BlockT;
479
  using DominatorTreeT = typename ContextT::DominatorTreeT;
480
  using FunctionT = typename ContextT::FunctionT;
481
  using ValueRefT = typename ContextT::ValueRefT;
482
 
483
  using CycleInfoT = GenericCycleInfo<ContextT>;
484
  using CycleT = typename CycleInfoT::CycleT;
485
 
486
  using ModifiedPO = ModifiedPostOrder<ContextT>;
487
  using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
488
  using DivergenceDescriptorT =
489
      typename SyncDependenceAnalysisT::DivergenceDescriptor;
490
  using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
491
 
492
  const ModifiedPO &CyclePOT;
493
  const DominatorTreeT &DT;
494
  const CycleInfoT &CI;
495
  const BlockT &DivTermBlock;
496
  const ContextT &Context;
497
 
498
  // Track blocks that receive a new label. Every time we relabel a
499
  // cycle header, we another pass over the modified post-order in
500
  // order to propagate the header label. The bit vector also allows
501
  // us to skip labels that have not changed.
502
  SparseBitVector<> FreshLabels;
503
 
504
  // divergent join and cycle exit descriptor.
505
  std::unique_ptr<DivergenceDescriptorT> DivDesc;
506
  BlockLabelMapT &BlockLabels;
507
 
508
  DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
509
                       const CycleInfoT &CI, const BlockT &DivTermBlock)
510
      : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
511
        Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
512
        BlockLabels(DivDesc->BlockLabels) {}
513
 
514
  void printDefs(raw_ostream &Out) {
515
    Out << "Propagator::BlockLabels {\n";
516
    for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
517
      const auto *Block = CyclePOT[BlockIdx];
518
      const auto *Label = BlockLabels[Block];
519
      Out << Context.print(Block) << "(" << BlockIdx << ") : ";
520
      if (!Label) {
521
        Out << "<null>\n";
522
      } else {
523
        Out << Context.print(Label) << "\n";
524
      }
525
    }
526
    Out << "}\n";
527
  }
528
 
529
  // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
530
  // causes a divergent join.
531
  bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
532
    const auto *OldLabel = BlockLabels[&SuccBlock];
533
 
534
    LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
535
                      << "\tpushed label: " << Context.print(&PushedLabel)
536
                      << "\n"
537
                      << "\told label: " << Context.print(OldLabel) << "\n");
538
 
539
    // Early exit if there is no change in the label.
540
    if (OldLabel == &PushedLabel)
541
      return false;
542
 
543
    if (OldLabel != &SuccBlock) {
544
      auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
545
      // Assigning a new label, mark this in FreshLabels.
546
      LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
547
      FreshLabels.set(SuccIdx);
548
    }
549
 
550
    // This is not a join if the succ was previously unlabeled.
551
    if (!OldLabel) {
552
      LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
553
                        << "\n");
554
      BlockLabels[&SuccBlock] = &PushedLabel;
555
      return false;
556
    }
557
 
558
    // This is a new join. Label the join block as itself, and not as
559
    // the pushed label.
560
    LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
561
    BlockLabels[&SuccBlock] = &SuccBlock;
562
 
563
    return true;
564
  }
565
 
566
  // visiting a virtual cycle exit edge from the cycle header --> temporal
567
  // divergence on join
568
  bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
569
    if (!computeJoin(ExitBlock, Label))
570
      return false;
571
 
572
    // Identified a divergent cycle exit
573
    DivDesc->CycleDivBlocks.insert(&ExitBlock);
574
    LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
575
                      << "\n");
576
    return true;
577
  }
578
 
579
  // process \p SuccBlock with reaching definition \p Label
580
  bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
581
    if (!computeJoin(SuccBlock, Label))
582
      return false;
583
 
584
    // Divergent, disjoint paths join.
585
    DivDesc->JoinDivBlocks.insert(&SuccBlock);
586
    LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
587
                      << "\n");
588
    return true;
589
  }
590
 
591
  std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
592
    assert(DivDesc);
593
 
594
    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
595
                      << Context.print(&DivTermBlock) << "\n");
596
 
597
    // Early stopping criterion
598
    int FloorIdx = CyclePOT.size() - 1;
599
    const BlockT *FloorLabel = nullptr;
600
    int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
601
 
602
    // Bootstrap with branch targets
603
    auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
604
    for (const auto *SuccBlock : successors(&DivTermBlock)) {
605
      if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
606
        // If DivTerm exits the cycle immediately, computeJoin() might
607
        // not reach SuccBlock with a different label. We need to
608
        // check for this exit now.
609
        DivDesc->CycleDivBlocks.insert(SuccBlock);
610
        LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
611
                          << Context.print(SuccBlock) << "\n");
612
      }
613
      auto SuccIdx = CyclePOT.getIndex(SuccBlock);
614
      visitEdge(*SuccBlock, *SuccBlock);
615
      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
616
    }
617
 
618
    while (true) {
619
      auto BlockIdx = FreshLabels.find_last();
620
      if (BlockIdx == -1 || BlockIdx < FloorIdx)
621
        break;
622
 
623
      LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
624
 
625
      FreshLabels.reset(BlockIdx);
626
      if (BlockIdx == DivTermIdx) {
627
        LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
628
        continue;
629
      }
630
 
631
      const auto *Block = CyclePOT[BlockIdx];
632
      LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
633
                        << BlockIdx << "\n");
634
 
635
      const auto *Label = BlockLabels[Block];
636
      assert(Label);
637
 
638
      bool CausedJoin = false;
639
      int LoweredFloorIdx = FloorIdx;
640
 
641
      // If the current block is the header of a reducible cycle that
642
      // contains the divergent branch, then the label should be
643
      // propagated to the cycle exits. Such a header is the "last
644
      // possible join" of any disjoint paths within this cycle. This
645
      // prevents detection of spurious joins at the entries of any
646
      // irreducible child cycles.
647
      //
648
      // This conclusion about the header is true for any choice of DFS:
649
      //
650
      //   If some DFS has a reducible cycle C with header H, then for
651
      //   any other DFS, H is the header of a cycle C' that is a
652
      //   superset of C. For a divergent branch inside the subgraph
653
      //   C, any join node inside C is either H, or some node
654
      //   encountered without passing through H.
655
      //
656
      auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
657
        if (!CyclePOT.isReducibleCycleHeader(Block))
658
          return nullptr;
659
        const auto *BlockCycle = CI.getCycle(Block);
660
        if (BlockCycle->contains(&DivTermBlock))
661
          return BlockCycle;
662
        return nullptr;
663
      };
664
 
665
      if (const auto *BlockCycle = getReducibleParent(Block)) {
666
        SmallVector<BlockT *, 4> BlockCycleExits;
667
        BlockCycle->getExitBlocks(BlockCycleExits);
668
        for (auto *BlockCycleExit : BlockCycleExits) {
669
          CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
670
          LoweredFloorIdx =
671
              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
672
        }
673
      } else {
674
        for (const auto *SuccBlock : successors(Block)) {
675
          CausedJoin |= visitEdge(*SuccBlock, *Label);
676
          LoweredFloorIdx =
677
              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
678
        }
679
      }
680
 
681
      // Floor update
682
      if (CausedJoin) {
683
        // 1. Different labels pushed to successors
684
        FloorIdx = LoweredFloorIdx;
685
      } else if (FloorLabel != Label) {
686
        // 2. No join caused BUT we pushed a label that is different than the
687
        // last pushed label
688
        FloorIdx = LoweredFloorIdx;
689
        FloorLabel = Label;
690
      }
691
    }
692
 
693
    LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
694
 
695
    // Check every cycle containing DivTermBlock for exit divergence.
696
    // A cycle has exit divergence if the label of an exit block does
697
    // not match the label of its header.
698
    for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
699
         Cycle = Cycle->getParentCycle()) {
700
      if (Cycle->isReducible()) {
701
        // The exit divergence of a reducible cycle is recorded while
702
        // propagating labels.
703
        continue;
704
      }
705
      SmallVector<BlockT *> Exits;
706
      Cycle->getExitBlocks(Exits);
707
      auto *Header = Cycle->getHeader();
708
      auto *HeaderLabel = BlockLabels[Header];
709
      for (const auto *Exit : Exits) {
710
        if (BlockLabels[Exit] != HeaderLabel) {
711
          // Identified a divergent cycle exit
712
          DivDesc->CycleDivBlocks.insert(Exit);
713
          LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
714
                            << "\n");
715
        }
716
      }
717
    }
718
 
719
    return std::move(DivDesc);
720
  }
721
};
722
 
723
template <typename ContextT>
724
typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
725
    llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
726
 
727
template <typename ContextT>
728
llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
729
    const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
730
    : CyclePO(Context), DT(DT), CI(CI) {
731
  CyclePO.compute(CI);
732
}
733
 
734
template <typename ContextT>
735
auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
736
    const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
737
  // trivial case
738
  if (succ_size(DivTermBlock) <= 1) {
739
    return EmptyDivergenceDesc;
740
  }
741
 
742
  // already available in cache?
743
  auto ItCached = CachedControlDivDescs.find(DivTermBlock);
744
  if (ItCached != CachedControlDivDescs.end())
745
    return *ItCached->second;
746
 
747
  // compute all join points
748
  DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
749
  auto DivDesc = Propagator.computeJoinPoints();
750
 
751
  auto printBlockSet = [&](ConstBlockSet &Blocks) {
752
    return Printable([&](raw_ostream &Out) {
753
      Out << "[";
754
      ListSeparator LS;
755
      for (const auto *BB : Blocks) {
756
        Out << LS << CI.getSSAContext().print(BB);
757
      }
758
      Out << "]\n";
759
    });
760
  };
761
 
762
  LLVM_DEBUG(
763
      dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
764
             << "):\n  JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
765
             << "  CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
766
             << "\n");
767
  (void)printBlockSet;
768
 
769
  auto ItInserted =
770
      CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
771
  assert(ItInserted.second);
772
  return *ItInserted.first->second;
773
}
774
 
775
template <typename ContextT>
776
bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
777
    const InstructionT &I) {
778
  if (I.isTerminator()) {
779
    if (DivergentTermBlocks.insert(I.getParent()).second) {
780
      LLVM_DEBUG(dbgs() << "marked divergent term block: "
781
                        << Context.print(I.getParent()) << "\n");
782
      return true;
783
    }
784
    return false;
785
  }
786
 
787
  return markDefsDivergent(I);
788
}
789
 
790
template <typename ContextT>
791
bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
792
    ConstValueRefT Val) {
793
  if (DivergentValues.insert(Val).second) {
794
    LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
795
    return true;
796
  }
797
  return false;
798
}
799
 
800
template <typename ContextT>
801
void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
802
    const InstructionT &Instr) {
803
  UniformOverrides.insert(&Instr);
804
}
805
 
806
template <typename ContextT>
807
void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
808
    const InstructionT &I, const CycleT &OuterDivCycle) {
809
  if (isDivergent(I))
810
    return;
811
 
812
  LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
813
                    << "\n");
814
  if (!usesValueFromCycle(I, OuterDivCycle))
815
    return;
816
 
817
  if (isAlwaysUniform(I))
818
    return;
819
 
820
  if (markDivergent(I))
821
    Worklist.push_back(&I);
822
}
823
 
824
// Mark all external users of values defined inside \param
825
// OuterDivCycle as divergent.
826
//
827
// This follows all live out edges wherever they may lead. Potential
828
// users of values defined inside DivCycle could be anywhere in the
829
// dominance region of DivCycle (including its fringes for phi nodes).
830
// A cycle C dominates a block B iff every path from the entry block
831
// to B must pass through a block contained in C. If C is a reducible
832
// cycle (or natural loop), C dominates B iff the header of C
833
// dominates B. But in general, we iteratively examine cycle cycle
834
// exits and their successors.
835
template <typename ContextT>
836
void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
837
    const CycleT &OuterDivCycle) {
838
  // Set of blocks that are dominated by the cycle, i.e., each is only
839
  // reachable from paths that pass through the cycle.
840
  SmallPtrSet<BlockT *, 16> DomRegion;
841
 
842
  // The boundary of DomRegion, formed by blocks that are not
843
  // dominated by the cycle.
844
  SmallVector<BlockT *> DomFrontier;
845
  OuterDivCycle.getExitBlocks(DomFrontier);
846
 
847
  // Returns true if BB is dominated by the cycle.
848
  auto isInDomRegion = [&](BlockT *BB) {
849
    for (auto *P : predecessors(BB)) {
850
      if (OuterDivCycle.contains(P))
851
        continue;
852
      if (DomRegion.count(P))
853
        continue;
854
      return false;
855
    }
856
    return true;
857
  };
858
 
859
  // Keep advancing the frontier along successor edges, while
860
  // promoting blocks to DomRegion.
861
  while (true) {
862
    bool Promoted = false;
863
    SmallVector<BlockT *> Temp;
864
    for (auto *W : DomFrontier) {
865
      if (!isInDomRegion(W)) {
866
        Temp.push_back(W);
867
        continue;
868
      }
869
      DomRegion.insert(W);
870
      Promoted = true;
871
      for (auto *Succ : successors(W)) {
872
        if (DomRegion.contains(Succ))
873
          continue;
874
        Temp.push_back(Succ);
875
      }
876
    }
877
    if (!Promoted)
878
      break;
879
    DomFrontier = Temp;
880
  }
881
 
882
  // At DomFrontier, only the PHI nodes are affected by temporal
883
  // divergence.
884
  for (const auto *UserBlock : DomFrontier) {
885
    LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
886
                      << Context.print(UserBlock) << "\n");
887
    for (const auto &Phi : UserBlock->phis()) {
888
      LLVM_DEBUG(dbgs() << "  " << Context.print(&Phi) << "\n");
889
      analyzeTemporalDivergence(Phi, OuterDivCycle);
890
    }
891
  }
892
 
893
  // All instructions inside the dominance region are affected by
894
  // temporal divergence.
895
  for (const auto *UserBlock : DomRegion) {
896
    LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
897
                      << Context.print(UserBlock) << "\n");
898
    for (const auto &I : *UserBlock) {
899
      LLVM_DEBUG(dbgs() << "  " << Context.print(&I) << "\n");
900
      analyzeTemporalDivergence(I, OuterDivCycle);
901
    }
902
  }
903
}
904
 
905
template <typename ContextT>
906
void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
907
    const BlockT &DivExit, const CycleT &InnerDivCycle) {
908
  LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
909
                    << "\n");
910
  auto *DivCycle = &InnerDivCycle;
911
  auto *OuterDivCycle = DivCycle;
912
  auto *ExitLevelCycle = CI.getCycle(&DivExit);
913
  const unsigned CycleExitDepth =
914
      ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
915
 
916
  // Find outer-most cycle that does not contain \p DivExit
917
  while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
918
    LLVM_DEBUG(dbgs() << "  Found exiting cycle: "
919
                      << Context.print(DivCycle->getHeader()) << "\n");
920
    OuterDivCycle = DivCycle;
921
    DivCycle = DivCycle->getParentCycle();
922
  }
923
  LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
924
                    << Context.print(OuterDivCycle->getHeader()) << "\n");
925
 
926
  if (!DivergentExitCycles.insert(OuterDivCycle).second)
927
    return;
928
 
929
  // Exit divergence does not matter if the cycle itself is assumed to
930
  // be divergent.
931
  for (const auto *C : AssumedDivergent) {
932
    if (C->contains(OuterDivCycle))
933
      return;
934
  }
935
 
936
  analyzeCycleExitDivergence(*OuterDivCycle);
937
}
938
 
939
template <typename ContextT>
940
void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
941
    const BlockT &BB) {
942
  LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
943
  for (const auto &I : instrs(BB)) {
944
    // Terminators do not produce values; they are divergent only if
945
    // the condition is divergent. That is handled when the divergent
946
    // condition is placed in the worklist.
947
    if (I.isTerminator())
948
      break;
949
 
950
    // Mark this as divergent. We don't check if the instruction is
951
    // always uniform. In a cycle where the thread convergence is not
952
    // statically known, the instruction is not statically converged,
953
    // and its outputs cannot be statically uniform.
954
    if (markDivergent(I))
955
      Worklist.push_back(&I);
956
  }
957
}
958
 
959
/// Mark divergent phi nodes in a join block
960
template <typename ContextT>
961
void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
962
    const BlockT &JoinBlock) {
963
  LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
964
                    << "\n");
965
  for (const auto &Phi : JoinBlock.phis()) {
966
    if (ContextT::isConstantValuePhi(Phi))
967
      continue;
968
    if (markDivergent(Phi))
969
      Worklist.push_back(&Phi);
970
  }
971
}
972
 
973
/// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
974
///
975
/// \return true iff \p Candidate was added to \p Cycles.
976
template <typename CycleT>
977
static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
978
                                 CycleT *Candidate) {
979
  if (llvm::any_of(Cycles,
980
                   [Candidate](CycleT *C) { return C->contains(Candidate); }))
981
    return false;
982
  Cycles.push_back(Candidate);
983
  return true;
984
}
985
 
986
/// Return the outermost cycle made divergent by branch outside it.
987
///
988
/// If two paths that diverged outside an irreducible cycle join
989
/// inside that cycle, then that whole cycle is assumed to be
990
/// divergent. This does not apply if the cycle is reducible.
991
template <typename CycleT, typename BlockT>
992
static const CycleT *getExtDivCycle(const CycleT *Cycle,
993
                                    const BlockT *DivTermBlock,
994
                                    const BlockT *JoinBlock) {
995
  assert(Cycle);
996
  assert(Cycle->contains(JoinBlock));
997
 
998
  if (Cycle->contains(DivTermBlock))
999
    return nullptr;
1000
 
1001
  if (Cycle->isReducible()) {
1002
    assert(Cycle->getHeader() == JoinBlock);
1003
    return nullptr;
1004
  }
1005
 
1006
  const auto *Parent = Cycle->getParentCycle();
1007
  while (Parent && !Parent->contains(DivTermBlock)) {
1008
    // If the join is inside a child, then the parent must be
1009
    // irreducible. The only join in a reducible cyle is its own
1010
    // header.
1011
    assert(!Parent->isReducible());
1012
    Cycle = Parent;
1013
    Parent = Cycle->getParentCycle();
1014
  }
1015
 
1016
  LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
1017
  return Cycle;
1018
}
1019
 
1020
/// Return the outermost cycle made divergent by branch inside it.
1021
///
1022
/// This checks the "diverged entry" criterion defined in the
1023
/// docs/ConvergenceAnalysis.html.
1024
template <typename ContextT, typename CycleT, typename BlockT,
1025
          typename DominatorTreeT>
1026
static const CycleT *
1027
getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1028
               const BlockT *JoinBlock, const DominatorTreeT &DT,
1029
               ContextT &Context) {
1030
  LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
1031
                    << "for internal branch " << Context.print(DivTermBlock)
1032
                    << "\n");
1033
  if (DT.properlyDominates(DivTermBlock, JoinBlock))
1034
    return nullptr;
1035
 
1036
  // Find the smallest common cycle, if one exists.
1037
  assert(Cycle && Cycle->contains(JoinBlock));
1038
  while (Cycle && !Cycle->contains(DivTermBlock)) {
1039
    Cycle = Cycle->getParentCycle();
1040
  }
1041
  if (!Cycle || Cycle->isReducible())
1042
    return nullptr;
1043
 
1044
  if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
1045
    return nullptr;
1046
 
1047
  LLVM_DEBUG(dbgs() << "  header " << Context.print(Cycle->getHeader())
1048
                    << " does not dominate join\n");
1049
 
1050
  const auto *Parent = Cycle->getParentCycle();
1051
  while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1052
    LLVM_DEBUG(dbgs() << "  header " << Context.print(Parent->getHeader())
1053
                      << " does not dominate join\n");
1054
    Cycle = Parent;
1055
    Parent = Parent->getParentCycle();
1056
  }
1057
 
1058
  LLVM_DEBUG(dbgs() << "  cycle made divergent by internal branch\n");
1059
  return Cycle;
1060
}
1061
 
1062
template <typename ContextT, typename CycleT, typename BlockT,
1063
          typename DominatorTreeT>
1064
static const CycleT *
1065
getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1066
                           const BlockT *JoinBlock, const DominatorTreeT &DT,
1067
                           ContextT &Context) {
1068
  if (!Cycle)
1069
    return nullptr;
1070
 
1071
  // First try to expand Cycle to the largest that contains JoinBlock
1072
  // but not DivTermBlock.
1073
  const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1074
 
1075
  // Continue expanding to the largest cycle that contains both.
1076
  const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1077
 
1078
  if (Int)
1079
    return Int;
1080
  return Ext;
1081
}
1082
 
1083
template <typename ContextT>
1084
void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1085
    const InstructionT &Term) {
1086
  const auto *DivTermBlock = Term.getParent();
1087
  DivergentTermBlocks.insert(DivTermBlock);
1088
  LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1089
                    << "\n");
1090
 
1091
  // Don't propagate divergence from unreachable blocks.
1092
  if (!DT.isReachableFromEntry(DivTermBlock))
1093
    return;
1094
 
1095
  const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1096
  SmallVector<const CycleT *> DivCycles;
1097
 
1098
  // Iterate over all blocks now reachable by a disjoint path join
1099
  for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1100
    const auto *Cycle = CI.getCycle(JoinBlock);
1101
    LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1102
                      << "\n");
1103
    if (const auto *Outermost = getOutermostDivergentCycle(
1104
            Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1105
      LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1106
      DivCycles.push_back(Outermost);
1107
      continue;
1108
    }
1109
    taintAndPushPhiNodes(*JoinBlock);
1110
  }
1111
 
1112
  // Sort by order of decreasing depth. This allows later cycles to be skipped
1113
  // because they are already contained in earlier ones.
1114
  llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1115
    return A->getDepth() > B->getDepth();
1116
  });
1117
 
1118
  // Cycles that are assumed divergent due to the diverged entry
1119
  // criterion potentially contain temporal divergence depending on
1120
  // the DFS chosen. Conservatively, all values produced in such a
1121
  // cycle are assumed divergent. "Cycle invariant" values may be
1122
  // assumed uniform, but that requires further analysis.
1123
  for (auto *C : DivCycles) {
1124
    if (!insertIfNotContained(AssumedDivergent, C))
1125
      continue;
1126
    LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1127
    for (const BlockT *BB : C->blocks()) {
1128
      taintAndPushAllDefs(*BB);
1129
    }
1130
  }
1131
 
1132
  const auto *BranchCycle = CI.getCycle(DivTermBlock);
1133
  assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1134
  for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1135
    propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1136
  }
1137
}
1138
 
1139
template <typename ContextT>
1140
void GenericUniformityAnalysisImpl<ContextT>::compute() {
1141
  // Initialize worklist.
1142
  auto DivValuesCopy = DivergentValues;
1143
  for (const auto DivVal : DivValuesCopy) {
1144
    assert(isDivergent(DivVal) && "Worklist invariant violated!");
1145
    pushUsers(DivVal);
1146
  }
1147
 
1148
  // All values on the Worklist are divergent.
1149
  // Their users may not have been updated yet.
1150
  while (!Worklist.empty()) {
1151
    const InstructionT *I = Worklist.back();
1152
    Worklist.pop_back();
1153
 
1154
    LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1155
 
1156
    if (I->isTerminator()) {
1157
      analyzeControlDivergence(*I);
1158
      continue;
1159
    }
1160
 
1161
    // propagate value divergence to users
1162
    assert(isDivergent(*I) && "Worklist invariant violated!");
1163
    pushUsers(*I);
1164
  }
1165
}
1166
 
1167
template <typename ContextT>
1168
bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1169
    const InstructionT &Instr) const {
1170
  return UniformOverrides.contains(&Instr);
1171
}
1172
 
1173
template <typename ContextT>
1174
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1175
    FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI,
1176
    const TargetTransformInfo *TTI)
1177
    : F(&Func) {
1178
  DA.reset(new ImplT{Func, DT, CI, TTI});
1179
  DA->initialize();
1180
  DA->compute();
1181
}
1182
 
1183
template <typename ContextT>
1184
void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1185
  bool haveDivergentArgs = false;
1186
 
1187
  if (DivergentValues.empty()) {
1188
    assert(DivergentTermBlocks.empty());
1189
    assert(DivergentExitCycles.empty());
1190
    OS << "ALL VALUES UNIFORM\n";
1191
    return;
1192
  }
1193
 
1194
  for (const auto &entry : DivergentValues) {
1195
    const BlockT *parent = Context.getDefBlock(entry);
1196
    if (!parent) {
1197
      if (!haveDivergentArgs) {
1198
        OS << "DIVERGENT ARGUMENTS:\n";
1199
        haveDivergentArgs = true;
1200
      }
1201
      OS << "  DIVERGENT: " << Context.print(entry) << '\n';
1202
    }
1203
  }
1204
 
1205
  if (!AssumedDivergent.empty()) {
1206
    OS << "CYCLES ASSSUMED DIVERGENT:\n";
1207
    for (const CycleT *cycle : AssumedDivergent) {
1208
      OS << "  " << cycle->print(Context) << '\n';
1209
    }
1210
  }
1211
 
1212
  if (!DivergentExitCycles.empty()) {
1213
    OS << "CYCLES WITH DIVERGENT EXIT:\n";
1214
    for (const CycleT *cycle : DivergentExitCycles) {
1215
      OS << "  " << cycle->print(Context) << '\n';
1216
    }
1217
  }
1218
 
1219
  for (auto &block : F) {
1220
    OS << "\nBLOCK " << Context.print(&block) << '\n';
1221
 
1222
    OS << "DEFINITIONS\n";
1223
    SmallVector<ConstValueRefT, 16> defs;
1224
    Context.appendBlockDefs(defs, block);
1225
    for (auto value : defs) {
1226
      if (isDivergent(value))
1227
        OS << "  DIVERGENT: ";
1228
      else
1229
        OS << "             ";
1230
      OS << Context.print(value) << '\n';
1231
    }
1232
 
1233
    OS << "TERMINATORS\n";
1234
    SmallVector<const InstructionT *, 8> terms;
1235
    Context.appendBlockTerms(terms, block);
1236
    bool divergentTerminators = hasDivergentTerminator(block);
1237
    for (auto *T : terms) {
1238
      if (divergentTerminators)
1239
        OS << "  DIVERGENT: ";
1240
      else
1241
        OS << "             ";
1242
      OS << Context.print(T) << '\n';
1243
    }
1244
 
1245
    OS << "END BLOCK\n";
1246
  }
1247
}
1248
 
1249
template <typename ContextT>
1250
bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1251
  return DA->hasDivergence();
1252
}
1253
 
1254
/// Whether \p V is divergent at its definition.
1255
template <typename ContextT>
1256
bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1257
  return DA->isDivergent(V);
1258
}
1259
 
1260
template <typename ContextT>
1261
bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1262
  return DA->hasDivergentTerminator(B);
1263
}
1264
 
1265
/// \brief T helper function for printing.
1266
template <typename ContextT>
1267
void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1268
  DA->print(out);
1269
}
1270
 
1271
template <typename ContextT>
1272
void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1273
    SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle,
1274
    SmallPtrSetImpl<BlockT *> &Finalized) {
1275
  LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1276
  while (!Stack.empty()) {
1277
    auto *NextBB = Stack.back();
1278
    if (Finalized.count(NextBB)) {
1279
      Stack.pop_back();
1280
      continue;
1281
    }
1282
    LLVM_DEBUG(dbgs() << "  visiting " << CI.getSSAContext().print(NextBB)
1283
                      << "\n");
1284
    auto *NestedCycle = CI.getCycle(NextBB);
1285
    if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1286
      LLVM_DEBUG(dbgs() << "  found a cycle\n");
1287
      while (NestedCycle->getParentCycle() != Cycle)
1288
        NestedCycle = NestedCycle->getParentCycle();
1289
 
1290
      SmallVector<BlockT *, 3> NestedExits;
1291
      NestedCycle->getExitBlocks(NestedExits);
1292
      bool PushedNodes = false;
1293
      for (auto *NestedExitBB : NestedExits) {
1294
        LLVM_DEBUG(dbgs() << "  examine exit: "
1295
                          << CI.getSSAContext().print(NestedExitBB) << "\n");
1296
        if (Cycle && !Cycle->contains(NestedExitBB))
1297
          continue;
1298
        if (Finalized.count(NestedExitBB))
1299
          continue;
1300
        PushedNodes = true;
1301
        Stack.push_back(NestedExitBB);
1302
        LLVM_DEBUG(dbgs() << "  pushed exit: "
1303
                          << CI.getSSAContext().print(NestedExitBB) << "\n");
1304
      }
1305
      if (!PushedNodes) {
1306
        // All loop exits finalized -> finish this node
1307
        Stack.pop_back();
1308
        computeCyclePO(CI, NestedCycle, Finalized);
1309
      }
1310
      continue;
1311
    }
1312
 
1313
    LLVM_DEBUG(dbgs() << "  no nested cycle, going into DAG\n");
1314
    // DAG-style
1315
    bool PushedNodes = false;
1316
    for (auto *SuccBB : successors(NextBB)) {
1317
      LLVM_DEBUG(dbgs() << "  examine succ: "
1318
                        << CI.getSSAContext().print(SuccBB) << "\n");
1319
      if (Cycle && !Cycle->contains(SuccBB))
1320
        continue;
1321
      if (Finalized.count(SuccBB))
1322
        continue;
1323
      PushedNodes = true;
1324
      Stack.push_back(SuccBB);
1325
      LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(SuccBB)
1326
                        << "\n");
1327
    }
1328
    if (!PushedNodes) {
1329
      // Never push nodes twice
1330
      LLVM_DEBUG(dbgs() << "  finishing node: "
1331
                        << CI.getSSAContext().print(NextBB) << "\n");
1332
      Stack.pop_back();
1333
      Finalized.insert(NextBB);
1334
      appendBlock(*NextBB);
1335
    }
1336
  }
1337
  LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1338
}
1339
 
1340
template <typename ContextT>
1341
void ModifiedPostOrder<ContextT>::computeCyclePO(
1342
    const CycleInfoT &CI, const CycleT *Cycle,
1343
    SmallPtrSetImpl<BlockT *> &Finalized) {
1344
  LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1345
  SmallVector<BlockT *> Stack;
1346
  auto *CycleHeader = Cycle->getHeader();
1347
 
1348
  LLVM_DEBUG(dbgs() << "  noted header: "
1349
                    << CI.getSSAContext().print(CycleHeader) << "\n");
1350
  assert(!Finalized.count(CycleHeader));
1351
  Finalized.insert(CycleHeader);
1352
 
1353
  // Visit the header last
1354
  LLVM_DEBUG(dbgs() << "  finishing header: "
1355
                    << CI.getSSAContext().print(CycleHeader) << "\n");
1356
  appendBlock(*CycleHeader, Cycle->isReducible());
1357
 
1358
  // Initialize with immediate successors
1359
  for (auto *BB : successors(CycleHeader)) {
1360
    LLVM_DEBUG(dbgs() << "  examine succ: " << CI.getSSAContext().print(BB)
1361
                      << "\n");
1362
    if (!Cycle->contains(BB))
1363
      continue;
1364
    if (BB == CycleHeader)
1365
      continue;
1366
    if (!Finalized.count(BB)) {
1367
      LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(BB)
1368
                        << "\n");
1369
      Stack.push_back(BB);
1370
    }
1371
  }
1372
 
1373
  // Compute PO inside region
1374
  computeStackPO(Stack, CI, Cycle, Finalized);
1375
 
1376
  LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1377
}
1378
 
1379
/// \brief Generically compute the modified post order.
1380
template <typename ContextT>
1381
void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1382
  SmallPtrSet<BlockT *, 32> Finalized;
1383
  SmallVector<BlockT *> Stack;
1384
  auto *F = CI.getFunction();
1385
  Stack.reserve(24); // FIXME made-up number
1386
  Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F));
1387
  computeStackPO(Stack, CI, nullptr, Finalized);
1388
}
1389
 
1390
} // namespace llvm
1391
 
1392
#undef DEBUG_TYPE
1393
 
1394
#endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H