Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Make changes to isl's schedule tree data structure.
  10. //
  11. //===----------------------------------------------------------------------===//
  12.  
  13. #ifndef POLLY_SCHEDULETREETRANSFORM_H
  14. #define POLLY_SCHEDULETREETRANSFORM_H
  15.  
  16. #include "polly/Support/ISLTools.h"
  17. #include "llvm/ADT/ArrayRef.h"
  18. #include "llvm/Support/ErrorHandling.h"
  19. #include "isl/isl-noexceptions.h"
  20. #include <cassert>
  21.  
  22. namespace polly {
  23. struct BandAttr;
  24.  
  25. /// This class defines a simple visitor class that may be used for
  26. /// various schedule tree analysis purposes.
  27. template <typename Derived, typename RetTy = void, typename... Args>
  28. struct ScheduleTreeVisitor {
  29.   Derived &getDerived() { return *static_cast<Derived *>(this); }
  30.   const Derived &getDerived() const {
  31.     return *static_cast<const Derived *>(this);
  32.   }
  33.  
  34.   RetTy visit(isl::schedule_node Node, Args... args) {
  35.     assert(!Node.is_null());
  36.     switch (isl_schedule_node_get_type(Node.get())) {
  37.     case isl_schedule_node_domain:
  38.       assert(isl_schedule_node_n_children(Node.get()) == 1);
  39.       return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(),
  40.                                       std::forward<Args>(args)...);
  41.     case isl_schedule_node_band:
  42.       assert(isl_schedule_node_n_children(Node.get()) == 1);
  43.       return getDerived().visitBand(Node.as<isl::schedule_node_band>(),
  44.                                     std::forward<Args>(args)...);
  45.     case isl_schedule_node_sequence:
  46.       assert(isl_schedule_node_n_children(Node.get()) >= 2);
  47.       return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(),
  48.                                         std::forward<Args>(args)...);
  49.     case isl_schedule_node_set:
  50.       return getDerived().visitSet(Node.as<isl::schedule_node_set>(),
  51.                                    std::forward<Args>(args)...);
  52.       assert(isl_schedule_node_n_children(Node.get()) >= 2);
  53.     case isl_schedule_node_leaf:
  54.       assert(isl_schedule_node_n_children(Node.get()) == 0);
  55.       return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(),
  56.                                     std::forward<Args>(args)...);
  57.     case isl_schedule_node_mark:
  58.       assert(isl_schedule_node_n_children(Node.get()) == 1);
  59.       return getDerived().visitMark(Node.as<isl::schedule_node_mark>(),
  60.                                     std::forward<Args>(args)...);
  61.     case isl_schedule_node_extension:
  62.       assert(isl_schedule_node_n_children(Node.get()) == 1);
  63.       return getDerived().visitExtension(
  64.           Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...);
  65.     case isl_schedule_node_filter:
  66.       assert(isl_schedule_node_n_children(Node.get()) == 1);
  67.       return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(),
  68.                                       std::forward<Args>(args)...);
  69.     default:
  70.       llvm_unreachable("unimplemented schedule node type");
  71.     }
  72.   }
  73.  
  74.   RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) {
  75.     return getDerived().visitSingleChild(std::move(Domain),
  76.                                          std::forward<Args>(args)...);
  77.   }
  78.  
  79.   RetTy visitBand(isl::schedule_node_band Band, Args... args) {
  80.     return getDerived().visitSingleChild(std::move(Band),
  81.                                          std::forward<Args>(args)...);
  82.   }
  83.  
  84.   RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) {
  85.     return getDerived().visitMultiChild(std::move(Sequence),
  86.                                         std::forward<Args>(args)...);
  87.   }
  88.  
  89.   RetTy visitSet(isl::schedule_node_set Set, Args... args) {
  90.     return getDerived().visitMultiChild(std::move(Set),
  91.                                         std::forward<Args>(args)...);
  92.   }
  93.  
  94.   RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
  95.     return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...);
  96.   }
  97.  
  98.   RetTy visitMark(isl::schedule_node_mark Mark, Args... args) {
  99.     return getDerived().visitSingleChild(std::move(Mark),
  100.                                          std::forward<Args>(args)...);
  101.   }
  102.  
  103.   RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) {
  104.     return getDerived().visitSingleChild(std::move(Extension),
  105.                                          std::forward<Args>(args)...);
  106.   }
  107.  
  108.   RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) {
  109.     return getDerived().visitSingleChild(std::move(Filter),
  110.                                          std::forward<Args>(args)...);
  111.   }
  112.  
  113.   RetTy visitSingleChild(isl::schedule_node Node, Args... args) {
  114.     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
  115.   }
  116.  
  117.   RetTy visitMultiChild(isl::schedule_node Node, Args... args) {
  118.     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
  119.   }
  120.  
  121.   RetTy visitNode(isl::schedule_node Node, Args... args) {
  122.     llvm_unreachable("Unimplemented other");
  123.   }
  124. };
  125.  
  126. /// Recursively visit all nodes of a schedule tree.
  127. template <typename Derived, typename RetTy = void, typename... Args>
  128. struct RecursiveScheduleTreeVisitor
  129.     : ScheduleTreeVisitor<Derived, RetTy, Args...> {
  130.   using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
  131.   BaseTy &getBase() { return *this; }
  132.   const BaseTy &getBase() const { return *this; }
  133.   Derived &getDerived() { return *static_cast<Derived *>(this); }
  134.   const Derived &getDerived() const {
  135.     return *static_cast<const Derived *>(this);
  136.   }
  137.  
  138.   /// When visiting an entire schedule tree, start at its root node.
  139.   RetTy visit(isl::schedule Schedule, Args... args) {
  140.     return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
  141.   }
  142.  
  143.   // Necessary to allow overload resolution with the added visit(isl::schedule)
  144.   // overload.
  145.   RetTy visit(isl::schedule_node Node, Args... args) {
  146.     return getBase().visit(Node, std::forward<Args>(args)...);
  147.   }
  148.  
  149.   /// By default, recursively visit the child nodes.
  150.   RetTy visitNode(isl::schedule_node Node, Args... args) {
  151.     for (unsigned i : rangeIslSize(0, Node.n_children()))
  152.       getDerived().visit(Node.child(i), std::forward<Args>(args)...);
  153.     return RetTy();
  154.   }
  155. };
  156.  
  157. /// Recursively visit all nodes of a schedule tree while allowing changes.
  158. ///
  159. /// The visit methods return an isl::schedule_node that is used to continue
  160. /// visiting the tree. Structural changes such as returning a different node
  161. /// will confuse the visitor.
  162. template <typename Derived, typename... Args>
  163. struct ScheduleNodeRewriter
  164.     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
  165.                                           Args...> {
  166.   Derived &getDerived() { return *static_cast<Derived *>(this); }
  167.   const Derived &getDerived() const {
  168.     return *static_cast<const Derived *>(this);
  169.   }
  170.  
  171.   isl::schedule_node visitNode(isl::schedule_node Node, Args... args) {
  172.     return getDerived().visitChildren(Node);
  173.   }
  174.  
  175.   isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) {
  176.     if (!Node.has_children())
  177.       return Node;
  178.  
  179.     isl::schedule_node It = Node.first_child();
  180.     while (true) {
  181.       It = getDerived().visit(It, std::forward<Args>(args)...);
  182.       if (!It.has_next_sibling())
  183.         break;
  184.       It = It.next_sibling();
  185.     }
  186.     return It.parent();
  187.   }
  188. };
  189.  
  190. /// Is this node the marker for its parent band?
  191. bool isBandMark(const isl::schedule_node &Node);
  192.  
  193. /// Extract the BandAttr from a band's wrapping marker. Can also pass the band
  194. /// itself and this methods will try to find its wrapping mark. Returns nullptr
  195. /// if the band has not BandAttr.
  196. BandAttr *getBandAttr(isl::schedule_node MarkOrBand);
  197.  
  198. /// Hoist all domains from extension into the root domain node, such that there
  199. /// are no more extension nodes (which isl does not support for some
  200. /// operations). This assumes that domains added by to extension nodes do not
  201. /// overlap.
  202. isl::schedule hoistExtensionNodes(isl::schedule Sched);
  203.  
  204. /// Replace the AST band @p BandToUnroll by a sequence of all its iterations.
  205. ///
  206. /// The implementation enumerates all points in the partial schedule and creates
  207. /// an ISL sequence node for each point. The number of iterations must be a
  208. /// constant.
  209. isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll);
  210.  
  211. /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent.
  212. isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
  213.  
  214. /// Loop-distribute the band @p BandToFission as much as possible.
  215. isl::schedule applyMaxFission(isl::schedule_node BandToFission);
  216.  
  217. /// Build the desired set of partial tile prefixes.
  218. ///
  219. /// We build a set of partial tile prefixes, which are prefixes of the vector
  220. /// loop that have exactly VectorWidth iterations.
  221. ///
  222. /// 1. Drop all constraints involving the dimension that represents the
  223. ///    vector loop.
  224. /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
  225. ///    iterations.
  226. /// 3. Subtract loop domain from it, project out the vector loop dimension and
  227. ///    get a set that contains prefixes, which do not have exactly VectorWidth
  228. ///    iterations.
  229. /// 4. Project out the vector loop dimension of the set that was build on the
  230. ///    first step and subtract the set built on the previous step to get the
  231. ///    desired set of prefixes.
  232. ///
  233. /// @param ScheduleRange A range of a map, which describes a prefix schedule
  234. ///                      relation.
  235. isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
  236.  
  237. /// Create an isl::union_set, which describes the isolate option based on
  238. /// IsolateDomain.
  239. ///
  240. /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
  241. ///                      belong to the current band node.
  242. /// @param OutDimsNum    A number of dimensions that should belong to
  243. ///                      the current band node.
  244. isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum);
  245.  
  246. /// Create an isl::union_set, which describes the specified option for the
  247. /// dimension of the current node.
  248. ///
  249. /// @param Ctx    An isl::ctx, which is used to create the isl::union_set.
  250. /// @param Option The name of the option.
  251. isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
  252.  
  253. /// Tile a schedule node.
  254. ///
  255. /// @param Node            The node to tile.
  256. /// @param Identifier      An name that identifies this kind of tiling and
  257. ///                        that is used to mark the tiled loops in the
  258. ///                        generated AST.
  259. /// @param TileSizes       A vector of tile sizes that should be used for
  260. ///                        tiling.
  261. /// @param DefaultTileSize A default tile size that is used for dimensions
  262. ///                        that are not covered by the TileSizes vector.
  263. isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
  264.                             llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
  265.  
  266. /// Tile a schedule node and unroll point loops.
  267. ///
  268. /// @param Node            The node to register tile.
  269. /// @param TileSizes       A vector of tile sizes that should be used for
  270. ///                        tiling.
  271. /// @param DefaultTileSize A default tile size that is used for dimensions
  272. isl::schedule_node applyRegisterTiling(isl::schedule_node Node,
  273.                                        llvm::ArrayRef<int> TileSizes,
  274.                                        int DefaultTileSize);
  275.  
  276. /// Apply greedy fusion. That is, fuse any loop that is possible to be fused
  277. /// top-down.
  278. ///
  279. /// @param Sched  Sched tree to fuse all the loops in.
  280. /// @param Deps   Validity constraints that must be preserved.
  281. isl::schedule applyGreedyFusion(isl::schedule Sched,
  282.                                 const isl::union_map &Deps);
  283.  
  284. } // namespace polly
  285.  
  286. #endif // POLLY_SCHEDULETREETRANSFORM_H
  287.