Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the MatrixBuilder class, which is used as a convenient way
  10. // to lower matrix operations to LLVM IR.
  11. //
  12. //===----------------------------------------------------------------------===//
  13.  
  14. #ifndef LLVM_IR_MATRIXBUILDER_H
  15. #define LLVM_IR_MATRIXBUILDER_H
  16.  
  17. #include "llvm/IR/Constant.h"
  18. #include "llvm/IR/Constants.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/InstrTypes.h"
  21. #include "llvm/IR/Instruction.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/IR/Type.h"
  24. #include "llvm/IR/Value.h"
  25. #include "llvm/Support/Alignment.h"
  26.  
  27. namespace llvm {
  28.  
  29. class Function;
  30. class Twine;
  31. class Module;
  32.  
  33. class MatrixBuilder {
  34.   IRBuilderBase &B;
  35.   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
  36.  
  37.   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
  38.                                                          Value *RHS) {
  39.     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
  40.            "One of the operands must be a matrix (embedded in a vector)");
  41.     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  42.       assert(!isa<ScalableVectorType>(LHS->getType()) &&
  43.              "LHS Assumed to be fixed width");
  44.       RHS = B.CreateVectorSplat(
  45.           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  46.           "scalar.splat");
  47.     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  48.       assert(!isa<ScalableVectorType>(RHS->getType()) &&
  49.              "RHS Assumed to be fixed width");
  50.       LHS = B.CreateVectorSplat(
  51.           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  52.           "scalar.splat");
  53.     }
  54.     return {LHS, RHS};
  55.   }
  56.  
  57. public:
  58.   MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
  59.  
  60.   /// Create a column major, strided matrix load.
  61.   /// \p EltTy   - Matrix element type
  62.   /// \p DataPtr - Start address of the matrix read
  63.   /// \p Rows    - Number of rows in matrix (must be a constant)
  64.   /// \p Columns - Number of columns in matrix (must be a constant)
  65.   /// \p Stride  - Space between columns
  66.   CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
  67.                                   Value *Stride, bool IsVolatile, unsigned Rows,
  68.                                   unsigned Columns, const Twine &Name = "") {
  69.     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
  70.  
  71.     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
  72.                     B.getInt32(Columns)};
  73.     Type *OverloadedTypes[] = {RetType, Stride->getType()};
  74.  
  75.     Function *TheFn = Intrinsic::getDeclaration(
  76.         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
  77.  
  78.     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  79.     Attribute AlignAttr =
  80.         Attribute::getWithAlignment(Call->getContext(), Alignment);
  81.     Call->addParamAttr(0, AlignAttr);
  82.     return Call;
  83.   }
  84.  
  85.   /// Create a column major, strided matrix store.
  86.   /// \p Matrix  - Matrix to store
  87.   /// \p Ptr     - Pointer to write back to
  88.   /// \p Stride  - Space between columns
  89.   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
  90.                                    Value *Stride, bool IsVolatile,
  91.                                    unsigned Rows, unsigned Columns,
  92.                                    const Twine &Name = "") {
  93.     Value *Ops[] = {Matrix,           Ptr,
  94.                     Stride,           B.getInt1(IsVolatile),
  95.                     B.getInt32(Rows), B.getInt32(Columns)};
  96.     Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
  97.  
  98.     Function *TheFn = Intrinsic::getDeclaration(
  99.         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
  100.  
  101.     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  102.     Attribute AlignAttr =
  103.         Attribute::getWithAlignment(Call->getContext(), Alignment);
  104.     Call->addParamAttr(1, AlignAttr);
  105.     return Call;
  106.   }
  107.  
  108.   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
  109.   /// rows and \p Columns columns.
  110.   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
  111.                                   unsigned Columns, const Twine &Name = "") {
  112.     auto *OpType = cast<VectorType>(Matrix->getType());
  113.     auto *ReturnType =
  114.         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
  115.  
  116.     Type *OverloadedTypes[] = {ReturnType};
  117.     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
  118.     Function *TheFn = Intrinsic::getDeclaration(
  119.         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
  120.  
  121.     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  122.   }
  123.  
  124.   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
  125.   /// RHS.
  126.   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
  127.                                  unsigned LHSColumns, unsigned RHSColumns,
  128.                                  const Twine &Name = "") {
  129.     auto *LHSType = cast<VectorType>(LHS->getType());
  130.     auto *RHSType = cast<VectorType>(RHS->getType());
  131.  
  132.     auto *ReturnType =
  133.         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
  134.  
  135.     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
  136.                     B.getInt32(RHSColumns)};
  137.     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
  138.  
  139.     Function *TheFn = Intrinsic::getDeclaration(
  140.         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
  141.     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  142.   }
  143.  
  144.   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
  145.   /// ColumnIdx).
  146.   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
  147.                             Value *ColumnIdx, unsigned NumRows) {
  148.     return B.CreateInsertElement(
  149.         Matrix, NewVal,
  150.         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
  151.                                                ColumnIdx->getType(), NumRows)),
  152.                     RowIdx));
  153.   }
  154.  
  155.   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
  156.   /// matrixes.
  157.   Value *CreateAdd(Value *LHS, Value *RHS) {
  158.     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  159.     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  160.       assert(!isa<ScalableVectorType>(LHS->getType()) &&
  161.              "LHS Assumed to be fixed width");
  162.       RHS = B.CreateVectorSplat(
  163.           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  164.           "scalar.splat");
  165.     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  166.       assert(!isa<ScalableVectorType>(RHS->getType()) &&
  167.              "RHS Assumed to be fixed width");
  168.       LHS = B.CreateVectorSplat(
  169.           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  170.           "scalar.splat");
  171.     }
  172.  
  173.     return cast<VectorType>(LHS->getType())
  174.                    ->getElementType()
  175.                    ->isFloatingPointTy()
  176.                ? B.CreateFAdd(LHS, RHS)
  177.                : B.CreateAdd(LHS, RHS);
  178.   }
  179.  
  180.   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
  181.   /// point matrixes.
  182.   Value *CreateSub(Value *LHS, Value *RHS) {
  183.     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  184.     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  185.       assert(!isa<ScalableVectorType>(LHS->getType()) &&
  186.              "LHS Assumed to be fixed width");
  187.       RHS = B.CreateVectorSplat(
  188.           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  189.           "scalar.splat");
  190.     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  191.       assert(!isa<ScalableVectorType>(RHS->getType()) &&
  192.              "RHS Assumed to be fixed width");
  193.       LHS = B.CreateVectorSplat(
  194.           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  195.           "scalar.splat");
  196.     }
  197.  
  198.     return cast<VectorType>(LHS->getType())
  199.                    ->getElementType()
  200.                    ->isFloatingPointTy()
  201.                ? B.CreateFSub(LHS, RHS)
  202.                : B.CreateSub(LHS, RHS);
  203.   }
  204.  
  205.   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
  206.   /// RHS.
  207.   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
  208.     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
  209.     if (LHS->getType()->getScalarType()->isFloatingPointTy())
  210.       return B.CreateFMul(LHS, RHS);
  211.     return B.CreateMul(LHS, RHS);
  212.   }
  213.  
  214.   /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
  215.   /// IsUnsigned indicates whether UDiv or SDiv should be used.
  216.   Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
  217.     assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
  218.     assert(!isa<ScalableVectorType>(LHS->getType()) &&
  219.            "LHS Assumed to be fixed width");
  220.     RHS =
  221.         B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
  222.                             RHS, "scalar.splat");
  223.     return cast<VectorType>(LHS->getType())
  224.                    ->getElementType()
  225.                    ->isFloatingPointTy()
  226.                ? B.CreateFDiv(LHS, RHS)
  227.                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
  228.   }
  229.  
  230.   /// Create an assumption that \p Idx is less than \p NumElements.
  231.   void CreateIndexAssumption(Value *Idx, unsigned NumElements,
  232.                              Twine const &Name = "") {
  233.     Value *NumElts =
  234.         B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
  235.     auto *Cmp = B.CreateICmpULT(Idx, NumElts);
  236.     if (isa<ConstantInt>(Cmp))
  237.       assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
  238.     else
  239.       B.CreateAssumption(Cmp);
  240.   }
  241.  
  242.   /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
  243.   /// a matrix with \p NumRows embedded in a vector.
  244.   Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
  245.                      Twine const &Name = "") {
  246.     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
  247.                                  ColumnIdx->getType()->getScalarSizeInBits());
  248.     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
  249.     RowIdx = B.CreateZExt(RowIdx, IntTy);
  250.     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
  251.     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
  252.     return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
  253.   }
  254. };
  255.  
  256. } // end namespace llvm
  257.  
  258. #endif // LLVM_IR_MATRIXBUILDER_H
  259.