1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "stack_builder.h"
17 
18 #include "mindir.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_MIN_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const int PARAM_MAX_NUM = 1;
26 static const std::string OP_NAME = "Stack";
27 
StackBuilder()28 StackBuilder::StackBuilder() {}
29 
~StackBuilder()30 StackBuilder::~StackBuilder() {}
31 
SetAxis(const std::shared_ptr<NNTensor> & tensor)32 OH_NN_ReturnCode StackBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
33 {
34     if (tensor->GetDataType() != OH_NN_INT64) {
35         LOGE("[StackBuilder] The last input axis should be type OH_NN_INT64.");
36         return OH_NN_INVALID_PARAMETER;
37     }
38 
39     if (tensor->GetElementCount() != 1) {
40         LOGE("[StackBuilder] The last input axis should be scaler.");
41         return OH_NN_INVALID_PARAMETER;
42     }
43 
44     void* buffer = tensor->GetBuffer();
45     if (buffer == nullptr) {
46         LOGE("[StackBuilder] Tensor buffer is nullptr.");
47         return OH_NN_INVALID_PARAMETER;
48     }
49     m_axis = *(static_cast<int64_t*>(buffer));
50 
51     return OH_NN_SUCCESS;
52 }
53 
54 /**
55  * Build method.
56  * 1.set attr of ops.
57  * 2.set inputIndex of ops.
58  * 3.set outputIndex of ops.
59  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)60 OH_NN_ReturnCode StackBuilder::Build(const std::vector<uint32_t>& paramsIndex,
61                                      const std::vector<uint32_t>& inputsIndex,
62                                      const std::vector<uint32_t>& outputsIndex,
63                                      const std::vector<std::shared_ptr<NNTensor>>& allTensors)
64 {
65     if (m_isBuild) {
66         LOGE("[StackBuilder] Stack operation has been build, cannot build again.");
67         return OH_NN_OPERATION_FORBIDDEN;
68     }
69 
70     if (inputsIndex.size() < INPUT_MIN_NUM) {
71         LOGE("[StackBuilder] The number of index of inputs don't larger than %d.", INPUT_MIN_NUM);
72         return OH_NN_INVALID_PARAMETER;
73     }
74     if (outputsIndex.size() != OUTPUT_NUM) {
75         LOGE("[StackBuilder] The number of index of outputs don't equal to %d.", OUTPUT_NUM);
76         return OH_NN_INVALID_PARAMETER;
77     }
78 
79     m_inputsIndex = inputsIndex;
80     m_outputsIndex = outputsIndex;
81 
82     auto returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
83     if (returnCode != OH_NN_SUCCESS) {
84         LOGE("[StackBuilder] Passed invalid param index.");
85         return returnCode;
86     }
87 
88     for (int i : paramsIndex) {
89         std::shared_ptr<NNTensor> tensor = allTensors[i];
90         tensor->IdentifyOpParameter();
91         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
92             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
93         } else {
94             LOGE("[StackBuilder] Build failed, param invalid, type=%d", tensor->GetType());
95             return OH_NN_INVALID_PARAMETER;
96         }
97 
98         if (returnCode != OH_NN_SUCCESS) {
99             LOGE("[StackBuilder] Passed invalid param.");
100             return returnCode;
101         }
102     }
103 
104     m_isBuild = true;
105     m_name = OP_NAME;
106     return OH_NN_SUCCESS;
107 }
108 
GetPrimitive()109 LiteGraphTensorPtr StackBuilder::GetPrimitive()
110 {
111     if (!m_isBuild) {
112         LOGE("[StackBuilder] Cannot get primitive before call build.");
113         return {nullptr, DestroyLiteGraphPrimitive};
114     }
115 
116     auto primitive = mindspore::lite::MindIR_Stack_CreatePrimitive(m_axis);
117     if (primitive == nullptr) {
118         LOGE("[StackBuilder] MindIR_Stack_CreatePrimitive failed.");
119         return {nullptr, DestroyLiteGraphPrimitive};
120     }
121 
122     LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
123     return graphPrimitivePtr;
124 }
125 
126 REGISTER_OPS(StackBuilder, OH_NN_OPS_STACK);
127 } // namespace Ops
128 } // namespace NeuralNetworkRuntime
129 } // namespace OHOS
130