1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "split_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int PARAM_MAX_NUM = 3;
23 static const std::string OP_NAME = "Split";
24 
SplitBuilder()25 SplitBuilder::SplitBuilder() {}
26 
~SplitBuilder()27 SplitBuilder::~SplitBuilder() {}
28 
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)29 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
30     const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
31 {
32     auto inputSize = inputsIndex.size();
33     if (inputSize != INPUT_NUM) {
34         LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
35         return OH_NN_INVALID_PARAMETER;
36     }
37 
38     auto allTensorSize = allTensors.size();
39     bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorSize](uint32_t index) {
40         return index >= allTensorSize;
41     });
42     if (isOverTensorSize) {
43         LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
44         return OH_NN_INVALID_PARAMETER;
45     }
46 
47     isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorSize](uint32_t index) {
48         return index >= allTensorSize;
49     });
50     if (isOverTensorSize) {
51         LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
52         return OH_NN_INVALID_PARAMETER;
53     }
54 
55     m_inputsIndex = inputsIndex;
56     m_outputsIndex = outputsIndex;
57 
58     // The quantization type of the first output determinies that of the operator.
59     SetQuantType(outputsIndex, allTensors);
60 
61     return OH_NN_SUCCESS;
62 }
63 
SetAxis(const std::shared_ptr<NNTensor> & tensor)64 OH_NN_ReturnCode SplitBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
65 {
66     if (tensor->GetDataType() != OH_NN_INT64) {
67         LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
68         return OH_NN_INVALID_PARAMETER;
69     }
70 
71     if (tensor->GetElementCount() != 1) {
72         LOGE("[SplitBuilder] The 4th input axis should be scaler.");
73         return OH_NN_INVALID_PARAMETER;
74     }
75 
76     void* buffer = tensor->GetBuffer();
77     if (buffer == nullptr) {
78         LOGE("[SplitBuilder] Tensor buffer is nullptr.");
79         return OH_NN_INVALID_PARAMETER;
80     }
81     m_axis = *(static_cast<const int64_t *>(buffer));
82 
83     return OH_NN_SUCCESS;
84 }
85 
SetOutputNum(const std::shared_ptr<NNTensor> & tensor)86 OH_NN_ReturnCode SplitBuilder::SetOutputNum(const std::shared_ptr<NNTensor>& tensor)
87 {
88     if (tensor->GetDataType() != OH_NN_INT64) {
89         LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
90         return OH_NN_INVALID_PARAMETER;
91     }
92 
93     if (tensor->GetElementCount() != 1) {
94         LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
95         return OH_NN_INVALID_PARAMETER;
96     }
97 
98     m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
99 
100     return OH_NN_SUCCESS;
101 }
102 
SetSizeSplits(const std::shared_ptr<NNTensor> & tensor)103 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(const std::shared_ptr<NNTensor>& tensor)
104 {
105     if (tensor->GetDataType() != OH_NN_INT64) {
106         LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
107         return OH_NN_INVALID_PARAMETER;
108     }
109 
110     const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
111     for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
112         m_size_splits.push_back(*size_splits_data_ptr++);
113     }
114 
115     return OH_NN_SUCCESS;
116 }
117 
118 /**
119  * Build method.
120  * 1.set attr of ops.
121  * 2.set inputIndex of ops.
122  * 3.set outputIndex of ops.
123  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)124 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> &paramsIndex,
125                                      const std::vector<uint32_t> &inputsIndex,
126                                      const std::vector<uint32_t> &outputsIndex,
127                                      const std::vector<std::shared_ptr<NNTensor>> &allTensors)
128 {
129     if (m_isBuild) {
130         LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
131         return OH_NN_OPERATION_FORBIDDEN;
132     }
133 
134     OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
135     if (returnCode != OH_NN_SUCCESS) {
136         LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
137         return returnCode;
138     }
139 
140     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
141     if (returnCode != OH_NN_SUCCESS) {
142         LOGE("[SplitBuilder] Build failed, passed invalid param index.");
143         return returnCode;
144     }
145 
146     for (int i : paramsIndex) {
147         std::shared_ptr<NNTensor> tensor = allTensors[i];
148         tensor->IdentifyOpParameter();
149         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
150             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
151         } else {
152             LOGE("[SplitBuilder] Build failed, param invalid, type=%d", tensor->GetType());
153             return OH_NN_INVALID_PARAMETER;
154         }
155 
156         if (returnCode != OH_NN_SUCCESS) {
157             LOGE("[SplitBuilder] Passed invalid param.");
158             return returnCode;
159         }
160     }
161 
162     m_isBuild = true;
163     m_name = OP_NAME;
164     return OH_NN_SUCCESS;
165 }
166 
GetPrimitive()167 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
168 {
169     if (!m_isBuild) {
170         LOGE("[SplitBuilder] Cannot get primitive before call build.");
171         return { nullptr, DestroyLiteGraphPrimitive };
172     }
173 
174     auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
175     if (primitive == nullptr) {
176         LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
177         return { nullptr, DestroyLiteGraphPrimitive };
178     }
179 
180     LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
181     return graphPrimitivePtr;
182 }
183 
184 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
185 } // namespace Ops
186 } // namespace NeuralNetworkRuntime
187 } // namespace OHOS
188