1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "stack_builder.h"
17
18 #include "mindir.h"
19
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_MIN_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const int PARAM_MAX_NUM = 1;
26 static const std::string OP_NAME = "Stack";
27
StackBuilder()28 StackBuilder::StackBuilder() {}
29
~StackBuilder()30 StackBuilder::~StackBuilder() {}
31
SetAxis(const std::shared_ptr<NNTensor> & tensor)32 OH_NN_ReturnCode StackBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
33 {
34 if (tensor->GetDataType() != OH_NN_INT64) {
35 LOGE("[StackBuilder] The last input axis should be type OH_NN_INT64.");
36 return OH_NN_INVALID_PARAMETER;
37 }
38
39 if (tensor->GetElementCount() != 1) {
40 LOGE("[StackBuilder] The last input axis should be scaler.");
41 return OH_NN_INVALID_PARAMETER;
42 }
43
44 void* buffer = tensor->GetBuffer();
45 if (buffer == nullptr) {
46 LOGE("[StackBuilder] Tensor buffer is nullptr.");
47 return OH_NN_INVALID_PARAMETER;
48 }
49 m_axis = *(static_cast<int64_t*>(buffer));
50
51 return OH_NN_SUCCESS;
52 }
53
54 /**
55 * Build method.
56 * 1.set attr of ops.
57 * 2.set inputIndex of ops.
58 * 3.set outputIndex of ops.
59 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)60 OH_NN_ReturnCode StackBuilder::Build(const std::vector<uint32_t>& paramsIndex,
61 const std::vector<uint32_t>& inputsIndex,
62 const std::vector<uint32_t>& outputsIndex,
63 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
64 {
65 if (m_isBuild) {
66 LOGE("[StackBuilder] Stack operation has been build, cannot build again.");
67 return OH_NN_OPERATION_FORBIDDEN;
68 }
69
70 if (inputsIndex.size() < INPUT_MIN_NUM) {
71 LOGE("[StackBuilder] The number of index of inputs don't larger than %d.", INPUT_MIN_NUM);
72 return OH_NN_INVALID_PARAMETER;
73 }
74 if (outputsIndex.size() != OUTPUT_NUM) {
75 LOGE("[StackBuilder] The number of index of outputs don't equal to %d.", OUTPUT_NUM);
76 return OH_NN_INVALID_PARAMETER;
77 }
78
79 m_inputsIndex = inputsIndex;
80 m_outputsIndex = outputsIndex;
81
82 auto returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
83 if (returnCode != OH_NN_SUCCESS) {
84 LOGE("[StackBuilder] Passed invalid param index.");
85 return returnCode;
86 }
87
88 for (int i : paramsIndex) {
89 std::shared_ptr<NNTensor> tensor = allTensors[i];
90 tensor->IdentifyOpParameter();
91 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
92 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
93 } else {
94 LOGE("[StackBuilder] Build failed, param invalid, type=%d", tensor->GetType());
95 return OH_NN_INVALID_PARAMETER;
96 }
97
98 if (returnCode != OH_NN_SUCCESS) {
99 LOGE("[StackBuilder] Passed invalid param.");
100 return returnCode;
101 }
102 }
103
104 m_isBuild = true;
105 m_name = OP_NAME;
106 return OH_NN_SUCCESS;
107 }
108
GetPrimitive()109 LiteGraphTensorPtr StackBuilder::GetPrimitive()
110 {
111 if (!m_isBuild) {
112 LOGE("[StackBuilder] Cannot get primitive before call build.");
113 return {nullptr, DestroyLiteGraphPrimitive};
114 }
115
116 auto primitive = mindspore::lite::MindIR_Stack_CreatePrimitive(m_axis);
117 if (primitive == nullptr) {
118 LOGE("[StackBuilder] MindIR_Stack_CreatePrimitive failed.");
119 return {nullptr, DestroyLiteGraphPrimitive};
120 }
121
122 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
123 return graphPrimitivePtr;
124 }
125
126 REGISTER_OPS(StackBuilder, OH_NN_OPS_STACK);
127 } // namespace Ops
128 } // namespace NeuralNetworkRuntime
129 } // namespace OHOS
130