1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "concat_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static constexpr int MINIMUM_INTPUT = 2;
22 static constexpr int OUTPUT_NUM = 1;
23 static constexpr int PARAM_MAX_NUM = 1;
24 static constexpr int AXIS_LENGTH = 1;
25 static const std::string OP_NAME = "Concat";
26 
ConcatBuilder()27 ConcatBuilder::ConcatBuilder() {}
28 
~ConcatBuilder()29 ConcatBuilder::~ConcatBuilder() {}
30 
SetAxis(const std::shared_ptr<NNTensor> & tensor)31 OH_NN_ReturnCode ConcatBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
32 {
33     tensor->IdentifyOpParameter();
34 
35     if (tensor->GetElementCount() != AXIS_LENGTH) {
36         LOGE("[Concat] SetAxis failed, the Activation shoule be a scalar");
37         return OH_NN_INVALID_PARAMETER;
38     }
39 
40     if (tensor->GetDataType() != OH_NN_INT64) {
41         LOGE("[Concat] SetAxis failed, the axis should be type OH_NN_INT64.");
42         return OH_NN_INVALID_PARAMETER;
43     }
44 
45     void* buffer = tensor->GetBuffer();
46     if (buffer == nullptr) {
47         LOGE("[Concat] SetAxis GetBuffer return nullptr.");
48         return OH_NN_INVALID_PARAMETER;
49     }
50     m_axis = *(static_cast<int64_t*>(buffer));
51 
52     return OH_NN_SUCCESS;
53 }
54 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)55 OH_NN_ReturnCode ConcatBuilder::Build(const std::vector<uint32_t>& paramsIndex,
56     const std::vector<uint32_t>& inputsIndex, const std::vector<uint32_t>& outputsIndex,
57     const std::vector<std::shared_ptr<NNTensor>>& allTensors)
58 {
59     if (m_isBuild) {
60         LOGE("[Concat] Build failed, operation has been build, cannot build again.");
61         return OH_NN_OPERATION_FORBIDDEN;
62     }
63 
64     if (inputsIndex.size() < MINIMUM_INTPUT) {
65         LOGE("[Concat] Build failed, Concat need more than one inputs.");
66         return OH_NN_INVALID_PARAMETER;
67     }
68 
69     if (outputsIndex.size() != OUTPUT_NUM) {
70         LOGE("[Concat] Build failed, The number of index of outputs not equal to 1.");
71         return OH_NN_INVALID_PARAMETER;
72     }
73 
74     OH_NN_ReturnCode returnCode = SetInputsAndOutputs(inputsIndex, outputsIndex, allTensors);
75     if (returnCode != OH_NN_SUCCESS) {
76         LOGE("[Concat] Build failed, set inputs or outputs failed.");
77         return returnCode;
78     }
79 
80     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
81     if (returnCode != OH_NN_SUCCESS) {
82         LOGE("[Concat] Build failed, passed invalid param index.");
83         return returnCode;
84     }
85 
86     for (int i : paramsIndex) {
87         std::shared_ptr<NNTensor> tensor = allTensors[i];
88         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
89             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
90         } else {
91             LOGE("[Concat] Build failed, param invalid, type=%d", tensor->GetType());
92             return OH_NN_INVALID_PARAMETER;
93         }
94 
95         if (returnCode != OH_NN_SUCCESS) {
96             LOGE("[Concat] Build failed, passed invalid param.");
97             return returnCode;
98         }
99     }
100 
101     // The quantization type of the first output determinies that of the operator.
102     SetQuantType(outputsIndex, allTensors);
103 
104     m_isBuild = true;
105     m_name = OP_NAME;
106     return OH_NN_SUCCESS;
107 }
108 
SetInputsAndOutputs(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)109 OH_NN_ReturnCode ConcatBuilder::SetInputsAndOutputs(const std::vector<uint32_t>& inputsIndex,
110                                                     const std::vector<uint32_t>& outputsIndex,
111                                                     const std::vector<std::shared_ptr<NNTensor>>& allTensors)
112 {
113     size_t allTensorsSize = allTensors.size();
114     bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
115         return index >= allTensorsSize;
116     });
117     if (isOverTensorSize) {
118         LOGE("[Concat] Invalid input index, it is out of range %zu.", allTensorsSize);
119         return OH_NN_INVALID_PARAMETER;
120     }
121 
122     isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorsSize](uint32_t index) {
123         return index >= allTensorsSize;
124     });
125     if (isOverTensorSize) {
126         LOGE("[Concat] Invalid output index, it is out of range %zu.", allTensorsSize);
127         return OH_NN_INVALID_PARAMETER;
128     }
129 
130     m_inputsIndex.clear();
131     m_inputsIndex = inputsIndex;
132 
133     m_outputsIndex.clear();
134     m_outputsIndex = outputsIndex;
135 
136     return OH_NN_SUCCESS;
137 }
138 
GetPrimitive()139 LiteGraphPrimitvePtr ConcatBuilder::GetPrimitive()
140 {
141     if (!m_isBuild) {
142         LOGE("[Concat] GetPrimitive failed, cannot get primitive before call build.");
143         return {nullptr, DestroyLiteGraphPrimitive};
144     }
145 
146     void* primitive = mindspore::lite::MindIR_Concat_CreatePrimitive(m_axis);
147     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
148     return graphPrimitivePtr;
149 }
150 
151 REGISTER_OPS(ConcatBuilder, OH_NN_OPS_CONCAT);
152 } // namespace Ops
153 } // namespace NeuralNetworkRuntime
154 } // namespace OHOS