1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "strided_slice_builder.h"
17
18 #include "mindir.h"
19
20 #include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h"
21
22 namespace OHOS {
23 namespace NeuralNetworkRuntime {
24 namespace Ops {
25 static const int INPUT_NUM = 4;
26 static const int OUTPUT_NUM = 1;
27 static const int PARAM_MAX_NUM = 5;
28 static const std::string OP_NAME = "StridedSlice";
29
StridedSliceBuilder()30 StridedSliceBuilder::StridedSliceBuilder() {}
31
~StridedSliceBuilder()32 StridedSliceBuilder::~StridedSliceBuilder() {}
33
SetInputOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode StridedSliceBuilder::SetInputOutput(const std::vector<uint32_t>& inputsIndex,
35 const std::vector<uint32_t>& outputsIndex,
36 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38 OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
39 if (returnCode != OH_NN_SUCCESS) {
40 LOGE("[StridedSliceBuilder] Passed invalid input or output index.");
41 return returnCode;
42 }
43
44 m_inputsIndex = inputsIndex;
45 m_outputsIndex = outputsIndex;
46
47 return OH_NN_SUCCESS;
48 }
49
SetBeginMask(const std::shared_ptr<NNTensor> & tensor)50 OH_NN_ReturnCode StridedSliceBuilder::SetBeginMask(const std::shared_ptr<NNTensor>& tensor)
51 {
52 if (tensor->GetDataType() != OH_NN_INT64) {
53 LOGE("[StridedSliceBuilder] The 5th input beginMask should be type HNN_INT64.");
54 return OH_NN_INVALID_PARAMETER;
55 }
56
57 void* buffer = tensor->GetBuffer();
58 if (buffer == nullptr) {
59 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
60 return OH_NN_INVALID_PARAMETER;
61 }
62 m_begin_mask = *(static_cast<int64_t*>(buffer));
63
64 return OH_NN_SUCCESS;
65 }
66
SetEndMask(const std::shared_ptr<NNTensor> & tensor)67 OH_NN_ReturnCode StridedSliceBuilder::SetEndMask(const std::shared_ptr<NNTensor>& tensor)
68 {
69 if (tensor->GetDataType() != OH_NN_INT64) {
70 LOGE("[StridedSliceBuilder] The 6th input endMask should be type HNN_INT64.");
71 return OH_NN_INVALID_PARAMETER;
72 }
73
74 void* buffer = tensor->GetBuffer();
75 if (buffer == nullptr) {
76 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
77 return OH_NN_INVALID_PARAMETER;
78 }
79 m_end_mask = *(static_cast<int64_t*>(buffer));
80
81 return OH_NN_SUCCESS;
82 }
83
SetEllipsisMask(const std::shared_ptr<NNTensor> & tensor)84 OH_NN_ReturnCode StridedSliceBuilder::SetEllipsisMask(const std::shared_ptr<NNTensor>& tensor)
85 {
86 if (tensor->GetDataType() != OH_NN_INT64) {
87 LOGE("[StridedSliceBuilder] The 7th input ellipsisMask should be type HNN_INT64.");
88 return OH_NN_INVALID_PARAMETER;
89 }
90
91 void* buffer = tensor->GetBuffer();
92 if (buffer == nullptr) {
93 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
94 return OH_NN_INVALID_PARAMETER;
95 }
96 m_ellipsis_mask = *(static_cast<int64_t*>(buffer));
97
98 return OH_NN_SUCCESS;
99 }
100
SetNewAxisMask(const std::shared_ptr<NNTensor> & tensor)101 OH_NN_ReturnCode StridedSliceBuilder::SetNewAxisMask(const std::shared_ptr<NNTensor>& tensor)
102 {
103 if (tensor->GetDataType() != OH_NN_INT64) {
104 LOGE("[StridedSliceBuilder] The 8th input newAxisMask should be type HNN_INT64.");
105 return OH_NN_INVALID_PARAMETER;
106 }
107
108 void* buffer = tensor->GetBuffer();
109 if (buffer == nullptr) {
110 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
111 return OH_NN_INVALID_PARAMETER;
112 }
113 m_new_axis_mask = *(static_cast<int64_t*>(buffer));
114
115 return OH_NN_SUCCESS;
116 }
117
SetShrinkAxisMask(const std::shared_ptr<NNTensor> & tensor)118 OH_NN_ReturnCode StridedSliceBuilder::SetShrinkAxisMask(const std::shared_ptr<NNTensor>& tensor)
119 {
120 if (tensor->GetDataType() != OH_NN_INT64) {
121 LOGE("[StridedSliceBuilder] The 9th input shrinkAxisMAsk should be type HNN_INT64.");
122 return OH_NN_INVALID_PARAMETER;
123 }
124
125 void* buffer = tensor->GetBuffer();
126 if (buffer == nullptr) {
127 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
128 return OH_NN_INVALID_PARAMETER;
129 }
130 m_shrink_axis_mask = *(static_cast<int64_t*>(buffer));
131
132 return OH_NN_SUCCESS;
133 }
134
135 /**
136 * Build method.
137 * 1.set attr of ops.
138 * 2.set inputIndex of ops.
139 * 3.set outputIndex of ops.
140 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)141 OH_NN_ReturnCode StridedSliceBuilder::Build(const std::vector<uint32_t>& paramsIndex,
142 const std::vector<uint32_t>& inputsIndex,
143 const std::vector<uint32_t>& outputsIndex,
144 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
145 {
146 if (m_isBuild) {
147 LOGE("[StridedSliceBuilder] StridedSlice operation has been build, cannot build again.");
148 return OH_NN_OPERATION_FORBIDDEN;
149 }
150
151 OH_NN_ReturnCode returnCode = SetInputOutput(inputsIndex, outputsIndex, allTensors);
152 if (returnCode != OH_NN_SUCCESS) {
153 LOGE("[StridedSliceBuilder] Set index of inputs or outputs failed.");
154 return returnCode;
155 }
156
157 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
158 if (returnCode != OH_NN_SUCCESS) {
159 LOGE("[StridedSliceBuilder] Passed invalid param index.");
160 return returnCode;
161 }
162
163 for (int i : paramsIndex) {
164 std::shared_ptr<NNTensor> tensor = allTensors[i];
165 tensor->IdentifyOpParameter();
166 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
167 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
168 } else {
169 LOGE("[StridedSliceBuilder] Build failed, param invalid, type=%d", tensor->GetType());
170 return OH_NN_INVALID_PARAMETER;
171 }
172
173 if (returnCode != OH_NN_SUCCESS) {
174 LOGE("[StridedSliceBuilder] Passed invalid param.");
175 return returnCode;
176 }
177 }
178
179 m_isBuild = true;
180 m_name = OP_NAME;
181 return OH_NN_SUCCESS;
182 }
183
GetPrimitive()184 LiteGraphPrimitvePtr StridedSliceBuilder::GetPrimitive()
185 {
186 if (!m_isBuild) {
187 LOGE("[StridedSliceBuilder] Cannot get primitive before call build.");
188 return {nullptr, DestroyLiteGraphPrimitive};
189 }
190
191 auto primitive = mindspore::lite::MindIR_StridedSlice_CreatePrimitive(m_begin_mask, m_end_mask, m_ellipsis_mask,
192 m_new_axis_mask, m_shrink_axis_mask);
193 if (primitive == nullptr) {
194 LOGE("[StridedSliceBuilder] MindIR_StridedSlice_CreatePrimitive failed.");
195 return {nullptr, DestroyLiteGraphPrimitive};
196 }
197
198 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
199 return graphPrimitivePtr;
200 }
201
202 REGISTER_OPS(StridedSliceBuilder, OH_NN_OPS_STRIDED_SLICE);
203 } // namespace Ops
204 } // namespace NeuralNetworkRuntime
205 } // namespace OHOS
206