Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line |
|---|---|---|---|
| 14 | pmbaty | 1 | //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// |
| 2 | // |
||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
| 6 | // |
||
| 7 | //===----------------------------------------------------------------------===// |
||
| 8 | // |
||
| 9 | // This file defines the MatrixBuilder class, which is used as a convenient way |
||
| 10 | // to lower matrix operations to LLVM IR. |
||
| 11 | // |
||
| 12 | //===----------------------------------------------------------------------===// |
||
| 13 | |||
| 14 | #ifndef LLVM_IR_MATRIXBUILDER_H |
||
| 15 | #define LLVM_IR_MATRIXBUILDER_H |
||
| 16 | |||
| 17 | #include "llvm/IR/Constant.h" |
||
| 18 | #include "llvm/IR/Constants.h" |
||
| 19 | #include "llvm/IR/IRBuilder.h" |
||
| 20 | #include "llvm/IR/InstrTypes.h" |
||
| 21 | #include "llvm/IR/Instruction.h" |
||
| 22 | #include "llvm/IR/IntrinsicInst.h" |
||
| 23 | #include "llvm/IR/Type.h" |
||
| 24 | #include "llvm/IR/Value.h" |
||
| 25 | #include "llvm/Support/Alignment.h" |
||
| 26 | |||
| 27 | namespace llvm { |
||
| 28 | |||
| 29 | class Function; |
||
| 30 | class Twine; |
||
| 31 | class Module; |
||
| 32 | |||
| 33 | class MatrixBuilder { |
||
| 34 | IRBuilderBase &B; |
||
| 35 | Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } |
||
| 36 | |||
| 37 | std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, |
||
| 38 | Value *RHS) { |
||
| 39 | assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && |
||
| 40 | "One of the operands must be a matrix (embedded in a vector)"); |
||
| 41 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
| 42 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
| 43 | "LHS Assumed to be fixed width"); |
||
| 44 | RHS = B.CreateVectorSplat( |
||
| 45 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
| 46 | "scalar.splat"); |
||
| 47 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
| 48 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
| 49 | "RHS Assumed to be fixed width"); |
||
| 50 | LHS = B.CreateVectorSplat( |
||
| 51 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
| 52 | "scalar.splat"); |
||
| 53 | } |
||
| 54 | return {LHS, RHS}; |
||
| 55 | } |
||
| 56 | |||
| 57 | public: |
||
| 58 | MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} |
||
| 59 | |||
| 60 | /// Create a column major, strided matrix load. |
||
| 61 | /// \p EltTy - Matrix element type |
||
| 62 | /// \p DataPtr - Start address of the matrix read |
||
| 63 | /// \p Rows - Number of rows in matrix (must be a constant) |
||
| 64 | /// \p Columns - Number of columns in matrix (must be a constant) |
||
| 65 | /// \p Stride - Space between columns |
||
| 66 | CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, |
||
| 67 | Value *Stride, bool IsVolatile, unsigned Rows, |
||
| 68 | unsigned Columns, const Twine &Name = "") { |
||
| 69 | auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); |
||
| 70 | |||
| 71 | Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), |
||
| 72 | B.getInt32(Columns)}; |
||
| 73 | Type *OverloadedTypes[] = {RetType, Stride->getType()}; |
||
| 74 | |||
| 75 | Function *TheFn = Intrinsic::getDeclaration( |
||
| 76 | getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); |
||
| 77 | |||
| 78 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
| 79 | Attribute AlignAttr = |
||
| 80 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
||
| 81 | Call->addParamAttr(0, AlignAttr); |
||
| 82 | return Call; |
||
| 83 | } |
||
| 84 | |||
| 85 | /// Create a column major, strided matrix store. |
||
| 86 | /// \p Matrix - Matrix to store |
||
| 87 | /// \p Ptr - Pointer to write back to |
||
| 88 | /// \p Stride - Space between columns |
||
| 89 | CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, |
||
| 90 | Value *Stride, bool IsVolatile, |
||
| 91 | unsigned Rows, unsigned Columns, |
||
| 92 | const Twine &Name = "") { |
||
| 93 | Value *Ops[] = {Matrix, Ptr, |
||
| 94 | Stride, B.getInt1(IsVolatile), |
||
| 95 | B.getInt32(Rows), B.getInt32(Columns)}; |
||
| 96 | Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; |
||
| 97 | |||
| 98 | Function *TheFn = Intrinsic::getDeclaration( |
||
| 99 | getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); |
||
| 100 | |||
| 101 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
| 102 | Attribute AlignAttr = |
||
| 103 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
||
| 104 | Call->addParamAttr(1, AlignAttr); |
||
| 105 | return Call; |
||
| 106 | } |
||
| 107 | |||
| 108 | /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows |
||
| 109 | /// rows and \p Columns columns. |
||
| 110 | CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, |
||
| 111 | unsigned Columns, const Twine &Name = "") { |
||
| 112 | auto *OpType = cast<VectorType>(Matrix->getType()); |
||
| 113 | auto *ReturnType = |
||
| 114 | FixedVectorType::get(OpType->getElementType(), Rows * Columns); |
||
| 115 | |||
| 116 | Type *OverloadedTypes[] = {ReturnType}; |
||
| 117 | Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; |
||
| 118 | Function *TheFn = Intrinsic::getDeclaration( |
||
| 119 | getModule(), Intrinsic::matrix_transpose, OverloadedTypes); |
||
| 120 | |||
| 121 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
| 122 | } |
||
| 123 | |||
| 124 | /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p |
||
| 125 | /// RHS. |
||
| 126 | CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, |
||
| 127 | unsigned LHSColumns, unsigned RHSColumns, |
||
| 128 | const Twine &Name = "") { |
||
| 129 | auto *LHSType = cast<VectorType>(LHS->getType()); |
||
| 130 | auto *RHSType = cast<VectorType>(RHS->getType()); |
||
| 131 | |||
| 132 | auto *ReturnType = |
||
| 133 | FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); |
||
| 134 | |||
| 135 | Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), |
||
| 136 | B.getInt32(RHSColumns)}; |
||
| 137 | Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; |
||
| 138 | |||
| 139 | Function *TheFn = Intrinsic::getDeclaration( |
||
| 140 | getModule(), Intrinsic::matrix_multiply, OverloadedTypes); |
||
| 141 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
| 142 | } |
||
| 143 | |||
| 144 | /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p |
||
| 145 | /// ColumnIdx). |
||
| 146 | Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, |
||
| 147 | Value *ColumnIdx, unsigned NumRows) { |
||
| 148 | return B.CreateInsertElement( |
||
| 149 | Matrix, NewVal, |
||
| 150 | B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( |
||
| 151 | ColumnIdx->getType(), NumRows)), |
||
| 152 | RowIdx)); |
||
| 153 | } |
||
| 154 | |||
| 155 | /// Add matrixes \p LHS and \p RHS. Support both integer and floating point |
||
| 156 | /// matrixes. |
||
| 157 | Value *CreateAdd(Value *LHS, Value *RHS) { |
||
| 158 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
||
| 159 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
| 160 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
| 161 | "LHS Assumed to be fixed width"); |
||
| 162 | RHS = B.CreateVectorSplat( |
||
| 163 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
| 164 | "scalar.splat"); |
||
| 165 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
| 166 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
| 167 | "RHS Assumed to be fixed width"); |
||
| 168 | LHS = B.CreateVectorSplat( |
||
| 169 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
| 170 | "scalar.splat"); |
||
| 171 | } |
||
| 172 | |||
| 173 | return cast<VectorType>(LHS->getType()) |
||
| 174 | ->getElementType() |
||
| 175 | ->isFloatingPointTy() |
||
| 176 | ? B.CreateFAdd(LHS, RHS) |
||
| 177 | : B.CreateAdd(LHS, RHS); |
||
| 178 | } |
||
| 179 | |||
| 180 | /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating |
||
| 181 | /// point matrixes. |
||
| 182 | Value *CreateSub(Value *LHS, Value *RHS) { |
||
| 183 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
||
| 184 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
| 185 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
| 186 | "LHS Assumed to be fixed width"); |
||
| 187 | RHS = B.CreateVectorSplat( |
||
| 188 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
| 189 | "scalar.splat"); |
||
| 190 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
| 191 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
| 192 | "RHS Assumed to be fixed width"); |
||
| 193 | LHS = B.CreateVectorSplat( |
||
| 194 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
| 195 | "scalar.splat"); |
||
| 196 | } |
||
| 197 | |||
| 198 | return cast<VectorType>(LHS->getType()) |
||
| 199 | ->getElementType() |
||
| 200 | ->isFloatingPointTy() |
||
| 201 | ? B.CreateFSub(LHS, RHS) |
||
| 202 | : B.CreateSub(LHS, RHS); |
||
| 203 | } |
||
| 204 | |||
| 205 | /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p |
||
| 206 | /// RHS. |
||
| 207 | Value *CreateScalarMultiply(Value *LHS, Value *RHS) { |
||
| 208 | std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); |
||
| 209 | if (LHS->getType()->getScalarType()->isFloatingPointTy()) |
||
| 210 | return B.CreateFMul(LHS, RHS); |
||
| 211 | return B.CreateMul(LHS, RHS); |
||
| 212 | } |
||
| 213 | |||
| 214 | /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p |
||
| 215 | /// IsUnsigned indicates whether UDiv or SDiv should be used. |
||
| 216 | Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { |
||
| 217 | assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); |
||
| 218 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
| 219 | "LHS Assumed to be fixed width"); |
||
| 220 | RHS = |
||
| 221 | B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), |
||
| 222 | RHS, "scalar.splat"); |
||
| 223 | return cast<VectorType>(LHS->getType()) |
||
| 224 | ->getElementType() |
||
| 225 | ->isFloatingPointTy() |
||
| 226 | ? B.CreateFDiv(LHS, RHS) |
||
| 227 | : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); |
||
| 228 | } |
||
| 229 | |||
| 230 | /// Create an assumption that \p Idx is less than \p NumElements. |
||
| 231 | void CreateIndexAssumption(Value *Idx, unsigned NumElements, |
||
| 232 | Twine const &Name = "") { |
||
| 233 | Value *NumElts = |
||
| 234 | B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); |
||
| 235 | auto *Cmp = B.CreateICmpULT(Idx, NumElts); |
||
| 236 | if (isa<ConstantInt>(Cmp)) |
||
| 237 | assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!"); |
||
| 238 | else |
||
| 239 | B.CreateAssumption(Cmp); |
||
| 240 | } |
||
| 241 | |||
| 242 | /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from |
||
| 243 | /// a matrix with \p NumRows embedded in a vector. |
||
| 244 | Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, |
||
| 245 | Twine const &Name = "") { |
||
| 246 | unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), |
||
| 247 | ColumnIdx->getType()->getScalarSizeInBits()); |
||
| 248 | Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); |
||
| 249 | RowIdx = B.CreateZExt(RowIdx, IntTy); |
||
| 250 | ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); |
||
| 251 | Value *NumRowsV = B.getIntN(MaxWidth, NumRows); |
||
| 252 | return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); |
||
| 253 | } |
||
| 254 | }; |
||
| 255 | |||
| 256 | } // end namespace llvm |
||
| 257 | |||
| 258 | #endif // LLVM_IR_MATRIXBUILDER_H |