Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9.  
  10. #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  11. #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  12.  
  13. #include "llvm/ADT/STLExtras.h"
  14. #include "llvm/ADT/iterator_range.h"
  15. #include "llvm/Analysis/TensorSpec.h"
  16. #include "llvm/Config/llvm-config.h"
  17.  
  18. #ifdef LLVM_HAVE_TFLITE
  19. #include "llvm/Analysis/MLModelRunner.h"
  20. #include "llvm/Analysis/Utils/TFUtils.h"
  21. #include "llvm/IR/LLVMContext.h"
  22. #include "llvm/IR/PassManager.h"
  23.  
  24. namespace llvm {
  25.  
  26. /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs
  27. /// to dynamically load and evaluate a TF SavedModel
  28. /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is
  29. /// sacrificed for ease of use while training.
  30. class ModelUnderTrainingRunner final : public MLModelRunner {
  31. public:
  32.   // Disallows copy and assign.
  33.   ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
  34.   ModelUnderTrainingRunner &
  35.   operator=(const ModelUnderTrainingRunner &) = delete;
  36.  
  37.   const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
  38.     return ExtraOutputsForLogging;
  39.   }
  40.  
  41.   const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
  42.     return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
  43.   }
  44.  
  45.   const std::optional<TFModelEvaluator::EvaluationResult> &
  46.   lastEvaluationResult() const {
  47.     return LastEvaluationResult;
  48.   }
  49.   static bool classof(const MLModelRunner *R) {
  50.     return R->getKind() == MLModelRunner::Kind::Development;
  51.   }
  52.  
  53.   static std::unique_ptr<ModelUnderTrainingRunner>
  54.   createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
  55.                        StringRef DecisionName,
  56.                        const std::vector<TensorSpec> &InputSpecs,
  57.                        StringRef OutputSpecsPathOverride = "");
  58.  
  59.   ModelUnderTrainingRunner(
  60.       LLVMContext &Ctx, const std::string &ModelPath,
  61.       const std::vector<TensorSpec> &InputSpecs,
  62.       const std::vector<TensorSpec> &OutputSpecs,
  63.       const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
  64.  
  65.   bool isValid() const { return !!Evaluator; }
  66.  
  67. private:
  68.   std::unique_ptr<TFModelEvaluator> Evaluator;
  69.   const std::vector<TensorSpec> OutputSpecs;
  70.   const std::vector<TensorSpec> ExtraOutputsForLogging;
  71.   std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
  72.   void *evaluateUntyped() override;
  73. };
  74.  
  75. } // namespace llvm
  76. #endif // define(LLVM_HAVE_TFLITE)
  77. #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  78.