Subversion Repositories QNX 8.QNX8 LLVM/Clang compiler suite

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
14 pmbaty 1
//===- TFUtils.h - utilities for tensorflow C API ---------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
#ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H
10
#define LLVM_ANALYSIS_UTILS_TFUTILS_H
11
 
12
#include "llvm/Config/llvm-config.h"
13
 
14
#ifdef LLVM_HAVE_TFLITE
15
#include "llvm/ADT/StringMap.h"
16
#include "llvm/Analysis/TensorSpec.h"
17
#include "llvm/IR/LLVMContext.h"
18
#include "llvm/Support/JSON.h"
19
 
20
#include <memory>
21
#include <vector>
22
 
23
namespace llvm {
24
 
25
/// Load a SavedModel, find the given inputs and outputs, and setup storage
26
/// for input tensors. The user is responsible for correctly dimensioning the
27
/// input tensors and setting their values before calling evaluate().
28
/// To initialize:
29
/// - construct the object
30
/// - initialize the input tensors using initInput. Indices must correspond to
31
///   indices in the InputNames used at construction.
32
/// To use:
33
/// - set input values by using getInput to get each input tensor, and then
34
///   setting internal scalars, for all dimensions (tensors are row-major:
35
///   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205)
36
/// - call evaluate. The input tensors' values are not consumed after this, and
37
///   may still be read.
38
/// - use the outputs in the output vector
39
class TFModelEvaluatorImpl;
40
class EvaluationResultImpl;
41
 
42
class TFModelEvaluator final {
43
public:
44
  /// The result of a model evaluation. Handles the lifetime of the output
45
  /// tensors, which means that their values need to be used before
46
  /// the EvaluationResult's dtor is called.
47
  class EvaluationResult {
48
  public:
49
    EvaluationResult(const EvaluationResult &) = delete;
50
    EvaluationResult &operator=(const EvaluationResult &Other) = delete;
51
 
52
    EvaluationResult(EvaluationResult &&Other);
53
    EvaluationResult &operator=(EvaluationResult &&Other);
54
 
55
    ~EvaluationResult();
56
 
57
    /// Get a (const) pointer to the first element of the tensor at Index.
58
    template <typename T> T *getTensorValue(size_t Index) {
59
      return static_cast<T *>(getUntypedTensorValue(Index));
60
    }
61
 
62
    template <typename T> const T *getTensorValue(size_t Index) const {
63
      return static_cast<T *>(getUntypedTensorValue(Index));
64
    }
65
 
66
    /// Get a (const) pointer to the untyped data of the tensor.
67
    void *getUntypedTensorValue(size_t Index);
68
    const void *getUntypedTensorValue(size_t Index) const;
69
 
70
  private:
71
    friend class TFModelEvaluator;
72
    EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
73
    std::unique_ptr<EvaluationResultImpl> Impl;
74
  };
75
 
76
  TFModelEvaluator(StringRef SavedModelPath,
77
                   const std::vector<TensorSpec> &InputSpecs,
78
                   const std::vector<TensorSpec> &OutputSpecs,
79
                   const char *Tags = "serve");
80
 
81
  ~TFModelEvaluator();
82
  TFModelEvaluator(const TFModelEvaluator &) = delete;
83
  TFModelEvaluator(TFModelEvaluator &&) = delete;
84
 
85
  /// Evaluate the model, assuming it is valid. Returns std::nullopt if the
86
  /// evaluation fails or the model is invalid, or an EvaluationResult
87
  /// otherwise. The inputs are assumed to have been already provided via
88
  /// getInput(). When returning std::nullopt, it also invalidates this object.
89
  std::optional<EvaluationResult> evaluate();
90
 
91
  /// Provides access to the input vector.
92
  template <typename T> T *getInput(size_t Index) {
93
    return static_cast<T *>(getUntypedInput(Index));
94
  }
95
 
96
  /// Returns true if the tensorflow model was loaded successfully, false
97
  /// otherwise.
98
  bool isValid() const { return !!Impl; }
99
 
100
  /// Untyped access to input.
101
  void *getUntypedInput(size_t Index);
102
 
103
private:
104
  std::unique_ptr<TFModelEvaluatorImpl> Impl;
105
};
106
 
107
} // namespace llvm
108
 
109
#endif // LLVM_HAVE_TFLITE
110
#endif // LLVM_ANALYSIS_UTILS_TFUTILS_H