Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
14 pmbaty 1
//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines the MatrixBuilder class, which is used as a convenient way
10
// to lower matrix operations to LLVM IR.
11
//
12
//===----------------------------------------------------------------------===//
13
 
14
#ifndef LLVM_IR_MATRIXBUILDER_H
15
#define LLVM_IR_MATRIXBUILDER_H
16
 
17
#include "llvm/IR/Constant.h"
18
#include "llvm/IR/Constants.h"
19
#include "llvm/IR/IRBuilder.h"
20
#include "llvm/IR/InstrTypes.h"
21
#include "llvm/IR/Instruction.h"
22
#include "llvm/IR/IntrinsicInst.h"
23
#include "llvm/IR/Type.h"
24
#include "llvm/IR/Value.h"
25
#include "llvm/Support/Alignment.h"
26
 
27
namespace llvm {
28
 
29
class Function;
30
class Twine;
31
class Module;
32
 
33
class MatrixBuilder {
34
  IRBuilderBase &B;
35
  Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
 
37
  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38
                                                         Value *RHS) {
39
    assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40
           "One of the operands must be a matrix (embedded in a vector)");
41
    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42
      assert(!isa<ScalableVectorType>(LHS->getType()) &&
43
             "LHS Assumed to be fixed width");
44
      RHS = B.CreateVectorSplat(
45
          cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46
          "scalar.splat");
47
    } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48
      assert(!isa<ScalableVectorType>(RHS->getType()) &&
49
             "RHS Assumed to be fixed width");
50
      LHS = B.CreateVectorSplat(
51
          cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52
          "scalar.splat");
53
    }
54
    return {LHS, RHS};
55
  }
56
 
57
public:
58
  MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
59
 
60
  /// Create a column major, strided matrix load.
61
  /// \p EltTy   - Matrix element type
62
  /// \p DataPtr - Start address of the matrix read
63
  /// \p Rows    - Number of rows in matrix (must be a constant)
64
  /// \p Columns - Number of columns in matrix (must be a constant)
65
  /// \p Stride  - Space between columns
66
  CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67
                                  Value *Stride, bool IsVolatile, unsigned Rows,
68
                                  unsigned Columns, const Twine &Name = "") {
69
    auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
70
 
71
    Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72
                    B.getInt32(Columns)};
73
    Type *OverloadedTypes[] = {RetType, Stride->getType()};
74
 
75
    Function *TheFn = Intrinsic::getDeclaration(
76
        getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77
 
78
    CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
79
    Attribute AlignAttr =
80
        Attribute::getWithAlignment(Call->getContext(), Alignment);
81
    Call->addParamAttr(0, AlignAttr);
82
    return Call;
83
  }
84
 
85
  /// Create a column major, strided matrix store.
86
  /// \p Matrix  - Matrix to store
87
  /// \p Ptr     - Pointer to write back to
88
  /// \p Stride  - Space between columns
89
  CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
90
                                   Value *Stride, bool IsVolatile,
91
                                   unsigned Rows, unsigned Columns,
92
                                   const Twine &Name = "") {
93
    Value *Ops[] = {Matrix,           Ptr,
94
                    Stride,           B.getInt1(IsVolatile),
95
                    B.getInt32(Rows), B.getInt32(Columns)};
96
    Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97
 
98
    Function *TheFn = Intrinsic::getDeclaration(
99
        getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100
 
101
    CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
102
    Attribute AlignAttr =
103
        Attribute::getWithAlignment(Call->getContext(), Alignment);
104
    Call->addParamAttr(1, AlignAttr);
105
    return Call;
106
  }
107
 
108
  /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109
  /// rows and \p Columns columns.
110
  CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
111
                                  unsigned Columns, const Twine &Name = "") {
112
    auto *OpType = cast<VectorType>(Matrix->getType());
113
    auto *ReturnType =
114
        FixedVectorType::get(OpType->getElementType(), Rows * Columns);
115
 
116
    Type *OverloadedTypes[] = {ReturnType};
117
    Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
118
    Function *TheFn = Intrinsic::getDeclaration(
119
        getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120
 
121
    return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
122
  }
123
 
124
  /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125
  /// RHS.
126
  CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
127
                                 unsigned LHSColumns, unsigned RHSColumns,
128
                                 const Twine &Name = "") {
129
    auto *LHSType = cast<VectorType>(LHS->getType());
130
    auto *RHSType = cast<VectorType>(RHS->getType());
131
 
132
    auto *ReturnType =
133
        FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
134
 
135
    Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136
                    B.getInt32(RHSColumns)};
