1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "strided_slice_builder.h"
17 
18 #include "mindir.h"
19 
20 #include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h"
21 
22 namespace OHOS {
23 namespace NeuralNetworkRuntime {
24 namespace Ops {
25 static const int INPUT_NUM = 4;
26 static const int OUTPUT_NUM = 1;
27 static const int PARAM_MAX_NUM = 5;
28 static const std::string OP_NAME = "StridedSlice";
29 
StridedSliceBuilder()30 StridedSliceBuilder::StridedSliceBuilder() {}
31 
~StridedSliceBuilder()32 StridedSliceBuilder::~StridedSliceBuilder() {}
33 
SetInputOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode StridedSliceBuilder::SetInputOutput(const std::vector<uint32_t>& inputsIndex,
35                                                      const std::vector<uint32_t>& outputsIndex,
36                                                      const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
39     if (returnCode != OH_NN_SUCCESS) {
40         LOGE("[StridedSliceBuilder] Passed invalid input or output index.");
41         return returnCode;
42     }
43 
44     m_inputsIndex = inputsIndex;
45     m_outputsIndex = outputsIndex;
46 
47     return OH_NN_SUCCESS;
48 }
49 
SetBeginMask(const std::shared_ptr<NNTensor> & tensor)50 OH_NN_ReturnCode StridedSliceBuilder::SetBeginMask(const std::shared_ptr<NNTensor>& tensor)
51 {
52     if (tensor->GetDataType() != OH_NN_INT64) {
53         LOGE("[StridedSliceBuilder] The 5th input beginMask should be type HNN_INT64.");
54         return OH_NN_INVALID_PARAMETER;
55     }
56 
57     void* buffer = tensor->GetBuffer();
58     if (buffer == nullptr) {
59         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
60         return OH_NN_INVALID_PARAMETER;
61     }
62     m_begin_mask = *(static_cast<int64_t*>(buffer));
63 
64     return OH_NN_SUCCESS;
65 }
66 
SetEndMask(const std::shared_ptr<NNTensor> & tensor)67 OH_NN_ReturnCode StridedSliceBuilder::SetEndMask(const std::shared_ptr<NNTensor>& tensor)
68 {
69     if (tensor->GetDataType() != OH_NN_INT64) {
70         LOGE("[StridedSliceBuilder] The 6th input endMask should be type HNN_INT64.");
71         return OH_NN_INVALID_PARAMETER;
72     }
73 
74     void* buffer = tensor->GetBuffer();
75     if (buffer == nullptr) {
76         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
77         return OH_NN_INVALID_PARAMETER;
78     }
79     m_end_mask = *(static_cast<int64_t*>(buffer));
80 
81     return OH_NN_SUCCESS;
82 }
83 
SetEllipsisMask(const std::shared_ptr<NNTensor> & tensor)84 OH_NN_ReturnCode StridedSliceBuilder::SetEllipsisMask(const std::shared_ptr<NNTensor>& tensor)
85 {
86     if (tensor->GetDataType() != OH_NN_INT64) {
87         LOGE("[StridedSliceBuilder] The 7th input ellipsisMask should be type HNN_INT64.");
88         return OH_NN_INVALID_PARAMETER;
89     }
90 
91     void* buffer = tensor->GetBuffer();
92     if (buffer == nullptr) {
93         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
94         return OH_NN_INVALID_PARAMETER;
95     }
96     m_ellipsis_mask = *(static_cast<int64_t*>(buffer));
97 
98     return OH_NN_SUCCESS;
99 }
100 
SetNewAxisMask(const std::shared_ptr<NNTensor> & tensor)101 OH_NN_ReturnCode StridedSliceBuilder::SetNewAxisMask(const std::shared_ptr<NNTensor>& tensor)
102 {
103     if (tensor->GetDataType() != OH_NN_INT64) {
104         LOGE("[StridedSliceBuilder] The 8th input newAxisMask should be type HNN_INT64.");
105         return OH_NN_INVALID_PARAMETER;
106     }
107 
108     void* buffer = tensor->GetBuffer();
109     if (buffer == nullptr) {
110         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
111         return OH_NN_INVALID_PARAMETER;
112     }
113     m_new_axis_mask = *(static_cast<int64_t*>(buffer));
114 
115     return OH_NN_SUCCESS;
116 }
117 
SetShrinkAxisMask(const std::shared_ptr<NNTensor> & tensor)118 OH_NN_ReturnCode StridedSliceBuilder::SetShrinkAxisMask(const std::shared_ptr<NNTensor>& tensor)
119 {
120     if (tensor->GetDataType() != OH_NN_INT64) {
121         LOGE("[StridedSliceBuilder] The 9th input shrinkAxisMAsk should be type HNN_INT64.");
122         return OH_NN_INVALID_PARAMETER;
123     }
124 
125     void* buffer = tensor->GetBuffer();
126     if (buffer == nullptr) {
127         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
128         return OH_NN_INVALID_PARAMETER;
129     }
130     m_shrink_axis_mask = *(static_cast<int64_t*>(buffer));
131 
132     return OH_NN_SUCCESS;
133 }
134 
135 /**
136  * Build method.
137  * 1.set attr of ops.
138  * 2.set inputIndex of ops.
139  * 3.set outputIndex of ops.
140  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)141 OH_NN_ReturnCode StridedSliceBuilder::Build(const std::vector<uint32_t>& paramsIndex,
142                                             const std::vector<uint32_t>& inputsIndex,
143                                             const std::vector<uint32_t>& outputsIndex,
144                                             const std::vector<std::shared_ptr<NNTensor>>& allTensors)
145 {
146     if (m_isBuild) {
147         LOGE("[StridedSliceBuilder] StridedSlice operation has been build, cannot build again.");
148         return OH_NN_OPERATION_FORBIDDEN;
149     }
150 
151     OH_NN_ReturnCode returnCode = SetInputOutput(inputsIndex, outputsIndex, allTensors);
152     if (returnCode != OH_NN_SUCCESS) {
153         LOGE("[StridedSliceBuilder] Set index of inputs or outputs failed.");
154         return returnCode;
155     }
156 
157     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
158     if (returnCode != OH_NN_SUCCESS) {
159         LOGE("[StridedSliceBuilder] Passed invalid param index.");
160         return returnCode;
161     }
162 
163     for (int i : paramsIndex) {
164         std::shared_ptr<NNTensor> tensor = allTensors[i];
165         tensor->IdentifyOpParameter();
166         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
167             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
168         } else {
169             LOGE("[StridedSliceBuilder] Build failed, param invalid, type=%d", tensor->GetType());
170             return OH_NN_INVALID_PARAMETER;
171         }
172 
173         if (returnCode != OH_NN_SUCCESS) {
174             LOGE("[StridedSliceBuilder] Passed invalid param.");
175             return returnCode;
176         }
177     }
178 
179     m_isBuild = true;
180     m_name = OP_NAME;
181     return OH_NN_SUCCESS;
182 }
183 
GetPrimitive()184 LiteGraphPrimitvePtr StridedSliceBuilder::GetPrimitive()
185 {
186     if (!m_isBuild) {
187         LOGE("[StridedSliceBuilder] Cannot get primitive before call build.");
188         return {nullptr, DestroyLiteGraphPrimitive};
189     }
190 
191     auto primitive = mindspore::lite::MindIR_StridedSlice_CreatePrimitive(m_begin_mask, m_end_mask, m_ellipsis_mask,
192         m_new_axis_mask, m_shrink_axis_mask);
193     if (primitive == nullptr) {
194         LOGE("[StridedSliceBuilder] MindIR_StridedSlice_CreatePrimitive failed.");
195         return {nullptr, DestroyLiteGraphPrimitive};
196     }
197 
198     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
199     return graphPrimitivePtr;
200 }
201 
202 REGISTER_OPS(StridedSliceBuilder, OH_NN_OPS_STRIDED_SLICE);
203 } // namespace Ops
204 } // namespace NeuralNetworkRuntime
205 } // namespace OHOS
206