1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "split_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int PARAM_MAX_NUM = 3;
23 static const std::string OP_NAME = "Split";
24
SplitBuilder()25 SplitBuilder::SplitBuilder() {}
26
~SplitBuilder()27 SplitBuilder::~SplitBuilder() {}
28
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)29 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
30 const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
31 {
32 auto inputSize = inputsIndex.size();
33 if (inputSize != INPUT_NUM) {
34 LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
35 return OH_NN_INVALID_PARAMETER;
36 }
37
38 auto allTensorSize = allTensors.size();
39 bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorSize](uint32_t index) {
40 return index >= allTensorSize;
41 });
42 if (isOverTensorSize) {
43 LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
44 return OH_NN_INVALID_PARAMETER;
45 }
46
47 isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorSize](uint32_t index) {
48 return index >= allTensorSize;
49 });
50 if (isOverTensorSize) {
51 LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
52 return OH_NN_INVALID_PARAMETER;
53 }
54
55 m_inputsIndex = inputsIndex;
56 m_outputsIndex = outputsIndex;
57
58 // The quantization type of the first output determinies that of the operator.
59 SetQuantType(outputsIndex, allTensors);
60
61 return OH_NN_SUCCESS;
62 }
63
SetAxis(const std::shared_ptr<NNTensor> & tensor)64 OH_NN_ReturnCode SplitBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
65 {
66 if (tensor->GetDataType() != OH_NN_INT64) {
67 LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
68 return OH_NN_INVALID_PARAMETER;
69 }
70
71 if (tensor->GetElementCount() != 1) {
72 LOGE("[SplitBuilder] The 4th input axis should be scaler.");
73 return OH_NN_INVALID_PARAMETER;
74 }
75
76 void* buffer = tensor->GetBuffer();
77 if (buffer == nullptr) {
78 LOGE("[SplitBuilder] Tensor buffer is nullptr.");
79 return OH_NN_INVALID_PARAMETER;
80 }
81 m_axis = *(static_cast<const int64_t *>(buffer));
82
83 return OH_NN_SUCCESS;
84 }
85
SetOutputNum(const std::shared_ptr<NNTensor> & tensor)86 OH_NN_ReturnCode SplitBuilder::SetOutputNum(const std::shared_ptr<NNTensor>& tensor)
87 {
88 if (tensor->GetDataType() != OH_NN_INT64) {
89 LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
90 return OH_NN_INVALID_PARAMETER;
91 }
92
93 if (tensor->GetElementCount() != 1) {
94 LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
95 return OH_NN_INVALID_PARAMETER;
96 }
97
98 m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
99
100 return OH_NN_SUCCESS;
101 }
102
SetSizeSplits(const std::shared_ptr<NNTensor> & tensor)103 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(const std::shared_ptr<NNTensor>& tensor)
104 {
105 if (tensor->GetDataType() != OH_NN_INT64) {
106 LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
107 return OH_NN_INVALID_PARAMETER;
108 }
109
110 const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
111 for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
112 m_size_splits.push_back(*size_splits_data_ptr++);
113 }
114
115 return OH_NN_SUCCESS;
116 }
117
118 /**
119 * Build method.
120 * 1.set attr of ops.
121 * 2.set inputIndex of ops.
122 * 3.set outputIndex of ops.
123 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)124 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> ¶msIndex,
125 const std::vector<uint32_t> &inputsIndex,
126 const std::vector<uint32_t> &outputsIndex,
127 const std::vector<std::shared_ptr<NNTensor>> &allTensors)
128 {
129 if (m_isBuild) {
130 LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
131 return OH_NN_OPERATION_FORBIDDEN;
132 }
133
134 OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
135 if (returnCode != OH_NN_SUCCESS) {
136 LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
137 return returnCode;
138 }
139
140 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
141 if (returnCode != OH_NN_SUCCESS) {
142 LOGE("[SplitBuilder] Build failed, passed invalid param index.");
143 return returnCode;
144 }
145
146 for (int i : paramsIndex) {
147 std::shared_ptr<NNTensor> tensor = allTensors[i];
148 tensor->IdentifyOpParameter();
149 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
150 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
151 } else {
152 LOGE("[SplitBuilder] Build failed, param invalid, type=%d", tensor->GetType());
153 return OH_NN_INVALID_PARAMETER;
154 }
155
156 if (returnCode != OH_NN_SUCCESS) {
157 LOGE("[SplitBuilder] Passed invalid param.");
158 return returnCode;
159 }
160 }
161
162 m_isBuild = true;
163 m_name = OP_NAME;
164 return OH_NN_SUCCESS;
165 }
166
GetPrimitive()167 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
168 {
169 if (!m_isBuild) {
170 LOGE("[SplitBuilder] Cannot get primitive before call build.");
171 return { nullptr, DestroyLiteGraphPrimitive };
172 }
173
174 auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
175 if (primitive == nullptr) {
176 LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
177 return { nullptr, DestroyLiteGraphPrimitive };
178 }
179
180 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
181 return graphPrimitivePtr;
182 }
183
184 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
185 } // namespace Ops
186 } // namespace NeuralNetworkRuntime
187 } // namespace OHOS
188