Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | #ifndef LLVM_ANALYSIS_TENSORSPEC_H |
||
10 | #define LLVM_ANALYSIS_TENSORSPEC_H |
||
11 | |||
12 | #include "llvm/Config/llvm-config.h" |
||
13 | |||
14 | #include "llvm/ADT/StringMap.h" |
||
15 | #include "llvm/IR/LLVMContext.h" |
||
16 | #include "llvm/Support/JSON.h" |
||
17 | |||
18 | #include <memory> |
||
19 | #include <optional> |
||
20 | #include <vector> |
||
21 | |||
22 | namespace llvm { |
||
23 | /// TensorSpec encapsulates the specification of a tensor: its dimensions, or |
||
24 | /// "shape" (row-major), its type (see TensorSpec::getDataType specializations |
||
25 | /// for supported types), its name and port (see "TensorFlow: Large-Scale |
||
26 | /// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: |
||
27 | /// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) |
||
28 | /// |
||
29 | /// Known tensor types. The left part is the C type, the right is a name we |
||
30 | /// can use to identify the type (to implement TensorSpec equality checks), and |
||
31 | /// to use, if needed, when mapping to an underlying evaluator's type system. |
||
32 | /// The main requirement is that the C type we use has the same size and |
||
33 | /// encoding (e.g. endian-ness) as the one used by the evaluator. |
||
34 | #define SUPPORTED_TENSOR_TYPES(M) \ |
||
35 | M(float, Float) \ |
||
36 | M(double, Double) \ |
||
37 | M(int8_t, Int8) \ |
||
38 | M(uint8_t, UInt8) \ |
||
39 | M(int16_t, Int16) \ |
||
40 | M(uint16_t, UInt16) \ |
||
41 | M(int32_t, Int32) \ |
||
42 | M(uint32_t, UInt32) \ |
||
43 | M(int64_t, Int64) \ |
||
44 | M(uint64_t, UInt64) |
||
45 | |||
46 | enum class TensorType { |
||
47 | Invalid, |
||
48 | #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, |
||
49 | SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) |
||
50 | #undef _TENSOR_TYPE_ENUM_MEMBERS |
||
51 | Total |
||
52 | }; |
||
53 | |||
54 | class TensorSpec final { |
||
55 | public: |
||
56 | template <typename T> |
||
57 | static TensorSpec createSpec(const std::string &Name, |
||
58 | const std::vector<int64_t> &Shape, |
||
59 | int Port = 0) { |
||
60 | return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape); |
||
61 | } |
||
62 | |||
63 | const std::string &name() const { return Name; } |
||
64 | int port() const { return Port; } |
||
65 | TensorType type() const { return Type; } |
||
66 | const std::vector<int64_t> &shape() const { return Shape; } |
||
67 | |||
68 | bool operator==(const TensorSpec &Other) const { |
||
69 | return Name == Other.Name && Port == Other.Port && Type == Other.Type && |
||
70 | Shape == Other.Shape; |
||
71 | } |
||
72 | |||
73 | bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } |
||
74 | |||
75 | /// Get the number of elements in a tensor with this shape. |
||
76 | size_t getElementCount() const { return ElementCount; } |
||
77 | /// Get the size, in bytes, of one element. |
||
78 | size_t getElementByteSize() const { return ElementSize; } |
||
79 | /// Get the total size of a memory buffer needed to store the whole tensor. |
||
80 | size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; } |
||
81 | |||
82 | template <typename T> bool isElementType() const { |
||
83 | return getDataType<T>() == Type; |
||
84 | } |
||
85 | |||
86 | TensorSpec(const std::string &NewName, const TensorSpec &Other) |
||
87 | : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize, |
||
88 | Other.Shape) {} |
||
89 | |||
90 | void toJSON(json::OStream &OS) const; |
||
91 | |||
92 | private: |
||
93 | TensorSpec(const std::string &Name, int Port, TensorType Type, |
||
94 | size_t ElementSize, const std::vector<int64_t> &Shape); |
||
95 | |||
96 | template <typename T> static TensorType getDataType(); |
||
97 | |||
98 | std::string Name; |
||
99 | int Port = 0; |
||
100 | TensorType Type = TensorType::Invalid; |
||
101 | std::vector<int64_t> Shape; |
||
102 | size_t ElementCount = 0; |
||
103 | size_t ElementSize = 0; |
||
104 | }; |
||
105 | |||
106 | /// Construct a TensorSpec from a JSON dictionary of the form: |
||
107 | /// { "name": <string>, |
||
108 | /// "port": <int>, |
||
109 | /// "type": <string. Use LLVM's types, e.g. float, double, int64_t>, |
||
110 | /// "shape": <array of ints> } |
||
111 | /// For the "type" field, see the C++ primitive types used in |
||
112 | /// TFUTILS_SUPPORTED_TYPES. |
||
113 | std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, |
||
114 | const json::Value &Value); |
||
115 | |||
116 | #define TFUTILS_GETDATATYPE_DEF(T, Name) \ |
||
117 | template <> TensorType TensorSpec::getDataType<T>(); |
||
118 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) |
||
119 | |||
120 | #undef TFUTILS_GETDATATYPE_DEF |
||
121 | } // namespace llvm |
||
122 | |||
123 | #endif // LLVM_ANALYSIS_TENSORSPEC_H |