Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
14 pmbaty 1
//===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
 
10
#ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11
#define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
12
 
13
#include "llvm/ADT/STLExtras.h"
14
#include "llvm/ADT/iterator_range.h"
15
#include "llvm/Analysis/TensorSpec.h"
16
#include "llvm/Config/llvm-config.h"
17
 
18
#ifdef LLVM_HAVE_TFLITE
19
#include "llvm/Analysis/MLModelRunner.h"
20
#include "llvm/Analysis/Utils/TFUtils.h"
21
#include "llvm/IR/LLVMContext.h"
22
#include "llvm/IR/PassManager.h"
23
 
24
namespace llvm {
25
 
26
/// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs
27
/// to dynamically load and evaluate a TF SavedModel
28
/// (https://www.tensorflow.org/guide/saved_model). Runtime performance is
29
/// sacrificed for ease of use while training.
30
class ModelUnderTrainingRunner final : public MLModelRunner {
31
public:
32
  // Disallows copy and assign.
33
  ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
34
  ModelUnderTrainingRunner &
35
  operator=(const ModelUnderTrainingRunner &) = delete;
36
 
37
  const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
38
    return ExtraOutputsForLogging;
39
  }
40
 
41
  const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
42
    return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
43
  }
44
 
45
  const std::optional<TFModelEvaluator::EvaluationResult> &
46
  lastEvaluationResult() const {
47
    return LastEvaluationResult;
48
  }
49
  static bool classof(const MLModelRunner *R) {
50
    return R->getKind() == MLModelRunner::Kind::Development;
51
  }
52
 
53
  static std::unique_ptr<ModelUnderTrainingRunner>
54
  createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
55
                       StringRef DecisionName,
56
                       const std::vector<TensorSpec> &InputSpecs,
57
                       StringRef OutputSpecsPathOverride = "");
58
 
59
  ModelUnderTrainingRunner(
60
      LLVMContext &Ctx, const std::string &ModelPath,
61
      const std::vector<TensorSpec> &InputSpecs,
62
      const std::vector<TensorSpec> &OutputSpecs,
63
      const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
64
 
65
  bool isValid() const { return !!Evaluator; }
66
 
67
private:
68
  std::unique_ptr<TFModelEvaluator> Evaluator;
69
  const std::vector<TensorSpec> OutputSpecs;
70
  const std::vector<TensorSpec> ExtraOutputsForLogging;
71
  std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
72
  void *evaluateUntyped() override;
73
};
74
 
75
} // namespace llvm
76
#endif // define(LLVM_HAVE_TFLITE)
77
#endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H