1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tensor_desc.h"
17 #include "validation.h"
18 #include "common/log.h"
19
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 const uint32_t BIT8_TO_BYTE = 1;
23 const uint32_t BIT16_TO_BYTE = 2;
24 const uint32_t BIT32_TO_BYTE = 4;
25 const uint32_t BIT64_TO_BYTE = 8;
26
GetTypeSize(OH_NN_DataType type)27 uint32_t GetTypeSize(OH_NN_DataType type)
28 {
29 switch (type) {
30 case OH_NN_BOOL:
31 return sizeof(bool);
32 case OH_NN_INT8:
33 case OH_NN_UINT8:
34 return BIT8_TO_BYTE;
35 case OH_NN_INT16:
36 case OH_NN_UINT16:
37 case OH_NN_FLOAT16:
38 return BIT16_TO_BYTE;
39 case OH_NN_INT32:
40 case OH_NN_UINT32:
41 case OH_NN_FLOAT32:
42 return BIT32_TO_BYTE;
43 case OH_NN_INT64:
44 case OH_NN_UINT64:
45 case OH_NN_FLOAT64:
46 return BIT64_TO_BYTE;
47 default:
48 return 0;
49 }
50 }
51
GetDataType(OH_NN_DataType * dataType) const52 OH_NN_ReturnCode TensorDesc::GetDataType(OH_NN_DataType* dataType) const
53 {
54 if (dataType == nullptr) {
55 LOGE("GetDataType failed, dataType is nullptr.");
56 return OH_NN_INVALID_PARAMETER;
57 }
58 *dataType = m_dataType;
59 return OH_NN_SUCCESS;
60 }
61
SetDataType(OH_NN_DataType dataType)62 OH_NN_ReturnCode TensorDesc::SetDataType(OH_NN_DataType dataType)
63 {
64 if (!Validation::ValidateTensorDataType(dataType)) {
65 LOGE("TensorDesc::SetDataType failed, dataType %{public}d is invalid.", static_cast<int>(dataType));
66 return OH_NN_INVALID_PARAMETER;
67 }
68 m_dataType = dataType;
69 return OH_NN_SUCCESS;
70 }
71
GetFormat(OH_NN_Format * format) const72 OH_NN_ReturnCode TensorDesc::GetFormat(OH_NN_Format* format) const
73 {
74 if (format == nullptr) {
75 LOGE("GetFormat failed, format is nullptr.");
76 return OH_NN_INVALID_PARAMETER;
77 }
78 *format = m_format;
79 return OH_NN_SUCCESS;
80 }
81
SetFormat(OH_NN_Format format)82 OH_NN_ReturnCode TensorDesc::SetFormat(OH_NN_Format format)
83 {
84 if (!Validation::ValidateTensorFormat(format)) {
85 LOGE("TensorDesc::SetFormat failed, format %{public}d is invalid.", static_cast<int>(format));
86 return OH_NN_INVALID_PARAMETER;
87 }
88 m_format = format;
89 return OH_NN_SUCCESS;
90 }
91
GetShape(int32_t ** shape,size_t * shapeNum) const92 OH_NN_ReturnCode TensorDesc::GetShape(int32_t** shape, size_t* shapeNum) const
93 {
94 if (shape == nullptr) {
95 LOGE("GetShape failed, shape is nullptr.");
96 return OH_NN_INVALID_PARAMETER;
97 }
98 if (*shape != nullptr) {
99 LOGE("GetShape failed, *shape is not nullptr.");
100 return OH_NN_INVALID_PARAMETER;
101 }
102 if (shapeNum == nullptr) {
103 LOGE("GetShape failed, shapeNum is nullptr.");
104 return OH_NN_INVALID_PARAMETER;
105 }
106 *shape = const_cast<int32_t*>(m_shape.data());
107 *shapeNum = m_shape.size();
108 return OH_NN_SUCCESS;
109 }
110
SetShape(const int32_t * shape,size_t shapeNum)111 OH_NN_ReturnCode TensorDesc::SetShape(const int32_t* shape, size_t shapeNum)
112 {
113 if (shape == nullptr) {
114 LOGE("SetShape failed, shape is nullptr.");
115 return OH_NN_INVALID_PARAMETER;
116 }
117 if (shapeNum == 0) {
118 LOGE("SetShape failed, shapeNum is 0.");
119 return OH_NN_INVALID_PARAMETER;
120 }
121 m_shape.clear();
122 for (size_t i = 0; i < shapeNum; ++i) {
123 m_shape.emplace_back(shape[i]);
124 }
125 return OH_NN_SUCCESS;
126 }
127
GetElementNum(size_t * elementNum) const128 OH_NN_ReturnCode TensorDesc::GetElementNum(size_t* elementNum) const
129 {
130 if (elementNum == nullptr) {
131 LOGE("GetElementNum failed, elementNum is nullptr.");
132 return OH_NN_INVALID_PARAMETER;
133 }
134 if (m_shape.empty()) {
135 LOGE("GetElementNum failed, shape is empty.");
136 return OH_NN_INVALID_PARAMETER;
137 }
138 *elementNum = 1;
139 size_t shapeNum = m_shape.size();
140 for (size_t i = 0; i < shapeNum; ++i) {
141 if (m_shape[i] <= 0) {
142 LOGW("GetElementNum return 0 with dynamic shape, shape[%{public}zu] is %{public}d.", i, m_shape[i]);
143 *elementNum = 0;
144 return OH_NN_DYNAMIC_SHAPE;
145 }
146 (*elementNum) *= m_shape[i];
147 }
148 return OH_NN_SUCCESS;
149 }
150
GetByteSize(size_t * byteSize) const151 OH_NN_ReturnCode TensorDesc::GetByteSize(size_t* byteSize) const
152 {
153 if (byteSize == nullptr) {
154 LOGE("GetByteSize failed, byteSize is nullptr.");
155 return OH_NN_INVALID_PARAMETER;
156 }
157 *byteSize = 0;
158 size_t elementNum = 0;
159 auto ret = GetElementNum(&elementNum);
160 if (ret == OH_NN_DYNAMIC_SHAPE) {
161 return OH_NN_SUCCESS;
162 } else if (ret != OH_NN_SUCCESS) {
163 LOGE("GetByteSize failed, get element num failed.");
164 return ret;
165 }
166
167 uint32_t typeSize = GetTypeSize(m_dataType);
168 if (typeSize == 0) {
169 LOGE("GetByteSize failed, data type is invalid.");
170 return OH_NN_INVALID_PARAMETER;
171 }
172
173 *byteSize = elementNum * typeSize;
174
175 return OH_NN_SUCCESS;
176 }
177
SetName(const char * name)178 OH_NN_ReturnCode TensorDesc::SetName(const char* name)
179 {
180 if (name == nullptr) {
181 LOGE("SetName failed, name is nullptr.");
182 return OH_NN_INVALID_PARAMETER;
183 }
184 m_name = name;
185 return OH_NN_SUCCESS;
186 }
187
188 // *name will be invalid after TensorDesc is destroyed
GetName(const char ** name) const189 OH_NN_ReturnCode TensorDesc::GetName(const char** name) const
190 {
191 if (name == nullptr) {
192 LOGE("GetName failed, name is nullptr.");
193 return OH_NN_INVALID_PARAMETER;
194 }
195 if (*name != nullptr) {
196 LOGE("GetName failed, *name is not nullptr.");
197 return OH_NN_INVALID_PARAMETER;
198 }
199 *name = m_name.c_str();
200 return OH_NN_SUCCESS;
201 }
202 } // namespace NeuralNetworkRuntime
203 } // namespace OHOS