Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | // This file defines the MatrixBuilder class, which is used as a convenient way |
||
10 | // to lower matrix operations to LLVM IR. |
||
11 | // |
||
12 | //===----------------------------------------------------------------------===// |
||
13 | |||
14 | #ifndef LLVM_IR_MATRIXBUILDER_H |
||
15 | #define LLVM_IR_MATRIXBUILDER_H |
||
16 | |||
17 | #include "llvm/IR/Constant.h" |
||
18 | #include "llvm/IR/Constants.h" |
||
19 | #include "llvm/IR/IRBuilder.h" |
||
20 | #include "llvm/IR/InstrTypes.h" |
||
21 | #include "llvm/IR/Instruction.h" |
||
22 | #include "llvm/IR/IntrinsicInst.h" |
||
23 | #include "llvm/IR/Type.h" |
||
24 | #include "llvm/IR/Value.h" |
||
25 | #include "llvm/Support/Alignment.h" |
||
26 | |||
27 | namespace llvm { |
||
28 | |||
29 | class Function; |
||
30 | class Twine; |
||
31 | class Module; |
||
32 | |||
33 | class MatrixBuilder { |
||
34 | IRBuilderBase &B; |
||
35 | Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } |
||
36 | |||
37 | std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, |
||
38 | Value *RHS) { |
||
39 | assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && |
||
40 | "One of the operands must be a matrix (embedded in a vector)"); |
||
41 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
42 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
43 | "LHS Assumed to be fixed width"); |
||
44 | RHS = B.CreateVectorSplat( |
||
45 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
46 | "scalar.splat"); |
||
47 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
48 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
49 | "RHS Assumed to be fixed width"); |
||
50 | LHS = B.CreateVectorSplat( |
||
51 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
52 | "scalar.splat"); |
||
53 | } |
||
54 | return {LHS, RHS}; |
||
55 | } |
||
56 | |||
57 | public: |
||
58 | MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} |
||
59 | |||
60 | /// Create a column major, strided matrix load. |
||
61 | /// \p EltTy - Matrix element type |
||
62 | /// \p DataPtr - Start address of the matrix read |
||
63 | /// \p Rows - Number of rows in matrix (must be a constant) |
||
64 | /// \p Columns - Number of columns in matrix (must be a constant) |
||
65 | /// \p Stride - Space between columns |
||
66 | CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, |
||
67 | Value *Stride, bool IsVolatile, unsigned Rows, |
||
68 | unsigned Columns, const Twine &Name = "") { |
||
69 | auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); |
||
70 | |||
71 | Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), |
||
72 | B.getInt32(Columns)}; |
||
73 | Type *OverloadedTypes[] = {RetType, Stride->getType()}; |
||
74 | |||
75 | Function *TheFn = Intrinsic::getDeclaration( |
||
76 | getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); |
||
77 | |||
78 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
79 | Attribute AlignAttr = |
||
80 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
||
81 | Call->addParamAttr(0, AlignAttr); |
||
82 | return Call; |
||
83 | } |
||
84 | |||
85 | /// Create a column major, strided matrix store. |
||
86 | /// \p Matrix - Matrix to store |
||
87 | /// \p Ptr - Pointer to write back to |
||
88 | /// \p Stride - Space between columns |
||
89 | CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, |
||
90 | Value *Stride, bool IsVolatile, |
||
91 | unsigned Rows, unsigned Columns, |
||
92 | const Twine &Name = "") { |
||
93 | Value *Ops[] = {Matrix, Ptr, |
||
94 | Stride, B.getInt1(IsVolatile), |
||
95 | B.getInt32(Rows), B.getInt32(Columns)}; |
||
96 | Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; |
||
97 | |||
98 | Function *TheFn = Intrinsic::getDeclaration( |
||
99 | getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); |
||
100 | |||
101 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
102 | Attribute AlignAttr = |
||
103 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
||
104 | Call->addParamAttr(1, AlignAttr); |
||
105 | return Call; |
||
106 | } |
||
107 | |||
108 | /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows |
||
109 | /// rows and \p Columns columns. |
||
110 | CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, |
||
111 | unsigned Columns, const Twine &Name = "") { |
||
112 | auto *OpType = cast<VectorType>(Matrix->getType()); |
||
113 | auto *ReturnType = |
||
114 | FixedVectorType::get(OpType->getElementType(), Rows * Columns); |
||
115 | |||
116 | Type *OverloadedTypes[] = {ReturnType}; |
||
117 | Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; |
||
118 | Function *TheFn = Intrinsic::getDeclaration( |
||
119 | getModule(), Intrinsic::matrix_transpose, OverloadedTypes); |
||
120 | |||
121 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
122 | } |
||
123 | |||
124 | /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p |
||
125 | /// RHS. |
||
126 | CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, |
||
127 | unsigned LHSColumns, unsigned RHSColumns, |
||
128 | const Twine &Name = "") { |
||
129 | auto *LHSType = cast<VectorType>(LHS->getType()); |
||
130 | auto *RHSType = cast<VectorType>(RHS->getType()); |
||
131 | |||
132 | auto *ReturnType = |
||
133 | FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); |
||
134 | |||
135 | Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), |
||
136 | B.getInt32(RHSColumns)}; |
||
137 | Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; |
||
138 | |||
139 | Function *TheFn = Intrinsic::getDeclaration( |
||
140 | getModule(), Intrinsic::matrix_multiply, OverloadedTypes); |
||
141 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
||
142 | } |
||
143 | |||
144 | /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p |
||
145 | /// ColumnIdx). |
||
146 | Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, |
||
147 | Value *ColumnIdx, unsigned NumRows) { |
||
148 | return B.CreateInsertElement( |
||
149 | Matrix, NewVal, |
||
150 | B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( |
||
151 | ColumnIdx->getType(), NumRows)), |
||
152 | RowIdx)); |
||
153 | } |
||
154 | |||
155 | /// Add matrixes \p LHS and \p RHS. Support both integer and floating point |
||
156 | /// matrixes. |
||
157 | Value *CreateAdd(Value *LHS, Value *RHS) { |
||
158 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
||
159 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
160 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
161 | "LHS Assumed to be fixed width"); |
||
162 | RHS = B.CreateVectorSplat( |
||
163 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
164 | "scalar.splat"); |
||
165 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
166 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
167 | "RHS Assumed to be fixed width"); |
||
168 | LHS = B.CreateVectorSplat( |
||
169 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
170 | "scalar.splat"); |
||
171 | } |
||
172 | |||
173 | return cast<VectorType>(LHS->getType()) |
||
174 | ->getElementType() |
||
175 | ->isFloatingPointTy() |
||
176 | ? B.CreateFAdd(LHS, RHS) |
||
177 | : B.CreateAdd(LHS, RHS); |
||
178 | } |
||
179 | |||
180 | /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating |
||
181 | /// point matrixes. |
||
182 | Value *CreateSub(Value *LHS, Value *RHS) { |
||
183 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
||
184 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
||
185 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
186 | "LHS Assumed to be fixed width"); |
||
187 | RHS = B.CreateVectorSplat( |
||
188 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
||
189 | "scalar.splat"); |
||
190 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
||
191 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
||
192 | "RHS Assumed to be fixed width"); |
||
193 | LHS = B.CreateVectorSplat( |
||
194 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
||
195 | "scalar.splat"); |
||
196 | } |
||
197 | |||
198 | return cast<VectorType>(LHS->getType()) |
||
199 | ->getElementType() |
||
200 | ->isFloatingPointTy() |
||
201 | ? B.CreateFSub(LHS, RHS) |
||
202 | : B.CreateSub(LHS, RHS); |
||
203 | } |
||
204 | |||
205 | /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p |
||
206 | /// RHS. |
||
207 | Value *CreateScalarMultiply(Value *LHS, Value *RHS) { |
||
208 | std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); |
||
209 | if (LHS->getType()->getScalarType()->isFloatingPointTy()) |
||
210 | return B.CreateFMul(LHS, RHS); |
||
211 | return B.CreateMul(LHS, RHS); |
||
212 | } |
||
213 | |||
214 | /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p |
||
215 | /// IsUnsigned indicates whether UDiv or SDiv should be used. |
||
216 | Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { |
||
217 | assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); |
||
218 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
||
219 | "LHS Assumed to be fixed width"); |
||
220 | RHS = |
||
221 | B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), |
||
222 | RHS, "scalar.splat"); |
||
223 | return cast<VectorType>(LHS->getType()) |
||
224 | ->getElementType() |
||
225 | ->isFloatingPointTy() |
||
226 | ? B.CreateFDiv(LHS, RHS) |
||
227 | : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); |
||
228 | } |
||
229 | |||
230 | /// Create an assumption that \p Idx is less than \p NumElements. |
||
231 | void CreateIndexAssumption(Value *Idx, unsigned NumElements, |
||
232 | Twine const &Name = "") { |
||
233 | Value *NumElts = |
||
234 | B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); |
||
235 | auto *Cmp = B.CreateICmpULT(Idx, NumElts); |
||
236 | if (isa<ConstantInt>(Cmp)) |
||
237 | assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!"); |
||
238 | else |
||
239 | B.CreateAssumption(Cmp); |
||
240 | } |
||
241 | |||
242 | /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from |
||
243 | /// a matrix with \p NumRows embedded in a vector. |
||
244 | Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, |
||
245 | Twine const &Name = "") { |
||
246 | unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), |
||
247 | ColumnIdx->getType()->getScalarSizeInBits()); |
||
248 | Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); |
||
249 | RowIdx = B.CreateZExt(RowIdx, IntTy); |
||
250 | ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); |
||
251 | Value *NumRowsV = B.getIntN(MaxWidth, NumRows); |
||
252 | return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); |
||
253 | } |
||
254 | }; |
||
255 | |||
256 | } // end namespace llvm |
||
257 | |||
258 | #endif // LLVM_IR_MATRIXBUILDER_H |