1//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation file for the abstraction of a tensor type, and JSON loading
10// utils.
11//
12//===----------------------------------------------------------------------===//
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/Config/config.h"
15
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/ADT/Twine.h"
18#include "llvm/Analysis/TensorSpec.h"
19#include "llvm/Support/CommandLine.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/JSON.h"
22#include "llvm/Support/ManagedStatic.h"
23#include "llvm/Support/raw_ostream.h"
24#include <array>
25#include <cassert>
26#include <numeric>
27
28using namespace llvm;
29
30namespace llvm {
31
32#define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
33  template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34
35SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
36
37#undef TFUTILS_GETDATATYPE_IMPL
38
39static std::array<std::string, static_cast<size_t>(TensorType::Total)>
40    TensorTypeNames{"INVALID",
41#define TFUTILS_GETNAME_IMPL(T, _) #T,
42                    SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
43#undef TFUTILS_GETNAME_IMPL
44    };
45
46StringRef toString(TensorType TT) {
47  return TensorTypeNames[static_cast<size_t>(TT)];
48}
49
50void TensorSpec::toJSON(json::OStream &OS) const {
51  OS.object([&]() {
52    OS.attribute("name", name());
53    OS.attribute("type", toString(type()));
54    OS.attribute("port", port());
55    OS.attributeArray("shape", [&]() {
56      for (size_t D : shape())
57        OS.value(static_cast<int64_t>(D));
58    });
59  });
60}
61
62TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63                       size_t ElementSize, const std::vector<int64_t> &Shape)
64    : Name(Name), Port(Port), Type(Type), Shape(Shape),
65      ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66                                   std::multiplies<int64_t>())),
67      ElementSize(ElementSize) {}
68
69std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70                                                const json::Value &Value) {
71  auto EmitError =
72      [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73    std::string S;
74    llvm::raw_string_ostream OS(S);
75    OS << Value;
76    Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
77    return std::nullopt;
78  };
79  // FIXME: accept a Path as a parameter, and use it for error reporting.
80  json::Path::Root Root("tensor_spec");
81  json::ObjectMapper Mapper(Value, Root);
82  if (!Mapper)
83    return EmitError("Value is not a dict");
84
85  std::string TensorName;
86  int TensorPort = -1;
87  std::string TensorType;
88  std::vector<int64_t> TensorShape;
89
90  if (!Mapper.map<std::string>("name", TensorName))
91    return EmitError("'name' property not present or not a string");
92  if (!Mapper.map<std::string>("type", TensorType))
93    return EmitError("'type' property not present or not a string");
94  if (!Mapper.map<int>("port", TensorPort))
95    return EmitError("'port' property not present or not an int");
96  if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97    return EmitError("'shape' property not present or not an int array");
98
99#define PARSE_TYPE(T, E)                                                       \
100  if (TensorType == #T)                                                        \
101    return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102  SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
103#undef PARSE_TYPE
104  return std::nullopt;
105}
106
107std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
108  switch (Spec.type()) {
109#define _IMR_DBG_PRINTER(T, N)                                                 \
110  case TensorType::N: {                                                        \
111    const T *TypedBuff = reinterpret_cast<const T *>(Buffer);                  \
112    auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount());  \
113    return llvm::join(                                                         \
114        llvm::map_range(R, [](T V) { return std::to_string(V); }), ",");       \
115  }
116    SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
117#undef _IMR_DBG_PRINTER
118  case TensorType::Total:
119  case TensorType::Invalid:
120    llvm_unreachable("invalid tensor type");
121  }
122  // To appease warnings about not all control paths returning a value.
123  return "";
124}
125
126} // namespace llvm
127