Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | // Make changes to isl's schedule tree data structure. |
||
10 | // |
||
11 | //===----------------------------------------------------------------------===// |
||
12 | |||
13 | #ifndef POLLY_SCHEDULETREETRANSFORM_H |
||
14 | #define POLLY_SCHEDULETREETRANSFORM_H |
||
15 | |||
16 | #include "polly/Support/ISLTools.h" |
||
17 | #include "llvm/ADT/ArrayRef.h" |
||
18 | #include "llvm/Support/ErrorHandling.h" |
||
19 | #include "isl/isl-noexceptions.h" |
||
20 | #include <cassert> |
||
21 | |||
22 | namespace polly { |
||
23 | struct BandAttr; |
||
24 | |||
25 | /// This class defines a simple visitor class that may be used for |
||
26 | /// various schedule tree analysis purposes. |
||
27 | template <typename Derived, typename RetTy = void, typename... Args> |
||
28 | struct ScheduleTreeVisitor { |
||
29 | Derived &getDerived() { return *static_cast<Derived *>(this); } |
||
30 | const Derived &getDerived() const { |
||
31 | return *static_cast<const Derived *>(this); |
||
32 | } |
||
33 | |||
34 | RetTy visit(isl::schedule_node Node, Args... args) { |
||
35 | assert(!Node.is_null()); |
||
36 | switch (isl_schedule_node_get_type(Node.get())) { |
||
37 | case isl_schedule_node_domain: |
||
38 | assert(isl_schedule_node_n_children(Node.get()) == 1); |
||
39 | return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(), |
||
40 | std::forward<Args>(args)...); |
||
41 | case isl_schedule_node_band: |
||
42 | assert(isl_schedule_node_n_children(Node.get()) == 1); |
||
43 | return getDerived().visitBand(Node.as<isl::schedule_node_band>(), |
||
44 | std::forward<Args>(args)...); |
||
45 | case isl_schedule_node_sequence: |
||
46 | assert(isl_schedule_node_n_children(Node.get()) >= 2); |
||
47 | return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(), |
||
48 | std::forward<Args>(args)...); |
||
49 | case isl_schedule_node_set: |
||
50 | return getDerived().visitSet(Node.as<isl::schedule_node_set>(), |
||
51 | std::forward<Args>(args)...); |
||
52 | assert(isl_schedule_node_n_children(Node.get()) >= 2); |
||
53 | case isl_schedule_node_leaf: |
||
54 | assert(isl_schedule_node_n_children(Node.get()) == 0); |
||
55 | return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(), |
||
56 | std::forward<Args>(args)...); |
||
57 | case isl_schedule_node_mark: |
||
58 | assert(isl_schedule_node_n_children(Node.get()) == 1); |
||
59 | return getDerived().visitMark(Node.as<isl::schedule_node_mark>(), |
||
60 | std::forward<Args>(args)...); |
||
61 | case isl_schedule_node_extension: |
||
62 | assert(isl_schedule_node_n_children(Node.get()) == 1); |
||
63 | return getDerived().visitExtension( |
||
64 | Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...); |
||
65 | case isl_schedule_node_filter: |
||
66 | assert(isl_schedule_node_n_children(Node.get()) == 1); |
||
67 | return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(), |
||
68 | std::forward<Args>(args)...); |
||
69 | default: |
||
70 | llvm_unreachable("unimplemented schedule node type"); |
||
71 | } |
||
72 | } |
||
73 | |||
74 | RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) { |
||
75 | return getDerived().visitSingleChild(std::move(Domain), |
||
76 | std::forward<Args>(args)...); |
||
77 | } |
||
78 | |||
79 | RetTy visitBand(isl::schedule_node_band Band, Args... args) { |
||
80 | return getDerived().visitSingleChild(std::move(Band), |
||
81 | std::forward<Args>(args)...); |
||
82 | } |
||
83 | |||
84 | RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) { |
||
85 | return getDerived().visitMultiChild(std::move(Sequence), |
||
86 | std::forward<Args>(args)...); |
||
87 | } |
||
88 | |||
89 | RetTy visitSet(isl::schedule_node_set Set, Args... args) { |
||
90 | return getDerived().visitMultiChild(std::move(Set), |
||
91 | std::forward<Args>(args)...); |
||
92 | } |
||
93 | |||
94 | RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { |
||
95 | return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...); |
||
96 | } |
||
97 | |||
98 | RetTy visitMark(isl::schedule_node_mark Mark, Args... args) { |
||
99 | return getDerived().visitSingleChild(std::move(Mark), |
||
100 | std::forward<Args>(args)...); |
||
101 | } |
||
102 | |||
103 | RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) { |
||
104 | return getDerived().visitSingleChild(std::move(Extension), |
||
105 | std::forward<Args>(args)...); |
||
106 | } |
||
107 | |||
108 | RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) { |
||
109 | return getDerived().visitSingleChild(std::move(Filter), |
||
110 | std::forward<Args>(args)...); |
||
111 | } |
||
112 | |||
113 | RetTy visitSingleChild(isl::schedule_node Node, Args... args) { |
||
114 | return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); |
||
115 | } |
||
116 | |||
117 | RetTy visitMultiChild(isl::schedule_node Node, Args... args) { |
||
118 | return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); |
||
119 | } |
||
120 | |||
121 | RetTy visitNode(isl::schedule_node Node, Args... args) { |
||
122 | llvm_unreachable("Unimplemented other"); |
||
123 | } |
||
124 | }; |
||
125 | |||
126 | /// Recursively visit all nodes of a schedule tree. |
||
127 | template <typename Derived, typename RetTy = void, typename... Args> |
||
128 | struct RecursiveScheduleTreeVisitor |
||
129 | : ScheduleTreeVisitor<Derived, RetTy, Args...> { |
||
130 | using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>; |
||
131 | BaseTy &getBase() { return *this; } |
||
132 | const BaseTy &getBase() const { return *this; } |
||
133 | Derived &getDerived() { return *static_cast<Derived *>(this); } |
||
134 | const Derived &getDerived() const { |
||
135 | return *static_cast<const Derived *>(this); |
||
136 | } |
||
137 | |||
138 | /// When visiting an entire schedule tree, start at its root node. |
||
139 | RetTy visit(isl::schedule Schedule, Args... args) { |
||
140 | return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...); |
||
141 | } |
||
142 | |||
143 | // Necessary to allow overload resolution with the added visit(isl::schedule) |
||
144 | // overload. |
||
145 | RetTy visit(isl::schedule_node Node, Args... args) { |
||
146 | return getBase().visit(Node, std::forward<Args>(args)...); |
||
147 | } |
||
148 | |||
149 | /// By default, recursively visit the child nodes. |
||
150 | RetTy visitNode(isl::schedule_node Node, Args... args) { |
||
151 | for (unsigned i : rangeIslSize(0, Node.n_children())) |
||
152 | getDerived().visit(Node.child(i), std::forward<Args>(args)...); |
||
153 | return RetTy(); |
||
154 | } |
||
155 | }; |
||
156 | |||
157 | /// Recursively visit all nodes of a schedule tree while allowing changes. |
||
158 | /// |
||
159 | /// The visit methods return an isl::schedule_node that is used to continue |
||
160 | /// visiting the tree. Structural changes such as returning a different node |
||
161 | /// will confuse the visitor. |
||
162 | template <typename Derived, typename... Args> |
||
163 | struct ScheduleNodeRewriter |
||
164 | : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, |
||
165 | Args...> { |
||
166 | Derived &getDerived() { return *static_cast<Derived *>(this); } |
||
167 | const Derived &getDerived() const { |
||
168 | return *static_cast<const Derived *>(this); |
||
169 | } |
||
170 | |||
171 | isl::schedule_node visitNode(isl::schedule_node Node, Args... args) { |
||
172 | return getDerived().visitChildren(Node); |
||
173 | } |
||
174 | |||
175 | isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) { |
||
176 | if (!Node.has_children()) |
||
177 | return Node; |
||
178 | |||
179 | isl::schedule_node It = Node.first_child(); |
||
180 | while (true) { |
||
181 | It = getDerived().visit(It, std::forward<Args>(args)...); |
||
182 | if (!It.has_next_sibling()) |
||
183 | break; |
||
184 | It = It.next_sibling(); |
||
185 | } |
||
186 | return It.parent(); |
||
187 | } |
||
188 | }; |
||
189 | |||
190 | /// Is this node the marker for its parent band? |
||
191 | bool isBandMark(const isl::schedule_node &Node); |
||
192 | |||
193 | /// Extract the BandAttr from a band's wrapping marker. Can also pass the band |
||
194 | /// itself and this methods will try to find its wrapping mark. Returns nullptr |
||
195 | /// if the band has not BandAttr. |
||
196 | BandAttr *getBandAttr(isl::schedule_node MarkOrBand); |
||
197 | |||
198 | /// Hoist all domains from extension into the root domain node, such that there |
||
199 | /// are no more extension nodes (which isl does not support for some |
||
200 | /// operations). This assumes that domains added by to extension nodes do not |
||
201 | /// overlap. |
||
202 | isl::schedule hoistExtensionNodes(isl::schedule Sched); |
||
203 | |||
204 | /// Replace the AST band @p BandToUnroll by a sequence of all its iterations. |
||
205 | /// |
||
206 | /// The implementation enumerates all points in the partial schedule and creates |
||
207 | /// an ISL sequence node for each point. The number of iterations must be a |
||
208 | /// constant. |
||
209 | isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll); |
||
210 | |||
211 | /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent. |
||
212 | isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor); |
||
213 | |||
214 | /// Loop-distribute the band @p BandToFission as much as possible. |
||
215 | isl::schedule applyMaxFission(isl::schedule_node BandToFission); |
||
216 | |||
217 | /// Build the desired set of partial tile prefixes. |
||
218 | /// |
||
219 | /// We build a set of partial tile prefixes, which are prefixes of the vector |
||
220 | /// loop that have exactly VectorWidth iterations. |
||
221 | /// |
||
222 | /// 1. Drop all constraints involving the dimension that represents the |
||
223 | /// vector loop. |
||
224 | /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth |
||
225 | /// iterations. |
||
226 | /// 3. Subtract loop domain from it, project out the vector loop dimension and |
||
227 | /// get a set that contains prefixes, which do not have exactly VectorWidth |
||
228 | /// iterations. |
||
229 | /// 4. Project out the vector loop dimension of the set that was build on the |
||
230 | /// first step and subtract the set built on the previous step to get the |
||
231 | /// desired set of prefixes. |
||
232 | /// |
||
233 | /// @param ScheduleRange A range of a map, which describes a prefix schedule |
||
234 | /// relation. |
||
235 | isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth); |
||
236 | |||
237 | /// Create an isl::union_set, which describes the isolate option based on |
||
238 | /// IsolateDomain. |
||
239 | /// |
||
240 | /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should |
||
241 | /// belong to the current band node. |
||
242 | /// @param OutDimsNum A number of dimensions that should belong to |
||
243 | /// the current band node. |
||
244 | isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum); |
||
245 | |||
246 | /// Create an isl::union_set, which describes the specified option for the |
||
247 | /// dimension of the current node. |
||
248 | /// |
||
249 | /// @param Ctx An isl::ctx, which is used to create the isl::union_set. |
||
250 | /// @param Option The name of the option. |
||
251 | isl::union_set getDimOptions(isl::ctx Ctx, const char *Option); |
||
252 | |||
253 | /// Tile a schedule node. |
||
254 | /// |
||
255 | /// @param Node The node to tile. |
||
256 | /// @param Identifier An name that identifies this kind of tiling and |
||
257 | /// that is used to mark the tiled loops in the |
||
258 | /// generated AST. |
||
259 | /// @param TileSizes A vector of tile sizes that should be used for |
||
260 | /// tiling. |
||
261 | /// @param DefaultTileSize A default tile size that is used for dimensions |
||
262 | /// that are not covered by the TileSizes vector. |
||
263 | isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, |
||
264 | llvm::ArrayRef<int> TileSizes, int DefaultTileSize); |
||
265 | |||
266 | /// Tile a schedule node and unroll point loops. |
||
267 | /// |
||
268 | /// @param Node The node to register tile. |
||
269 | /// @param TileSizes A vector of tile sizes that should be used for |
||
270 | /// tiling. |
||
271 | /// @param DefaultTileSize A default tile size that is used for dimensions |
||
272 | isl::schedule_node applyRegisterTiling(isl::schedule_node Node, |
||
273 | llvm::ArrayRef<int> TileSizes, |
||
274 | int DefaultTileSize); |
||
275 | |||
276 | /// Apply greedy fusion. That is, fuse any loop that is possible to be fused |
||
277 | /// top-down. |
||
278 | /// |
||
279 | /// @param Sched Sched tree to fuse all the loops in. |
||
280 | /// @param Deps Validity constraints that must be preserved. |
||
281 | isl::schedule applyGreedyFusion(isl::schedule Sched, |
||
282 | const isl::union_map &Deps); |
||
283 | |||
284 | } // namespace polly |
||
285 | |||
286 | #endif // POLLY_SCHEDULETREETRANSFORM_H |