- //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// 
- // 
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 
- // See https://llvm.org/LICENSE.txt for license information. 
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 
- // 
- //===----------------------------------------------------------------------===// 
- // 
- // This file defines the MatrixBuilder class, which is used as a convenient way 
- // to lower matrix operations to LLVM IR. 
- // 
- //===----------------------------------------------------------------------===// 
-   
- #ifndef LLVM_IR_MATRIXBUILDER_H 
- #define LLVM_IR_MATRIXBUILDER_H 
-   
- #include "llvm/IR/Constant.h" 
- #include "llvm/IR/Constants.h" 
- #include "llvm/IR/IRBuilder.h" 
- #include "llvm/IR/InstrTypes.h" 
- #include "llvm/IR/Instruction.h" 
- #include "llvm/IR/IntrinsicInst.h" 
- #include "llvm/IR/Type.h" 
- #include "llvm/IR/Value.h" 
- #include "llvm/Support/Alignment.h" 
-   
- namespace llvm { 
-   
- class Function; 
- class Twine; 
- class Module; 
-   
- class MatrixBuilder { 
-   IRBuilderBase &B; 
-   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } 
-   
-   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, 
-                                                          Value *RHS) { 
-     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && 
-            "One of the operands must be a matrix (embedded in a vector)"); 
-     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(LHS->getType()) && 
-              "LHS Assumed to be fixed width"); 
-       RHS = B.CreateVectorSplat( 
-           cast<VectorType>(LHS->getType())->getElementCount(), RHS, 
-           "scalar.splat"); 
-     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(RHS->getType()) && 
-              "RHS Assumed to be fixed width"); 
-       LHS = B.CreateVectorSplat( 
-           cast<VectorType>(RHS->getType())->getElementCount(), LHS, 
-           "scalar.splat"); 
-     } 
-     return {LHS, RHS}; 
-   } 
-   
- public: 
-   MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} 
-   
-   /// Create a column major, strided matrix load. 
-   /// \p EltTy   - Matrix element type 
-   /// \p DataPtr - Start address of the matrix read 
-   /// \p Rows    - Number of rows in matrix (must be a constant) 
-   /// \p Columns - Number of columns in matrix (must be a constant) 
-   /// \p Stride  - Space between columns 
-   CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, 
-                                   Value *Stride, bool IsVolatile, unsigned Rows, 
-                                   unsigned Columns, const Twine &Name = "") { 
-     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); 
-   
-     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), 
-                     B.getInt32(Columns)}; 
-     Type *OverloadedTypes[] = {RetType, Stride->getType()}; 
-   
-     Function *TheFn = Intrinsic::getDeclaration( 
-         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); 
-   
-     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 
-     Attribute AlignAttr = 
-         Attribute::getWithAlignment(Call->getContext(), Alignment); 
-     Call->addParamAttr(0, AlignAttr); 
-     return Call; 
-   } 
-   
-   /// Create a column major, strided matrix store. 
-   /// \p Matrix  - Matrix to store 
-   /// \p Ptr     - Pointer to write back to 
-   /// \p Stride  - Space between columns 
-   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, 
-                                    Value *Stride, bool IsVolatile, 
-                                    unsigned Rows, unsigned Columns, 
-                                    const Twine &Name = "") { 
-     Value *Ops[] = {Matrix,           Ptr, 
-                     Stride,           B.getInt1(IsVolatile), 
-                     B.getInt32(Rows), B.getInt32(Columns)}; 
-     Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; 
-   
-     Function *TheFn = Intrinsic::getDeclaration( 
-         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); 
-   
-     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 
-     Attribute AlignAttr = 
-         Attribute::getWithAlignment(Call->getContext(), Alignment); 
-     Call->addParamAttr(1, AlignAttr); 
-     return Call; 
-   } 
-   
-   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows 
-   /// rows and \p Columns columns. 
-   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, 
-                                   unsigned Columns, const Twine &Name = "") { 
-     auto *OpType = cast<VectorType>(Matrix->getType()); 
-     auto *ReturnType = 
-         FixedVectorType::get(OpType->getElementType(), Rows * Columns); 
-   
-     Type *OverloadedTypes[] = {ReturnType}; 
-     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; 
-     Function *TheFn = Intrinsic::getDeclaration( 
-         getModule(), Intrinsic::matrix_transpose, OverloadedTypes); 
-   
-     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 
-   } 
-   
-   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p 
-   /// RHS. 
-   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, 
-                                  unsigned LHSColumns, unsigned RHSColumns, 
-                                  const Twine &Name = "") { 
-     auto *LHSType = cast<VectorType>(LHS->getType()); 
-     auto *RHSType = cast<VectorType>(RHS->getType()); 
-   
-     auto *ReturnType = 
-         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); 
-   
-     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), 
-                     B.getInt32(RHSColumns)}; 
-     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; 
-   
-     Function *TheFn = Intrinsic::getDeclaration( 
-         getModule(), Intrinsic::matrix_multiply, OverloadedTypes); 
-     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 
-   } 
-   
-   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p 
-   /// ColumnIdx). 
-   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, 
-                             Value *ColumnIdx, unsigned NumRows) { 
-     return B.CreateInsertElement( 
-         Matrix, NewVal, 
-         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( 
-                                                ColumnIdx->getType(), NumRows)), 
-                     RowIdx)); 
-   } 
-   
-   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point 
-   /// matrixes. 
-   Value *CreateAdd(Value *LHS, Value *RHS) { 
-     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 
-     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(LHS->getType()) && 
-              "LHS Assumed to be fixed width"); 
-       RHS = B.CreateVectorSplat( 
-           cast<VectorType>(LHS->getType())->getElementCount(), RHS, 
-           "scalar.splat"); 
-     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(RHS->getType()) && 
-              "RHS Assumed to be fixed width"); 
-       LHS = B.CreateVectorSplat( 
-           cast<VectorType>(RHS->getType())->getElementCount(), LHS, 
-           "scalar.splat"); 
-     } 
-   
-     return cast<VectorType>(LHS->getType()) 
-                    ->getElementType() 
-                    ->isFloatingPointTy() 
-                ? B.CreateFAdd(LHS, RHS) 
-                : B.CreateAdd(LHS, RHS); 
-   } 
-   
-   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating 
-   /// point matrixes. 
-   Value *CreateSub(Value *LHS, Value *RHS) { 
-     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 
-     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(LHS->getType()) && 
-              "LHS Assumed to be fixed width"); 
-       RHS = B.CreateVectorSplat( 
-           cast<VectorType>(LHS->getType())->getElementCount(), RHS, 
-           "scalar.splat"); 
-     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 
-       assert(!isa<ScalableVectorType>(RHS->getType()) && 
-              "RHS Assumed to be fixed width"); 
-       LHS = B.CreateVectorSplat( 
-           cast<VectorType>(RHS->getType())->getElementCount(), LHS, 
-           "scalar.splat"); 
-     } 
-   
-     return cast<VectorType>(LHS->getType()) 
-                    ->getElementType() 
-                    ->isFloatingPointTy() 
-                ? B.CreateFSub(LHS, RHS) 
-                : B.CreateSub(LHS, RHS); 
-   } 
-   
-   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p 
-   /// RHS. 
-   Value *CreateScalarMultiply(Value *LHS, Value *RHS) { 
-     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); 
-     if (LHS->getType()->getScalarType()->isFloatingPointTy()) 
-       return B.CreateFMul(LHS, RHS); 
-     return B.CreateMul(LHS, RHS); 
-   } 
-   
-   /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p 
-   /// IsUnsigned indicates whether UDiv or SDiv should be used. 
-   Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { 
-     assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); 
-     assert(!isa<ScalableVectorType>(LHS->getType()) && 
-            "LHS Assumed to be fixed width"); 
-     RHS = 
-         B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), 
-                             RHS, "scalar.splat"); 
-     return cast<VectorType>(LHS->getType()) 
-                    ->getElementType() 
-                    ->isFloatingPointTy() 
-                ? B.CreateFDiv(LHS, RHS) 
-                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); 
-   } 
-   
-   /// Create an assumption that \p Idx is less than \p NumElements. 
-   void CreateIndexAssumption(Value *Idx, unsigned NumElements, 
-                              Twine const &Name = "") { 
-     Value *NumElts = 
-         B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); 
-     auto *Cmp = B.CreateICmpULT(Idx, NumElts); 
-     if (isa<ConstantInt>(Cmp)) 
-       assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!"); 
-     else 
-       B.CreateAssumption(Cmp); 
-   } 
-   
-   /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from 
-   /// a matrix with \p NumRows embedded in a vector. 
-   Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, 
-                      Twine const &Name = "") { 
-     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), 
-                                  ColumnIdx->getType()->getScalarSizeInBits()); 
-     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); 
-     RowIdx = B.CreateZExt(RowIdx, IntTy); 
-     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); 
-     Value *NumRowsV = B.getIntN(MaxWidth, NumRows); 
-     return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); 
-   } 
- }; 
-   
- } // end namespace llvm 
-   
- #endif // LLVM_IR_MATRIXBUILDER_H 
-