Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. //===- ReleaseModeModelRunner.h - Fast, precompiled model runner  ---------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a model runner wrapping an AOT compiled ML model.
  10. // Only inference is supported.
  11. //
  12. //===----------------------------------------------------------------------===//
  13.  
  14. #ifndef LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  15. #define LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  16.  
  17. #include "llvm/Analysis/MLModelRunner.h"
  18. #include "llvm/Analysis/TensorSpec.h"
  19. #include "llvm/Support/ErrorHandling.h"
  20.  
  21. #include <memory>
  22. #include <vector>
  23.  
  24. namespace llvm {
  25.  
  26. /// ReleaseModeModelRunner - production mode implementation of the
  27. /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
  28. template <class TGen>
  29. class ReleaseModeModelRunner final : public MLModelRunner {
  30. public:
  31.   /// FeatureNames' type should be an indexed collection of std::string, like
  32.   /// std::array or std::vector, that has a size() method.
  33.   template <class FType>
  34.   ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
  35.                          StringRef DecisionName, StringRef FeedPrefix = "feed_",
  36.                          StringRef FetchPrefix = "fetch_")
  37.       : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
  38.         CompiledModel(std::make_unique<TGen>()) {
  39.     assert(CompiledModel && "The CompiledModel should be valid");
  40.  
  41.     for (size_t I = 0; I < InputSpec.size(); ++I) {
  42.       const int Index =
  43.           CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
  44.       void *Buffer = nullptr;
  45.       if (Index >= 0)
  46.         Buffer = CompiledModel->arg_data(Index);
  47.       setUpBufferForTensor(I, InputSpec[I], Buffer);
  48.     }
  49.  
  50.     ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
  51.                                                    DecisionName.str());
  52.     assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
  53.   }
  54.  
  55.   virtual ~ReleaseModeModelRunner() = default;
  56.  
  57.   static bool classof(const MLModelRunner *R) {
  58.     return R->getKind() == MLModelRunner::Kind::Release;
  59.   }
  60.  
  61. private:
  62.   void *evaluateUntyped() override {
  63.     CompiledModel->Run();
  64.     return CompiledModel->result_data(ResultIndex);
  65.   }
  66.  
  67.   int32_t ResultIndex = -1;
  68.   std::unique_ptr<TGen> CompiledModel;
  69. };
  70.  
  71. /// A mock class satisfying the interface expected by ReleaseModeModelRunner for
  72. /// its `TGen` parameter. Useful to avoid conditional compilation complexity, as
  73. /// a compile-time replacement for a real AOT-ed model.
  74. class NoopSavedModelImpl final {
  75. #define NOOP_MODEL_ERRMSG                                                      \
  76.   "The mock AOT-ed saved model is a compile-time stub and should not be "      \
  77.   "called."
  78.  
  79. public:
  80.   NoopSavedModelImpl() = default;
  81.   int LookupArgIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  82.   int LookupResultIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  83.   void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  84.   void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  85.   void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  86. #undef NOOP_MODEL_ERRMSG
  87. };
  88. } // namespace llvm
  89.  
  90. #endif // LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  91.