Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- TFUtils.h - utilities for tensorflow C API ---------------*- C++ -*-===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H |
||
10 | #define LLVM_ANALYSIS_UTILS_TFUTILS_H |
||
11 | |||
12 | #include "llvm/Config/llvm-config.h" |
||
13 | |||
14 | #ifdef LLVM_HAVE_TFLITE |
||
15 | #include "llvm/ADT/StringMap.h" |
||
16 | #include "llvm/Analysis/TensorSpec.h" |
||
17 | #include "llvm/IR/LLVMContext.h" |
||
18 | #include "llvm/Support/JSON.h" |
||
19 | |||
20 | #include <memory> |
||
21 | #include <vector> |
||
22 | |||
23 | namespace llvm { |
||
24 | |||
25 | /// Load a SavedModel, find the given inputs and outputs, and setup storage |
||
26 | /// for input tensors. The user is responsible for correctly dimensioning the |
||
27 | /// input tensors and setting their values before calling evaluate(). |
||
28 | /// To initialize: |
||
29 | /// - construct the object |
||
30 | /// - initialize the input tensors using initInput. Indices must correspond to |
||
31 | /// indices in the InputNames used at construction. |
||
32 | /// To use: |
||
33 | /// - set input values by using getInput to get each input tensor, and then |
||
34 | /// setting internal scalars, for all dimensions (tensors are row-major: |
||
35 | /// https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205) |
||
36 | /// - call evaluate. The input tensors' values are not consumed after this, and |
||
37 | /// may still be read. |
||
38 | /// - use the outputs in the output vector |
||
39 | class TFModelEvaluatorImpl; |
||
40 | class EvaluationResultImpl; |
||
41 | |||
42 | class TFModelEvaluator final { |
||
43 | public: |
||
44 | /// The result of a model evaluation. Handles the lifetime of the output |
||
45 | /// tensors, which means that their values need to be used before |
||
46 | /// the EvaluationResult's dtor is called. |
||
47 | class EvaluationResult { |
||
48 | public: |
||
49 | EvaluationResult(const EvaluationResult &) = delete; |
||
50 | EvaluationResult &operator=(const EvaluationResult &Other) = delete; |
||
51 | |||
52 | EvaluationResult(EvaluationResult &&Other); |
||
53 | EvaluationResult &operator=(EvaluationResult &&Other); |
||
54 | |||
55 | ~EvaluationResult(); |
||
56 | |||
57 | /// Get a (const) pointer to the first element of the tensor at Index. |
||
58 | template <typename T> T *getTensorValue(size_t Index) { |
||
59 | return static_cast<T *>(getUntypedTensorValue(Index)); |
||
60 | } |
||
61 | |||
62 | template <typename T> const T *getTensorValue(size_t Index) const { |
||
63 | return static_cast<T *>(getUntypedTensorValue(Index)); |
||
64 | } |
||
65 | |||
66 | /// Get a (const) pointer to the untyped data of the tensor. |
||
67 | void *getUntypedTensorValue(size_t Index); |
||
68 | const void *getUntypedTensorValue(size_t Index) const; |
||
69 | |||
70 | private: |
||
71 | friend class TFModelEvaluator; |
||
72 | EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl); |
||
73 | std::unique_ptr<EvaluationResultImpl> Impl; |
||
74 | }; |
||
75 | |||
76 | TFModelEvaluator(StringRef SavedModelPath, |
||
77 | const std::vector<TensorSpec> &InputSpecs, |
||
78 | const std::vector<TensorSpec> &OutputSpecs, |
||
79 | const char *Tags = "serve"); |
||
80 | |||
81 | ~TFModelEvaluator(); |
||
82 | TFModelEvaluator(const TFModelEvaluator &) = delete; |
||
83 | TFModelEvaluator(TFModelEvaluator &&) = delete; |
||
84 | |||
85 | /// Evaluate the model, assuming it is valid. Returns std::nullopt if the |
||
86 | /// evaluation fails or the model is invalid, or an EvaluationResult |
||
87 | /// otherwise. The inputs are assumed to have been already provided via |
||
88 | /// getInput(). When returning std::nullopt, it also invalidates this object. |
||
89 | std::optional<EvaluationResult> evaluate(); |
||
90 | |||
91 | /// Provides access to the input vector. |
||
92 | template <typename T> T *getInput(size_t Index) { |
||
93 | return static_cast<T *>(getUntypedInput(Index)); |
||
94 | } |
||
95 | |||
96 | /// Returns true if the tensorflow model was loaded successfully, false |
||
97 | /// otherwise. |
||
98 | bool isValid() const { return !!Impl; } |
||
99 | |||
100 | /// Untyped access to input. |
||
101 | void *getUntypedInput(size_t Index); |
||
102 | |||
103 | private: |
||
104 | std::unique_ptr<TFModelEvaluatorImpl> Impl; |
||
105 | }; |
||
106 | |||
107 | } // namespace llvm |
||
108 | |||
109 | #endif // LLVM_HAVE_TFLITE |
||
110 | #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H |