Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line |
|---|---|---|---|
| 14 | pmbaty | 1 | //===- TFUtils.h - utilities for tensorflow C API ---------------*- C++ -*-===// |
| 2 | // |
||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
| 6 | // |
||
| 7 | //===----------------------------------------------------------------------===// |
||
| 8 | // |
||
| 9 | #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H |
||
| 10 | #define LLVM_ANALYSIS_UTILS_TFUTILS_H |
||
| 11 | |||
| 12 | #include "llvm/Config/llvm-config.h" |
||
| 13 | |||
| 14 | #ifdef LLVM_HAVE_TFLITE |
||
| 15 | #include "llvm/ADT/StringMap.h" |
||
| 16 | #include "llvm/Analysis/TensorSpec.h" |
||
| 17 | #include "llvm/IR/LLVMContext.h" |
||
| 18 | #include "llvm/Support/JSON.h" |
||
| 19 | |||
| 20 | #include <memory> |
||
| 21 | #include <vector> |
||
| 22 | |||
| 23 | namespace llvm { |
||
| 24 | |||
| 25 | /// Load a SavedModel, find the given inputs and outputs, and setup storage |
||
| 26 | /// for input tensors. The user is responsible for correctly dimensioning the |
||
| 27 | /// input tensors and setting their values before calling evaluate(). |
||
| 28 | /// To initialize: |
||
| 29 | /// - construct the object |
||
| 30 | /// - initialize the input tensors using initInput. Indices must correspond to |
||
| 31 | /// indices in the InputNames used at construction. |
||
| 32 | /// To use: |
||
| 33 | /// - set input values by using getInput to get each input tensor, and then |
||
| 34 | /// setting internal scalars, for all dimensions (tensors are row-major: |
||
| 35 | /// https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205) |
||
| 36 | /// - call evaluate. The input tensors' values are not consumed after this, and |
||
| 37 | /// may still be read. |
||
| 38 | /// - use the outputs in the output vector |
||
| 39 | class TFModelEvaluatorImpl; |
||
| 40 | class EvaluationResultImpl; |
||
| 41 | |||
| 42 | class TFModelEvaluator final { |
||
| 43 | public: |
||
| 44 | /// The result of a model evaluation. Handles the lifetime of the output |
||
| 45 | /// tensors, which means that their values need to be used before |
||
| 46 | /// the EvaluationResult's dtor is called. |
||
| 47 | class EvaluationResult { |
||
| 48 | public: |
||
| 49 | EvaluationResult(const EvaluationResult &) = delete; |
||
| 50 | EvaluationResult &operator=(const EvaluationResult &Other) = delete; |
||
| 51 | |||
| 52 | EvaluationResult(EvaluationResult &&Other); |
||
| 53 | EvaluationResult &operator=(EvaluationResult &&Other); |
||
| 54 | |||
| 55 | ~EvaluationResult(); |
||
| 56 | |||
| 57 | /// Get a (const) pointer to the first element of the tensor at Index. |
||
| 58 | template <typename T> T *getTensorValue(size_t Index) { |
||
| 59 | return static_cast<T *>(getUntypedTensorValue(Index)); |
||
| 60 | } |
||
| 61 | |||
| 62 | template <typename T> const T *getTensorValue(size_t Index) const { |
||
| 63 | return static_cast<T *>(getUntypedTensorValue(Index)); |
||
| 64 | } |
||
| 65 | |||
| 66 | /// Get a (const) pointer to the untyped data of the tensor. |
||
| 67 | void *getUntypedTensorValue(size_t Index); |
||
| 68 | const void *getUntypedTensorValue(size_t Index) const; |
||
| 69 | |||
| 70 | private: |
||
| 71 | friend class TFModelEvaluator; |
||
| 72 | EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl); |
||
| 73 | std::unique_ptr<EvaluationResultImpl> Impl; |
||
| 74 | }; |
||
| 75 | |||
| 76 | TFModelEvaluator(StringRef SavedModelPath, |
||
| 77 | const std::vector<TensorSpec> &InputSpecs, |
||
| 78 | const std::vector<TensorSpec> &OutputSpecs, |
||
| 79 | const char *Tags = "serve"); |
||
| 80 | |||
| 81 | ~TFModelEvaluator(); |
||
| 82 | TFModelEvaluator(const TFModelEvaluator &) = delete; |
||
| 83 | TFModelEvaluator(TFModelEvaluator &&) = delete; |
||
| 84 | |||
| 85 | /// Evaluate the model, assuming it is valid. Returns std::nullopt if the |
||
| 86 | /// evaluation fails or the model is invalid, or an EvaluationResult |
||
| 87 | /// otherwise. The inputs are assumed to have been already provided via |
||
| 88 | /// getInput(). When returning std::nullopt, it also invalidates this object. |
||
| 89 | std::optional<EvaluationResult> evaluate(); |
||
| 90 | |||
| 91 | /// Provides access to the input vector. |
||
| 92 | template <typename T> T *getInput(size_t Index) { |
||
| 93 | return static_cast<T *>(getUntypedInput(Index)); |
||
| 94 | } |
||
| 95 | |||
| 96 | /// Returns true if the tensorflow model was loaded successfully, false |
||
| 97 | /// otherwise. |
||
| 98 | bool isValid() const { return !!Impl; } |
||
| 99 | |||
| 100 | /// Untyped access to input. |
||
| 101 | void *getUntypedInput(size_t Index); |
||
| 102 | |||
| 103 | private: |
||
| 104 | std::unique_ptr<TFModelEvaluatorImpl> Impl; |
||
| 105 | }; |
||
| 106 | |||
| 107 | } // namespace llvm |
||
| 108 | |||
| 109 | #endif // LLVM_HAVE_TFLITE |
||
| 110 | #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H |