//===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
 
//
 
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 
// See https://llvm.org/LICENSE.txt for license information.
 
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
//
 
//===----------------------------------------------------------------------===//
 
//
 
 
 
#ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
 
#define LLVM_ANALYSIS_MLMODELRUNNER_H
 
 
 
#include "llvm/Analysis/TensorSpec.h"
 
#include "llvm/IR/PassManager.h"
 
 
 
namespace llvm {
 
class LLVMContext;
 
 
 
/// MLModelRunner interface: abstraction of a mechanism for evaluating a
 
/// tensorflow "saved model".
 
/// NOTE: feature indices are expected to be consistent all accross
 
/// MLModelRunners (pertaining to the same model), and also Loggers (see
 
/// TFUtils.h)
 
class MLModelRunner {
 
public:
 
  // Disallows copy and assign.
 
  MLModelRunner(const MLModelRunner &) = delete;
 
  MLModelRunner &operator=(const MLModelRunner &) = delete;
 
  virtual ~MLModelRunner() = default;
 
 
 
  template <typename T> T evaluate() {
 
    return *reinterpret_cast<T *>(evaluateUntyped());
 
  }
 
 
 
  template <typename T, typename I> T *getTensor(I FeatureID) {
 
    return reinterpret_cast<T *>(
 
        getTensorUntyped(static_cast<size_t>(FeatureID)));
 
  }
 
 
 
  template <typename T, typename I> const T *getTensor(I FeatureID) const {
 
    return reinterpret_cast<const T *>(
 
        getTensorUntyped(static_cast<size_t>(FeatureID)));
 
  }
 
 
 
  void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; }
 
  const void *getTensorUntyped(size_t Index) const {
 
    return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
 
  }
 
 
 
  enum class Kind : int { Unknown, Release, Development, NoOp };
 
  Kind getKind() const { return Type; }
 
 
 
protected:
 
  MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs)
 
      : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) {
 
    assert(Type != Kind::Unknown);
 
  }
 
  virtual void *evaluateUntyped() = 0;
 
 
 
  void setUpBufferForTensor(size_t Index, const TensorSpec &Spec,
 
                            void *Buffer) {
 
    if (!Buffer) {
 
      OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize());
 
      Buffer = OwnedBuffers.back().data();
 
    }
 
    InputBuffers[Index] = Buffer;
 
  }
 
 
 
  LLVMContext &Ctx;
 
  const Kind Type;
 
 
 
private:
 
  std::vector<void *> InputBuffers;
 
  std::vector<std::vector<char *>> OwnedBuffers;
 
};
 
} // namespace llvm
 
 
 
#endif // LLVM_ANALYSIS_MLMODELRUNNER_H