- //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 
- // 
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 
- // See https://llvm.org/LICENSE.txt for license information. 
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 
- // 
- //===----------------------------------------------------------------------===// 
- // 
-   
- #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 
- #define LLVM_ANALYSIS_MLMODELRUNNER_H 
-   
- #include "llvm/Analysis/TensorSpec.h" 
- #include "llvm/IR/PassManager.h" 
-   
- namespace llvm { 
- class LLVMContext; 
-   
- /// MLModelRunner interface: abstraction of a mechanism for evaluating a 
- /// tensorflow "saved model". 
- /// NOTE: feature indices are expected to be consistent all accross 
- /// MLModelRunners (pertaining to the same model), and also Loggers (see 
- /// TFUtils.h) 
- class MLModelRunner { 
- public: 
-   // Disallows copy and assign. 
-   MLModelRunner(const MLModelRunner &) = delete; 
-   MLModelRunner &operator=(const MLModelRunner &) = delete; 
-   virtual ~MLModelRunner() = default; 
-   
-   template <typename T> T evaluate() { 
-     return *reinterpret_cast<T *>(evaluateUntyped()); 
-   } 
-   
-   template <typename T, typename I> T *getTensor(I FeatureID) { 
-     return reinterpret_cast<T *>( 
-         getTensorUntyped(static_cast<size_t>(FeatureID))); 
-   } 
-   
-   template <typename T, typename I> const T *getTensor(I FeatureID) const { 
-     return reinterpret_cast<const T *>( 
-         getTensorUntyped(static_cast<size_t>(FeatureID))); 
-   } 
-   
-   void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } 
-   const void *getTensorUntyped(size_t Index) const { 
-     return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 
-   } 
-   
-   enum class Kind : int { Unknown, Release, Development, NoOp }; 
-   Kind getKind() const { return Type; } 
-   
- protected: 
-   MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) 
-       : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { 
-     assert(Type != Kind::Unknown); 
-   } 
-   virtual void *evaluateUntyped() = 0; 
-   
-   void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 
-                             void *Buffer) { 
-     if (!Buffer) { 
-       OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 
-       Buffer = OwnedBuffers.back().data(); 
-     } 
-     InputBuffers[Index] = Buffer; 
-   } 
-   
-   LLVMContext &Ctx; 
-   const Kind Type; 
-   
- private: 
-   std::vector<void *> InputBuffers; 
-   std::vector<std::vector<char *>> OwnedBuffers; 
- }; 
- } // namespace llvm 
-   
- #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 
-