137
    Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138
 
139
    Function *TheFn = Intrinsic::getDeclaration(
140
        getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
141
    return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
142
  }
143
 
144
  /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
145
  /// ColumnIdx).
146
  Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
147
                            Value *ColumnIdx, unsigned NumRows) {
148
    return B.CreateInsertElement(
149
        Matrix, NewVal,
150
        B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151
                                               ColumnIdx->getType(), NumRows)),
152
                    RowIdx));
153
  }
154
 
155
  /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
156
  /// matrixes.
157
  Value *CreateAdd(Value *LHS, Value *RHS) {
158
    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159
    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
160
      assert(!isa<ScalableVectorType>(LHS->getType()) &&
161
             "LHS Assumed to be fixed width");
162
      RHS = B.CreateVectorSplat(
163
          cast<VectorType>(LHS->getType())->getElementCount(), RHS,
164
          "scalar.splat");
165
    } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
166
      assert(!isa<ScalableVectorType>(RHS->getType()) &&
167
             "RHS Assumed to be fixed width");
168
      LHS = B.CreateVectorSplat(
169
          cast<VectorType>(RHS->getType())->getElementCount(), LHS,
170
          "scalar.splat");
171
    }
172
 
173
    return cast<VectorType>(LHS->getType())
174
                   ->getElementType()
175
                   ->isFloatingPointTy()
176
               ? B.CreateFAdd(LHS, RHS)
177
               : B.CreateAdd(LHS, RHS);
178
  }
179
 
180
  /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
181
  /// point matrixes.
182
  Value *CreateSub(Value *LHS, Value *RHS) {
183
    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184
    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
185
      assert(!isa<ScalableVectorType>(LHS->getType()) &&
186
             "LHS Assumed to be fixed width");
187
      RHS = B.CreateVectorSplat(
188
          cast<VectorType>(LHS->getType())->getElementCount(), RHS,
189
          "scalar.splat");
190
    } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
191
      assert(!isa<ScalableVectorType>(RHS->getType()) &&
192
             "RHS Assumed to be fixed width");
193
      LHS = B.CreateVectorSplat(
194
          cast<VectorType>(RHS->getType())->getElementCount(), LHS,
195
          "scalar.splat");
196
    }
197
 
198
    return cast<VectorType>(LHS->getType())
199
                   ->getElementType()
200
                   ->isFloatingPointTy()
201
               ? B.CreateFSub(LHS, RHS)
202
               : B.CreateSub(LHS, RHS);
203
  }
204
 
205
  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
206
  /// RHS.
207
  Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
208
    std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209
    if (LHS->getType()->getScalarType()->isFloatingPointTy())
210
      return B.CreateFMul(LHS, RHS);
211
    return B.CreateMul(LHS, RHS);
212
  }
213
 
214
  /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
215
  /// IsUnsigned indicates whether UDiv or SDiv should be used.
216
  Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
217
    assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
218
    assert(!isa<ScalableVectorType>(LHS->getType()) &&
219
           "LHS Assumed to be fixed width");
220
    RHS =
221
        B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
222
                            RHS, "scalar.splat");
223
    return cast<VectorType>(LHS->getType())
224
                   ->getElementType()
225
                   ->isFloatingPointTy()
226
               ? B.CreateFDiv(LHS, RHS)
227
               : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228
  }
229
 
230
  /// Create an assumption that \p Idx is less than \p NumElements.
231
  void CreateIndexAssumption(Value *Idx, unsigned NumElements,
232
                             Twine const &Name = "") {
233
    Value *NumElts =
234
        B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
235
    auto *Cmp = B.CreateICmpULT(Idx, NumElts);
236
    if (isa<ConstantInt>(Cmp))
237
      assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
238
    else
239
      B.CreateAssumption(Cmp);
240
  }
241
 
242
  /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243
  /// a matrix with \p NumRows embedded in a vector.
244
  Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245
                     Twine const &Name = "") {
246
    unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
247
                                 ColumnIdx->getType()->getScalarSizeInBits());
248
    Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
249
    RowIdx = B.CreateZExt(RowIdx, IntTy);
250
    ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251
    Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252
    return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253
  }
254
};
255
 
256
} // end namespace llvm
257
 
258
#endif // LLVM_IR_MATRIXBUILDER_H