1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "concat_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static constexpr int MINIMUM_INTPUT = 2;
22 static constexpr int OUTPUT_NUM = 1;
23 static constexpr int PARAM_MAX_NUM = 1;
24 static constexpr int AXIS_LENGTH = 1;
25 static const std::string OP_NAME = "Concat";
26
ConcatBuilder()27 ConcatBuilder::ConcatBuilder() {}
28
~ConcatBuilder()29 ConcatBuilder::~ConcatBuilder() {}
30
SetAxis(const std::shared_ptr<NNTensor> & tensor)31 OH_NN_ReturnCode ConcatBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
32 {
33 tensor->IdentifyOpParameter();
34
35 if (tensor->GetElementCount() != AXIS_LENGTH) {
36 LOGE("[Concat] SetAxis failed, the Activation shoule be a scalar");
37 return OH_NN_INVALID_PARAMETER;
38 }
39
40 if (tensor->GetDataType() != OH_NN_INT64) {
41 LOGE("[Concat] SetAxis failed, the axis should be type OH_NN_INT64.");
42 return OH_NN_INVALID_PARAMETER;
43 }
44
45 void* buffer = tensor->GetBuffer();
46 if (buffer == nullptr) {
47 LOGE("[Concat] SetAxis GetBuffer return nullptr.");
48 return OH_NN_INVALID_PARAMETER;
49 }
50 m_axis = *(static_cast<int64_t*>(buffer));
51
52 return OH_NN_SUCCESS;
53 }
54
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)55 OH_NN_ReturnCode ConcatBuilder::Build(const std::vector<uint32_t>& paramsIndex,
56 const std::vector<uint32_t>& inputsIndex, const std::vector<uint32_t>& outputsIndex,
57 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
58 {
59 if (m_isBuild) {
60 LOGE("[Concat] Build failed, operation has been build, cannot build again.");
61 return OH_NN_OPERATION_FORBIDDEN;
62 }
63
64 if (inputsIndex.size() < MINIMUM_INTPUT) {
65 LOGE("[Concat] Build failed, Concat need more than one inputs.");
66 return OH_NN_INVALID_PARAMETER;
67 }
68
69 if (outputsIndex.size() != OUTPUT_NUM) {
70 LOGE("[Concat] Build failed, The number of index of outputs not equal to 1.");
71 return OH_NN_INVALID_PARAMETER;
72 }
73
74 OH_NN_ReturnCode returnCode = SetInputsAndOutputs(inputsIndex, outputsIndex, allTensors);
75 if (returnCode != OH_NN_SUCCESS) {
76 LOGE("[Concat] Build failed, set inputs or outputs failed.");
77 return returnCode;
78 }
79
80 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
81 if (returnCode != OH_NN_SUCCESS) {
82 LOGE("[Concat] Build failed, passed invalid param index.");
83 return returnCode;
84 }
85
86 for (int i : paramsIndex) {
87 std::shared_ptr<NNTensor> tensor = allTensors[i];
88 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
89 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
90 } else {
91 LOGE("[Concat] Build failed, param invalid, type=%d", tensor->GetType());
92 return OH_NN_INVALID_PARAMETER;
93 }
94
95 if (returnCode != OH_NN_SUCCESS) {
96 LOGE("[Concat] Build failed, passed invalid param.");
97 return returnCode;
98 }
99 }
100
101 // The quantization type of the first output determinies that of the operator.
102 SetQuantType(outputsIndex, allTensors);
103
104 m_isBuild = true;
105 m_name = OP_NAME;
106 return OH_NN_SUCCESS;
107 }
108
SetInputsAndOutputs(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)109 OH_NN_ReturnCode ConcatBuilder::SetInputsAndOutputs(const std::vector<uint32_t>& inputsIndex,
110 const std::vector<uint32_t>& outputsIndex,
111 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
112 {
113 size_t allTensorsSize = allTensors.size();
114 bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
115 return index >= allTensorsSize;
116 });
117 if (isOverTensorSize) {
118 LOGE("[Concat] Invalid input index, it is out of range %zu.", allTensorsSize);
119 return OH_NN_INVALID_PARAMETER;
120 }
121
122 isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorsSize](uint32_t index) {
123 return index >= allTensorsSize;
124 });
125 if (isOverTensorSize) {
126 LOGE("[Concat] Invalid output index, it is out of range %zu.", allTensorsSize);
127 return OH_NN_INVALID_PARAMETER;
128 }
129
130 m_inputsIndex.clear();
131 m_inputsIndex = inputsIndex;
132
133 m_outputsIndex.clear();
134 m_outputsIndex = outputsIndex;
135
136 return OH_NN_SUCCESS;
137 }
138
GetPrimitive()139 LiteGraphPrimitvePtr ConcatBuilder::GetPrimitive()
140 {
141 if (!m_isBuild) {
142 LOGE("[Concat] GetPrimitive failed, cannot get primitive before call build.");
143 return {nullptr, DestroyLiteGraphPrimitive};
144 }
145
146 void* primitive = mindspore::lite::MindIR_Concat_CreatePrimitive(m_axis);
147 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
148 return graphPrimitivePtr;
149 }
150
151 REGISTER_OPS(ConcatBuilder, OH_NN_OPS_CONCAT);
152 } // namespace Ops
153 } // namespace NeuralNetworkRuntime
154 } // namespace OHOS