1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "broadcast_to_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int OUTPUT_NUM = 1;
23 static const int PARAM_MAX_NUM = 1;
24 static const std::string OP_NAME = "BroadcastTo";
25 
BroadcastToBuilder()26 BroadcastToBuilder::BroadcastToBuilder() {}
27 
~BroadcastToBuilder()28 BroadcastToBuilder::~BroadcastToBuilder() {}
29 
SetShape(const std::shared_ptr<NNTensor> & tensor)30 OH_NN_ReturnCode BroadcastToBuilder::SetShape(const std::shared_ptr<NNTensor>& tensor)
31 {
32     if (tensor->GetDataType() != OH_NN_INT64) {
33         LOGE("[BroadcastTo] The shape should be type OH_NN_INT64.");
34         return OH_NN_INVALID_PARAMETER;
35     }
36 
37     m_shape.clear();
38 
39     void* buffer = tensor->GetBuffer();
40     if (buffer == nullptr) {
41         LOGE("[BroadcastTo] Tensor buffer is nullptr.");
42         return OH_NN_INVALID_PARAMETER;
43     }
44 
45     int64_t* pShape = static_cast<int64_t*>(buffer);
46 
47     uint32_t elementCount = tensor->GetElementCount();
48     for (uint32_t i = 0; i < elementCount; ++i) {
49         m_shape.emplace_back(*pShape);
50         ++pShape;
51     }
52     return OH_NN_SUCCESS;
53 }
54 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)55 OH_NN_ReturnCode BroadcastToBuilder::Build(const std::vector<uint32_t>& paramsIndex,
56                                            const std::vector<uint32_t>& inputsIndex,
57                                            const std::vector<uint32_t>& outputsIndex,
58                                            const std::vector<std::shared_ptr<NNTensor>>& allTensors)
59 {
60     if (m_isBuild) {
61         LOGE("[BroadcastTo] Build failed, the broadcastTo operation has been build. cannot build again.");
62         return OH_NN_OPERATION_FORBIDDEN;
63     }
64 
65     auto ret = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
66     if (ret != OH_NN_SUCCESS) {
67         LOGE("[BroadcastTo] Build failed, passed invalid input or output index.");
68         return ret;
69     }
70 
71     m_inputsIndex = inputsIndex;
72     m_outputsIndex = outputsIndex;
73 
74     ret = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
75     if (ret != OH_NN_SUCCESS) {
76         LOGE("[BroadcastTo] Build failed, passed invalid param index.");
77         return ret;
78     }
79 
80     for (int i : paramsIndex) {
81         std::shared_ptr<NNTensor> tensor = allTensors[i];
82         tensor->IdentifyOpParameter();
83         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
84             ret = (this->*(m_paramMap[tensor->GetType()]))(tensor);
85         } else {
86             LOGE("[BroadcastTo] Build failed, param invalid, type=%d", tensor->GetType());
87             return OH_NN_INVALID_PARAMETER;
88         }
89 
90         if (ret != OH_NN_SUCCESS) {
91             LOGE("[BroadcastTo] Build failed, passed invalid param.");
92             return ret;
93         }
94     }
95 
96     m_name = OP_NAME;
97     m_isBuild = true;
98     return OH_NN_SUCCESS;
99 }
100 
GetPrimitive()101 LiteGraphPrimitvePtr BroadcastToBuilder::GetPrimitive()
102 {
103     if (!m_isBuild) {
104         LOGE("[BroadcastTo] GetPrimitive failed, cannot get primitive before call build.");
105         return {nullptr, DestroyLiteGraphPrimitive};
106     }
107 
108     void* primitive = mindspore::lite::MindIR_BroadcastTo_CreatePrimitive(m_shape);
109     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
110     return graphPrimitivePtr;
111 }
112 
113 REGISTER_OPS(BroadcastToBuilder, OH_NN_OPS_BROADCAST_TO);
114 } // namespace Ops
115 } // namespace NeuralNetworkRuntime
116 } // namespace OHOS
117