Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line |
|---|---|---|---|
| 14 | pmbaty | 1 | //===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===// |
| 2 | // |
||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
| 6 | // |
||
| 7 | //===----------------------------------------------------------------------===// |
||
| 8 | // |
||
| 9 | #ifndef LLVM_ANALYSIS_TENSORSPEC_H |
||
| 10 | #define LLVM_ANALYSIS_TENSORSPEC_H |
||
| 11 | |||
| 12 | #include "llvm/Config/llvm-config.h" |
||
| 13 | |||
| 14 | #include "llvm/ADT/StringMap.h" |
||
| 15 | #include "llvm/IR/LLVMContext.h" |
||
| 16 | #include "llvm/Support/JSON.h" |
||
| 17 | |||
| 18 | #include <memory> |
||
| 19 | #include <optional> |
||
| 20 | #include <vector> |
||
| 21 | |||
| 22 | namespace llvm { |
||
| 23 | /// TensorSpec encapsulates the specification of a tensor: its dimensions, or |
||
| 24 | /// "shape" (row-major), its type (see TensorSpec::getDataType specializations |
||
| 25 | /// for supported types), its name and port (see "TensorFlow: Large-Scale |
||
| 26 | /// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: |
||
| 27 | /// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) |
||
| 28 | /// |
||
| 29 | /// Known tensor types. The left part is the C type, the right is a name we |
||
| 30 | /// can use to identify the type (to implement TensorSpec equality checks), and |
||
| 31 | /// to use, if needed, when mapping to an underlying evaluator's type system. |
||
| 32 | /// The main requirement is that the C type we use has the same size and |
||
| 33 | /// encoding (e.g. endian-ness) as the one used by the evaluator. |
||
| 34 | #define SUPPORTED_TENSOR_TYPES(M) \ |
||
| 35 | M(float, Float) \ |
||
| 36 | M(double, Double) \ |
||
| 37 | M(int8_t, Int8) \ |
||
| 38 | M(uint8_t, UInt8) \ |
||
| 39 | M(int16_t, Int16) \ |
||
| 40 | M(uint16_t, UInt16) \ |
||
| 41 | M(int32_t, Int32) \ |
||
| 42 | M(uint32_t, UInt32) \ |
||
| 43 | M(int64_t, Int64) \ |
||
| 44 | M(uint64_t, UInt64) |
||
| 45 | |||
| 46 | enum class TensorType { |
||
| 47 | Invalid, |
||
| 48 | #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, |
||
| 49 | SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) |
||
| 50 | #undef _TENSOR_TYPE_ENUM_MEMBERS |
||
| 51 | Total |
||
| 52 | }; |
||
| 53 | |||
| 54 | class TensorSpec final { |
||
| 55 | public: |
||
| 56 | template <typename T> |
||
| 57 | static TensorSpec createSpec(const std::string &Name, |
||
| 58 | const std::vector<int64_t> &Shape, |
||
| 59 | int Port = 0) { |
||
| 60 | return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape); |
||
| 61 | } |
||
| 62 | |||
| 63 | const std::string &name() const { return Name; } |
||
| 64 | int port() const { return Port; } |
||
| 65 | TensorType type() const { return Type; } |
||
| 66 | const std::vector<int64_t> &shape() const { return Shape; } |
||
| 67 | |||
| 68 | bool operator==(const TensorSpec &Other) const { |
||
| 69 | return Name == Other.Name && Port == Other.Port && Type == Other.Type && |
||
| 70 | Shape == Other.Shape; |
||
| 71 | } |
||
| 72 | |||
| 73 | bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } |
||
| 74 | |||
| 75 | /// Get the number of elements in a tensor with this shape. |
||
| 76 | size_t getElementCount() const { return ElementCount; } |
||
| 77 | /// Get the size, in bytes, of one element. |
||
| 78 | size_t getElementByteSize() const { return ElementSize; } |
||
| 79 | /// Get the total size of a memory buffer needed to store the whole tensor. |
||
| 80 | size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; } |
||
| 81 | |||
| 82 | template <typename T> bool isElementType() const { |
||
| 83 | return getDataType<T>() == Type; |
||
| 84 | } |
||
| 85 | |||
| 86 | TensorSpec(const std::string &NewName, const TensorSpec &Other) |
||
| 87 | : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize, |
||
| 88 | Other.Shape) {} |
||
| 89 | |||
| 90 | void toJSON(json::OStream &OS) const; |
||
| 91 | |||
| 92 | private: |
||
| 93 | TensorSpec(const std::string &Name, int Port, TensorType Type, |
||
| 94 | size_t ElementSize, const std::vector<int64_t> &Shape); |
||
| 95 | |||
| 96 | template <typename T> static TensorType getDataType(); |
||
| 97 | |||
| 98 | std::string Name; |
||
| 99 | int Port = 0; |
||
| 100 | TensorType Type = TensorType::Invalid; |
||
| 101 | std::vector<int64_t> Shape; |
||
| 102 | size_t ElementCount = 0; |
||
| 103 | size_t ElementSize = 0; |
||
| 104 | }; |
||
| 105 | |||
| 106 | /// Construct a TensorSpec from a JSON dictionary of the form: |
||
| 107 | /// { "name": <string>, |
||
| 108 | /// "port": <int>, |
||
| 109 | /// "type": <string. Use LLVM's types, e.g. float, double, int64_t>, |
||
| 110 | /// "shape": <array of ints> } |
||
| 111 | /// For the "type" field, see the C++ primitive types used in |
||
| 112 | /// TFUTILS_SUPPORTED_TYPES. |
||
| 113 | std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, |
||
| 114 | const json::Value &Value); |
||
| 115 | |||
| 116 | #define TFUTILS_GETDATATYPE_DEF(T, Name) \ |
||
| 117 | template <> TensorType TensorSpec::getDataType<T>(); |
||
| 118 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) |
||
| 119 | |||
| 120 | #undef TFUTILS_GETDATATYPE_DEF |
||
| 121 | } // namespace llvm |
||
| 122 | |||
| 123 | #endif // LLVM_ANALYSIS_TENSORSPEC_H |