1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tensor_desc.h"
17 #include "validation.h"
18 #include "common/log.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 const uint32_t BIT8_TO_BYTE = 1;
23 const uint32_t BIT16_TO_BYTE = 2;
24 const uint32_t BIT32_TO_BYTE = 4;
25 const uint32_t BIT64_TO_BYTE = 8;
26 
GetTypeSize(OH_NN_DataType type)27 uint32_t GetTypeSize(OH_NN_DataType type)
28 {
29     switch (type) {
30         case OH_NN_BOOL:
31             return sizeof(bool);
32         case OH_NN_INT8:
33         case OH_NN_UINT8:
34             return BIT8_TO_BYTE;
35         case OH_NN_INT16:
36         case OH_NN_UINT16:
37         case OH_NN_FLOAT16:
38             return BIT16_TO_BYTE;
39         case OH_NN_INT32:
40         case OH_NN_UINT32:
41         case OH_NN_FLOAT32:
42             return BIT32_TO_BYTE;
43         case OH_NN_INT64:
44         case OH_NN_UINT64:
45         case OH_NN_FLOAT64:
46             return BIT64_TO_BYTE;
47         default:
48             return 0;
49     }
50 }
51 
GetDataType(OH_NN_DataType * dataType) const52 OH_NN_ReturnCode TensorDesc::GetDataType(OH_NN_DataType* dataType) const
53 {
54     if (dataType == nullptr) {
55         LOGE("GetDataType failed, dataType is nullptr.");
56         return OH_NN_INVALID_PARAMETER;
57     }
58     *dataType = m_dataType;
59     return OH_NN_SUCCESS;
60 }
61 
SetDataType(OH_NN_DataType dataType)62 OH_NN_ReturnCode TensorDesc::SetDataType(OH_NN_DataType dataType)
63 {
64     if (!Validation::ValidateTensorDataType(dataType)) {
65         LOGE("TensorDesc::SetDataType failed, dataType %{public}d is invalid.", static_cast<int>(dataType));
66         return OH_NN_INVALID_PARAMETER;
67     }
68     m_dataType = dataType;
69     return OH_NN_SUCCESS;
70 }
71 
GetFormat(OH_NN_Format * format) const72 OH_NN_ReturnCode TensorDesc::GetFormat(OH_NN_Format* format) const
73 {
74     if (format == nullptr) {
75         LOGE("GetFormat failed, format is nullptr.");
76         return OH_NN_INVALID_PARAMETER;
77     }
78     *format = m_format;
79     return OH_NN_SUCCESS;
80 }
81 
SetFormat(OH_NN_Format format)82 OH_NN_ReturnCode TensorDesc::SetFormat(OH_NN_Format format)
83 {
84     if (!Validation::ValidateTensorFormat(format)) {
85         LOGE("TensorDesc::SetFormat failed, format %{public}d is invalid.", static_cast<int>(format));
86         return OH_NN_INVALID_PARAMETER;
87     }
88     m_format = format;
89     return OH_NN_SUCCESS;
90 }
91 
GetShape(int32_t ** shape,size_t * shapeNum) const92 OH_NN_ReturnCode TensorDesc::GetShape(int32_t** shape, size_t* shapeNum) const
93 {
94     if (shape == nullptr) {
95         LOGE("GetShape failed, shape is nullptr.");
96         return OH_NN_INVALID_PARAMETER;
97     }
98     if (*shape != nullptr) {
99         LOGE("GetShape failed, *shape is not nullptr.");
100         return OH_NN_INVALID_PARAMETER;
101     }
102     if (shapeNum == nullptr) {
103         LOGE("GetShape failed, shapeNum is nullptr.");
104         return OH_NN_INVALID_PARAMETER;
105     }
106     *shape = const_cast<int32_t*>(m_shape.data());
107     *shapeNum = m_shape.size();
108     return OH_NN_SUCCESS;
109 }
110 
SetShape(const int32_t * shape,size_t shapeNum)111 OH_NN_ReturnCode TensorDesc::SetShape(const int32_t* shape, size_t shapeNum)
112 {
113     if (shape == nullptr) {
114         LOGE("SetShape failed, shape is nullptr.");
115         return OH_NN_INVALID_PARAMETER;
116     }
117     if (shapeNum == 0) {
118         LOGE("SetShape failed, shapeNum is 0.");
119         return OH_NN_INVALID_PARAMETER;
120     }
121     m_shape.clear();
122     for (size_t i = 0; i < shapeNum; ++i) {
123         m_shape.emplace_back(shape[i]);
124     }
125     return OH_NN_SUCCESS;
126 }
127 
GetElementNum(size_t * elementNum) const128 OH_NN_ReturnCode TensorDesc::GetElementNum(size_t* elementNum) const
129 {
130     if (elementNum == nullptr) {
131         LOGE("GetElementNum failed, elementNum is nullptr.");
132         return OH_NN_INVALID_PARAMETER;
133     }
134     if (m_shape.empty()) {
135         LOGE("GetElementNum failed, shape is empty.");
136         return OH_NN_INVALID_PARAMETER;
137     }
138     *elementNum = 1;
139     size_t shapeNum = m_shape.size();
140     for (size_t i = 0; i < shapeNum; ++i) {
141         if (m_shape[i] <= 0) {
142             LOGW("GetElementNum return 0 with dynamic shape, shape[%{public}zu] is %{public}d.", i, m_shape[i]);
143             *elementNum = 0;
144             return OH_NN_DYNAMIC_SHAPE;
145         }
146         (*elementNum) *= m_shape[i];
147     }
148     return OH_NN_SUCCESS;
149 }
150 
GetByteSize(size_t * byteSize) const151 OH_NN_ReturnCode TensorDesc::GetByteSize(size_t* byteSize) const
152 {
153     if (byteSize == nullptr) {
154         LOGE("GetByteSize failed, byteSize is nullptr.");
155         return OH_NN_INVALID_PARAMETER;
156     }
157     *byteSize = 0;
158     size_t elementNum = 0;
159     auto ret = GetElementNum(&elementNum);
160     if (ret == OH_NN_DYNAMIC_SHAPE) {
161         return OH_NN_SUCCESS;
162     } else if (ret != OH_NN_SUCCESS) {
163         LOGE("GetByteSize failed, get element num failed.");
164         return ret;
165     }
166 
167     uint32_t typeSize = GetTypeSize(m_dataType);
168     if (typeSize == 0) {
169         LOGE("GetByteSize failed, data type is invalid.");
170         return OH_NN_INVALID_PARAMETER;
171     }
172 
173     *byteSize = elementNum * typeSize;
174 
175     return OH_NN_SUCCESS;
176 }
177 
SetName(const char * name)178 OH_NN_ReturnCode TensorDesc::SetName(const char* name)
179 {
180     if (name == nullptr) {
181         LOGE("SetName failed, name is nullptr.");
182         return OH_NN_INVALID_PARAMETER;
183     }
184     m_name = name;
185     return OH_NN_SUCCESS;
186 }
187 
188 // *name will be invalid after TensorDesc is destroyed
GetName(const char ** name) const189 OH_NN_ReturnCode TensorDesc::GetName(const char** name) const
190 {
191     if (name == nullptr) {
192         LOGE("GetName failed, name is nullptr.");
193         return OH_NN_INVALID_PARAMETER;
194     }
195     if (*name != nullptr) {
196         LOGE("GetName failed, *name is not nullptr.");
197         return OH_NN_INVALID_PARAMETER;
198     }
199     *name = m_name.c_str();
200     return OH_NN_SUCCESS;
201 }
202 }  // namespace NeuralNetworkRuntime
203 }  // namespace OHOS