1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef NEURAL_NETWORK_RUNTIME_NNCOMPILER_H
17 #define NEURAL_NETWORK_RUNTIME_NNCOMPILER_H
18 
19 #include "compiler.h"
20 
21 #include "mindir.h"
22 #include "device.h"
23 #include "inner_model.h"
24 #include "prepared_model.h"
25 #include "nnexecutor.h"
26 
27 namespace OHOS {
28 namespace NeuralNetworkRuntime {
29 
30 class NNCompiler : public Compiler {
31 public:
32     NNCompiler() = delete;
33     NNCompiler(std::shared_ptr<Device> device, size_t backendID);
34     NNCompiler(const void* model, std::shared_ptr<Device> device, size_t backendID);
35     ~NNCompiler() override;
36 
37     size_t GetBackendID() const override;
38 
39     OH_NN_ReturnCode SetCacheDir(const std::string& cacheModelPath, uint32_t version) override;
40     OH_NN_ReturnCode SetPerformance(OH_NN_PerformanceMode performance) override;
41     OH_NN_ReturnCode SetPriority(OH_NN_Priority priority) override;
42     OH_NN_ReturnCode SetEnableFp16(bool isFp16) override;
43 
44     bool IsBuild() const override;
45     OH_NN_ReturnCode Build() override;
46 
47     OH_NN_ReturnCode SaveToCacheFile() const override;
48     OH_NN_ReturnCode RestoreFromCacheFile() override;
49     OH_NN_ReturnCode SaveToCacheBuffer(const void* buffer, size_t length, size_t* modelSize) const override;
50     OH_NN_ReturnCode RestoreFromCacheBuffer(const void* buffer, size_t length) override;
51 
52     OH_NN_ReturnCode SetExtensionConfig(const std::unordered_map<std::string, std::vector<char>>& configs) override;
53     OH_NN_ReturnCode SetOptions(const std::vector<std::shared_ptr<void>>& options) override;
54     OH_NN_ReturnCode GetModelName(std::string& modelName) override;
55 
56     NNExecutor* CreateExecutor();
57 
58 private:
59     void ReleaseBuffer(std::vector<Buffer>& buffers) const;
60     void ReleaseBufferByDevice(std::vector<Buffer>& buffers) const;
61     OH_NN_ReturnCode SerializeTensorsToBuffer(
62         const std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>>& tensorDescs,
63         Buffer& buffer) const;
64     OH_NN_ReturnCode DeserializedTensorsFromBuffer(
65         const Buffer& buffer, std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>>& tensorDescs);
66 
67     OH_NN_ReturnCode OnlineBuild();
68     OH_NN_ReturnCode NormalBuild();
69     OH_NN_ReturnCode BuildOfflineModel();
70     OH_NN_ReturnCode CheckModelParameter() const;
71     OH_NN_ReturnCode IsOfflineModel(bool& isOfflineModel) const;
72     OH_NN_ReturnCode IsSupportedModel(const std::shared_ptr<mindspore::lite::LiteGraph>& liteGraph,
73                                       bool& isSupportedModel) const;
74 
75 private:
76     bool m_isBuild {false};
77     bool m_enableFp16 {false};
78     std::string m_cachePath;
79     uint32_t m_cacheVersion {0};
80     std::shared_ptr<Device> m_device {nullptr};
81     size_t m_backendID {0};
82     OH_NN_Priority m_priority {OH_NN_PRIORITY_NONE};
83     OH_NN_PerformanceMode m_performance {OH_NN_PERFORMANCE_NONE};
84     std::shared_ptr<PreparedModel> m_preparedModel {nullptr};
85     void* m_metaGraph {nullptr};
86     InnerModel* m_innerModel {nullptr};
87     std::shared_ptr<mindspore::lite::LiteGraph> m_liteGraph {nullptr};
88     std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>> m_inputTensorDescs;
89     std::vector<std::pair<std::shared_ptr<TensorDesc>, OH_NN_TensorType>> m_outputTensorDescs;
90     ExtensionConfig m_extensionConfig;
91 };
92 } // NeuralNetworkRuntime
93 } // OHOS
94 
95 #endif // NEURAL_NETWORK_RUNTIME_NNCOMPILER_H