Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line | 
|---|---|---|---|
| 14 | pmbaty | 1 | //===- TFUtils.h - utilities for tensorflow C API ---------------*- C++ -*-===// | 
| 2 | // | ||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| 4 | // See https://llvm.org/LICENSE.txt for license information. | ||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| 6 | // | ||
| 7 | //===----------------------------------------------------------------------===// | ||
| 8 | // | ||
| 9 | #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H | ||
| 10 | #define LLVM_ANALYSIS_UTILS_TFUTILS_H | ||
| 11 | |||
| 12 | #include "llvm/Config/llvm-config.h" | ||
| 13 | |||
| 14 | #ifdef LLVM_HAVE_TFLITE | ||
| 15 | #include "llvm/ADT/StringMap.h" | ||
| 16 | #include "llvm/Analysis/TensorSpec.h" | ||
| 17 | #include "llvm/IR/LLVMContext.h" | ||
| 18 | #include "llvm/Support/JSON.h" | ||
| 19 | |||
| 20 | #include <memory> | ||
| 21 | #include <vector> | ||
| 22 | |||
| 23 | namespace llvm { | ||
| 24 | |||
| 25 | /// Load a SavedModel, find the given inputs and outputs, and setup storage | ||
| 26 | /// for input tensors. The user is responsible for correctly dimensioning the | ||
| 27 | /// input tensors and setting their values before calling evaluate(). | ||
| 28 | /// To initialize: | ||
| 29 | /// - construct the object | ||
| 30 | /// - initialize the input tensors using initInput. Indices must correspond to | ||
| 31 | ///   indices in the InputNames used at construction. | ||
| 32 | /// To use: | ||
| 33 | /// - set input values by using getInput to get each input tensor, and then | ||
| 34 | ///   setting internal scalars, for all dimensions (tensors are row-major: | ||
| 35 | ///   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205) | ||
| 36 | /// - call evaluate. The input tensors' values are not consumed after this, and | ||
| 37 | ///   may still be read. | ||
| 38 | /// - use the outputs in the output vector | ||
| 39 | class TFModelEvaluatorImpl; | ||
| 40 | class EvaluationResultImpl; | ||
| 41 | |||
| 42 | class TFModelEvaluator final { | ||
| 43 | public: | ||
| 44 |   /// The result of a model evaluation. Handles the lifetime of the output | ||
| 45 |   /// tensors, which means that their values need to be used before | ||
| 46 |   /// the EvaluationResult's dtor is called. | ||
| 47 | class EvaluationResult { | ||
| 48 | public: | ||
| 49 | EvaluationResult(const EvaluationResult &) = delete; | ||
| 50 | EvaluationResult &operator=(const EvaluationResult &Other) = delete; | ||
| 51 | |||
| 52 | EvaluationResult(EvaluationResult &&Other); | ||
| 53 | EvaluationResult &operator=(EvaluationResult &&Other); | ||
| 54 | |||
| 55 | ~EvaluationResult(); | ||
| 56 | |||
| 57 |     /// Get a (const) pointer to the first element of the tensor at Index. | ||
| 58 | template <typename T> T *getTensorValue(size_t Index) { | ||
| 59 | return static_cast<T *>(getUntypedTensorValue(Index)); | ||
| 60 |     } | ||
| 61 | |||
| 62 | template <typename T> const T *getTensorValue(size_t Index) const { | ||
| 63 | return static_cast<T *>(getUntypedTensorValue(Index)); | ||
| 64 |     } | ||
| 65 | |||
| 66 |     /// Get a (const) pointer to the untyped data of the tensor. | ||
| 67 | void *getUntypedTensorValue(size_t Index); | ||
| 68 | const void *getUntypedTensorValue(size_t Index) const; | ||
| 69 | |||
| 70 | private: | ||
| 71 | friend class TFModelEvaluator; | ||
| 72 | EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl); | ||
| 73 | std::unique_ptr<EvaluationResultImpl> Impl; | ||
| 74 | }; | ||
| 75 | |||
| 76 |   TFModelEvaluator(StringRef SavedModelPath, | ||
| 77 | const std::vector<TensorSpec> &InputSpecs, | ||
| 78 | const std::vector<TensorSpec> &OutputSpecs, | ||
| 79 | const char *Tags = "serve"); | ||
| 80 | |||
| 81 | ~TFModelEvaluator(); | ||
| 82 | TFModelEvaluator(const TFModelEvaluator &) = delete; | ||
| 83 | TFModelEvaluator(TFModelEvaluator &&) = delete; | ||
| 84 | |||
| 85 |   /// Evaluate the model, assuming it is valid. Returns std::nullopt if the | ||
| 86 |   /// evaluation fails or the model is invalid, or an EvaluationResult | ||
| 87 |   /// otherwise. The inputs are assumed to have been already provided via | ||
| 88 |   /// getInput(). When returning std::nullopt, it also invalidates this object. | ||
| 89 | std::optional<EvaluationResult> evaluate(); | ||
| 90 | |||
| 91 |   /// Provides access to the input vector. | ||
| 92 | template <typename T> T *getInput(size_t Index) { | ||
| 93 | return static_cast<T *>(getUntypedInput(Index)); | ||
| 94 |   } | ||
| 95 | |||
| 96 |   /// Returns true if the tensorflow model was loaded successfully, false | ||
| 97 |   /// otherwise. | ||
| 98 | bool isValid() const { return !!Impl; } | ||
| 99 | |||
| 100 |   /// Untyped access to input. | ||
| 101 | void *getUntypedInput(size_t Index); | ||
| 102 | |||
| 103 | private: | ||
| 104 | std::unique_ptr<TFModelEvaluatorImpl> Impl; | ||
| 105 | }; | ||
| 106 | |||
| 107 | } // namespace llvm | ||
| 108 | |||
| 109 | #endif // LLVM_HAVE_TFLITE | ||
| 110 | #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H |