1 /* 2 * Copyright (c) 2022 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #include "sqrt_builder.h" 17 18 #include "mindir.h" 19 20 namespace OHOS { 21 namespace NeuralNetworkRuntime { 22 namespace Ops { 23 static const int INPUT_NUM = 1; 24 static const int OUTPUT_NUM = 1; 25 static const int PARAM_NUM = 0; 26 static const std::string OP_NAME = "Sqrt"; 27 SqrtBuilder()28 SqrtBuilder::SqrtBuilder() {} 29 ~SqrtBuilder()30 SqrtBuilder::~SqrtBuilder() {} 31 32 /** 33 * Build method. 34 * 1.set attr of ops. 35 * 2.set inputIndex of ops. 36 * 3.set outputIndex of ops. 37 */ Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)38 OH_NN_ReturnCode SqrtBuilder::Build(const std::vector<uint32_t>& paramsIndex, 39 const std::vector<uint32_t>& inputsIndex, 40 const std::vector<uint32_t>& outputsIndex, 41 const std::vector<std::shared_ptr<NNTensor>>& allTensors) 42 { 43 if (m_isBuild) { 44 LOGE("[SqrtBuilder] Sqrt operation has been build, cannot build again."); 45 return OH_NN_OPERATION_FORBIDDEN; 46 } 47 48 OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM); 49 if (returnCode != OH_NN_SUCCESS) { 50 LOGE("[SqrtBuilder] Passed invalid input or output index."); 51 return returnCode; 52 } 53 54 m_inputsIndex = inputsIndex; 55 m_outputsIndex = outputsIndex; 56 57 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_NUM); 58 if (returnCode != OH_NN_SUCCESS) { 59 LOGE("[SqrtBuilder] Passed invalid param index."); 60 return returnCode; 61 } 62 63 // The quantization type of the first output determinies that of the operator. 64 SetQuantType(outputsIndex, allTensors); 65 m_isBuild = true; 66 m_name = OP_NAME; 67 return OH_NN_SUCCESS; 68 } 69 GetPrimitive()70 LiteGraphTensorPtr SqrtBuilder::GetPrimitive() 71 { 72 if (!m_isBuild) { 73 LOGE("[SqrtBuilder] Cannot get primitive before call build."); 74 return {nullptr, DestroyLiteGraphPrimitive}; 75 } 76 77 auto primitive = mindspore::lite::MindIR_Sqrt_CreatePrimitive(); 78 if (primitive == nullptr) { 79 LOGE("[SqrtBuilder] Create primitive of Sqrt failed."); 80 return {nullptr, DestroyLiteGraphPrimitive}; 81 } 82 83 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive); 84 return graphPrimitivePtr; 85 } 86 87 REGISTER_OPS(SqrtBuilder, OH_NN_OPS_SQRT); 88 } // namespace Ops 89 } // namespace NeuralNetworkRuntime 90 } // namespace OHOS 91