Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line | 
|---|---|---|---|
| 14 | pmbaty | 1 | //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===// | 
| 2 | // | ||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| 4 | // See https://llvm.org/LICENSE.txt for license information. | ||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| 6 | // | ||
| 7 | //===----------------------------------------------------------------------===// | ||
| 8 | // | ||
| 9 | // Make changes to isl's schedule tree data structure. | ||
| 10 | // | ||
| 11 | //===----------------------------------------------------------------------===// | ||
| 12 | |||
| 13 | #ifndef POLLY_SCHEDULETREETRANSFORM_H | ||
| 14 | #define POLLY_SCHEDULETREETRANSFORM_H | ||
| 15 | |||
| 16 | #include "polly/Support/ISLTools.h" | ||
| 17 | #include "llvm/ADT/ArrayRef.h" | ||
| 18 | #include "llvm/Support/ErrorHandling.h" | ||
| 19 | #include "isl/isl-noexceptions.h" | ||
| 20 | #include <cassert> | ||
| 21 | |||
| 22 | namespace polly { | ||
| 23 | struct BandAttr; | ||
| 24 | |||
| 25 | /// This class defines a simple visitor class that may be used for | ||
| 26 | /// various schedule tree analysis purposes. | ||
| 27 | template <typename Derived, typename RetTy = void, typename... Args> | ||
| 28 | struct ScheduleTreeVisitor { | ||
| 29 | Derived &getDerived() { return *static_cast<Derived *>(this); } | ||
| 30 | const Derived &getDerived() const { | ||
| 31 | return *static_cast<const Derived *>(this); | ||
| 32 |   } | ||
| 33 | |||
| 34 | RetTy visit(isl::schedule_node Node, Args... args) { | ||
| 35 | assert(!Node.is_null()); | ||
| 36 | switch (isl_schedule_node_get_type(Node.get())) { | ||
| 37 | case isl_schedule_node_domain: | ||
| 38 | assert(isl_schedule_node_n_children(Node.get()) == 1); | ||
| 39 | return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(), | ||
| 40 | std::forward<Args>(args)...); | ||
| 41 | case isl_schedule_node_band: | ||
| 42 | assert(isl_schedule_node_n_children(Node.get()) == 1); | ||
| 43 | return getDerived().visitBand(Node.as<isl::schedule_node_band>(), | ||
| 44 | std::forward<Args>(args)...); | ||
| 45 | case isl_schedule_node_sequence: | ||
| 46 | assert(isl_schedule_node_n_children(Node.get()) >= 2); | ||
| 47 | return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(), | ||
| 48 | std::forward<Args>(args)...); | ||
| 49 | case isl_schedule_node_set: | ||
| 50 | return getDerived().visitSet(Node.as<isl::schedule_node_set>(), | ||
| 51 | std::forward<Args>(args)...); | ||
| 52 | assert(isl_schedule_node_n_children(Node.get()) >= 2); | ||
| 53 | case isl_schedule_node_leaf: | ||
| 54 | assert(isl_schedule_node_n_children(Node.get()) == 0); | ||
| 55 | return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(), | ||
| 56 | std::forward<Args>(args)...); | ||
| 57 | case isl_schedule_node_mark: | ||
| 58 | assert(isl_schedule_node_n_children(Node.get()) == 1); | ||
| 59 | return getDerived().visitMark(Node.as<isl::schedule_node_mark>(), | ||
| 60 | std::forward<Args>(args)...); | ||
| 61 | case isl_schedule_node_extension: | ||
| 62 | assert(isl_schedule_node_n_children(Node.get()) == 1); | ||
| 63 | return getDerived().visitExtension( | ||
| 64 | Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...); | ||
| 65 | case isl_schedule_node_filter: | ||
| 66 | assert(isl_schedule_node_n_children(Node.get()) == 1); | ||
| 67 | return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(), | ||
| 68 | std::forward<Args>(args)...); | ||
| 69 | default: | ||
| 70 | llvm_unreachable("unimplemented schedule node type"); | ||
| 71 |     } | ||
| 72 |   } | ||
| 73 | |||
| 74 | RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) { | ||
| 75 | return getDerived().visitSingleChild(std::move(Domain), | ||
| 76 | std::forward<Args>(args)...); | ||
| 77 |   } | ||
| 78 | |||
| 79 | RetTy visitBand(isl::schedule_node_band Band, Args... args) { | ||
| 80 | return getDerived().visitSingleChild(std::move(Band), | ||
| 81 | std::forward<Args>(args)...); | ||
| 82 |   } | ||
| 83 | |||
| 84 | RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) { | ||
| 85 | return getDerived().visitMultiChild(std::move(Sequence), | ||
| 86 | std::forward<Args>(args)...); | ||
| 87 |   } | ||
| 88 | |||
| 89 | RetTy visitSet(isl::schedule_node_set Set, Args... args) { | ||
| 90 | return getDerived().visitMultiChild(std::move(Set), | ||
| 91 | std::forward<Args>(args)...); | ||
| 92 |   } | ||
| 93 | |||
| 94 | RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { | ||
| 95 | return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...); | ||
| 96 |   } | ||
| 97 | |||
| 98 | RetTy visitMark(isl::schedule_node_mark Mark, Args... args) { | ||
| 99 | return getDerived().visitSingleChild(std::move(Mark), | ||
| 100 | std::forward<Args>(args)...); | ||
| 101 |   } | ||
| 102 | |||
| 103 | RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) { | ||
| 104 | return getDerived().visitSingleChild(std::move(Extension), | ||
| 105 | std::forward<Args>(args)...); | ||
| 106 |   } | ||
| 107 | |||
| 108 | RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) { | ||
| 109 | return getDerived().visitSingleChild(std::move(Filter), | ||
| 110 | std::forward<Args>(args)...); | ||
| 111 |   } | ||
| 112 | |||
| 113 | RetTy visitSingleChild(isl::schedule_node Node, Args... args) { | ||
| 114 | return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); | ||
| 115 |   } | ||
| 116 | |||
| 117 | RetTy visitMultiChild(isl::schedule_node Node, Args... args) { | ||
| 118 | return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); | ||
| 119 |   } | ||
| 120 | |||
| 121 | RetTy visitNode(isl::schedule_node Node, Args... args) { | ||
| 122 | llvm_unreachable("Unimplemented other"); | ||
| 123 |   } | ||
| 124 | }; | ||
| 125 | |||
| 126 | /// Recursively visit all nodes of a schedule tree. | ||
| 127 | template <typename Derived, typename RetTy = void, typename... Args> | ||
| 128 | struct RecursiveScheduleTreeVisitor | ||
| 129 | : ScheduleTreeVisitor<Derived, RetTy, Args...> { | ||
| 130 | using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>; | ||
| 131 | BaseTy &getBase() { return *this; } | ||
| 132 | const BaseTy &getBase() const { return *this; } | ||
| 133 | Derived &getDerived() { return *static_cast<Derived *>(this); } | ||
| 134 | const Derived &getDerived() const { | ||
| 135 | return *static_cast<const Derived *>(this); | ||
| 136 |   } | ||
| 137 | |||
| 138 |   /// When visiting an entire schedule tree, start at its root node. | ||
| 139 | RetTy visit(isl::schedule Schedule, Args... args) { | ||
| 140 | return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...); | ||
| 141 |   } | ||
| 142 | |||
| 143 |   // Necessary to allow overload resolution with the added visit(isl::schedule) | ||
| 144 |   // overload. | ||
| 145 | RetTy visit(isl::schedule_node Node, Args... args) { | ||
| 146 | return getBase().visit(Node, std::forward<Args>(args)...); | ||
| 147 |   } | ||
| 148 | |||
| 149 |   /// By default, recursively visit the child nodes. | ||
| 150 | RetTy visitNode(isl::schedule_node Node, Args... args) { | ||
| 151 | for (unsigned i : rangeIslSize(0, Node.n_children())) | ||
| 152 | getDerived().visit(Node.child(i), std::forward<Args>(args)...); | ||
| 153 | return RetTy(); | ||
| 154 |   } | ||
| 155 | }; | ||
| 156 | |||
| 157 | /// Recursively visit all nodes of a schedule tree while allowing changes. | ||
| 158 | /// | ||
| 159 | /// The visit methods return an isl::schedule_node that is used to continue | ||
| 160 | /// visiting the tree. Structural changes such as returning a different node | ||
| 161 | /// will confuse the visitor. | ||
| 162 | template <typename Derived, typename... Args> | ||
| 163 | struct ScheduleNodeRewriter | ||
| 164 | : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, | ||
| 165 | Args...> { | ||
| 166 | Derived &getDerived() { return *static_cast<Derived *>(this); } | ||
| 167 | const Derived &getDerived() const { | ||
| 168 | return *static_cast<const Derived *>(this); | ||
| 169 |   } | ||
| 170 | |||
| 171 | isl::schedule_node visitNode(isl::schedule_node Node, Args... args) { | ||
| 172 | return getDerived().visitChildren(Node); | ||
| 173 |   } | ||
| 174 | |||
| 175 | isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) { | ||
| 176 | if (!Node.has_children()) | ||
| 177 | return Node; | ||
| 178 | |||
| 179 | isl::schedule_node It = Node.first_child(); | ||
| 180 | while (true) { | ||
| 181 | It = getDerived().visit(It, std::forward<Args>(args)...); | ||
| 182 | if (!It.has_next_sibling()) | ||
| 183 | break; | ||
| 184 | It = It.next_sibling(); | ||
| 185 |     } | ||
| 186 | return It.parent(); | ||
| 187 |   } | ||
| 188 | }; | ||
| 189 | |||
| 190 | /// Is this node the marker for its parent band? | ||
| 191 | bool isBandMark(const isl::schedule_node &Node); | ||
| 192 | |||
| 193 | /// Extract the BandAttr from a band's wrapping marker. Can also pass the band | ||
| 194 | /// itself and this methods will try to find its wrapping mark. Returns nullptr | ||
| 195 | /// if the band has not BandAttr. | ||
| 196 | BandAttr *getBandAttr(isl::schedule_node MarkOrBand); | ||
| 197 | |||
| 198 | /// Hoist all domains from extension into the root domain node, such that there | ||
| 199 | /// are no more extension nodes (which isl does not support for some | ||
| 200 | /// operations). This assumes that domains added by to extension nodes do not | ||
| 201 | /// overlap. | ||
| 202 | isl::schedule hoistExtensionNodes(isl::schedule Sched); | ||
| 203 | |||
| 204 | /// Replace the AST band @p BandToUnroll by a sequence of all its iterations. | ||
| 205 | /// | ||
| 206 | /// The implementation enumerates all points in the partial schedule and creates | ||
| 207 | /// an ISL sequence node for each point. The number of iterations must be a | ||
| 208 | /// constant. | ||
| 209 | isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll); | ||
| 210 | |||
| 211 | /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent. | ||
| 212 | isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor); | ||
| 213 | |||
| 214 | /// Loop-distribute the band @p BandToFission as much as possible. | ||
| 215 | isl::schedule applyMaxFission(isl::schedule_node BandToFission); | ||
| 216 | |||
| 217 | /// Build the desired set of partial tile prefixes. | ||
| 218 | /// | ||
| 219 | /// We build a set of partial tile prefixes, which are prefixes of the vector | ||
| 220 | /// loop that have exactly VectorWidth iterations. | ||
| 221 | /// | ||
| 222 | /// 1. Drop all constraints involving the dimension that represents the | ||
| 223 | ///    vector loop. | ||
| 224 | /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth | ||
| 225 | ///    iterations. | ||
| 226 | /// 3. Subtract loop domain from it, project out the vector loop dimension and | ||
| 227 | ///    get a set that contains prefixes, which do not have exactly VectorWidth | ||
| 228 | ///    iterations. | ||
| 229 | /// 4. Project out the vector loop dimension of the set that was build on the | ||
| 230 | ///    first step and subtract the set built on the previous step to get the | ||
| 231 | ///    desired set of prefixes. | ||
| 232 | /// | ||
| 233 | /// @param ScheduleRange A range of a map, which describes a prefix schedule | ||
| 234 | ///                      relation. | ||
| 235 | isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth); | ||
| 236 | |||
| 237 | /// Create an isl::union_set, which describes the isolate option based on | ||
| 238 | /// IsolateDomain. | ||
| 239 | /// | ||
| 240 | /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should | ||
| 241 | ///                      belong to the current band node. | ||
| 242 | /// @param OutDimsNum    A number of dimensions that should belong to | ||
| 243 | ///                      the current band node. | ||
| 244 | isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum); | ||
| 245 | |||
| 246 | /// Create an isl::union_set, which describes the specified option for the | ||
| 247 | /// dimension of the current node. | ||
| 248 | /// | ||
| 249 | /// @param Ctx    An isl::ctx, which is used to create the isl::union_set. | ||
| 250 | /// @param Option The name of the option. | ||
| 251 | isl::union_set getDimOptions(isl::ctx Ctx, const char *Option); | ||
| 252 | |||
| 253 | /// Tile a schedule node. | ||
| 254 | /// | ||
| 255 | /// @param Node            The node to tile. | ||
| 256 | /// @param Identifier      An name that identifies this kind of tiling and | ||
| 257 | ///                        that is used to mark the tiled loops in the | ||
| 258 | ///                        generated AST. | ||
| 259 | /// @param TileSizes       A vector of tile sizes that should be used for | ||
| 260 | ///                        tiling. | ||
| 261 | /// @param DefaultTileSize A default tile size that is used for dimensions | ||
| 262 | ///                        that are not covered by the TileSizes vector. | ||
| 263 | isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, | ||
| 264 | llvm::ArrayRef<int> TileSizes, int DefaultTileSize); | ||
| 265 | |||
| 266 | /// Tile a schedule node and unroll point loops. | ||
| 267 | /// | ||
| 268 | /// @param Node            The node to register tile. | ||
| 269 | /// @param TileSizes       A vector of tile sizes that should be used for | ||
| 270 | ///                        tiling. | ||
| 271 | /// @param DefaultTileSize A default tile size that is used for dimensions | ||
| 272 | isl::schedule_node applyRegisterTiling(isl::schedule_node Node, | ||
| 273 | llvm::ArrayRef<int> TileSizes, | ||
| 274 | int DefaultTileSize); | ||
| 275 | |||
| 276 | /// Apply greedy fusion. That is, fuse any loop that is possible to be fused | ||
| 277 | /// top-down. | ||
| 278 | /// | ||
| 279 | /// @param Sched  Sched tree to fuse all the loops in. | ||
| 280 | /// @param Deps   Validity constraints that must be preserved. | ||
| 281 | isl::schedule applyGreedyFusion(isl::schedule Sched, | ||
| 282 | const isl::union_map &Deps); | ||
| 283 | |||
| 284 | } // namespace polly | ||
| 285 | |||
| 286 | #endif // POLLY_SCHEDULETREETRANSFORM_H |