1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_NNRT_DELEGATE_KERNEL_H
17 #define TENSORFLOW_LITE_DELEGATES_NNRT_DELEGATE_KERNEL_H
18 
19 
20 #include "neural_network_runtime.h"
21 #include "tensorflow/lite/c/common.h"
22 
23 #include "tensor_mapping.h"
24 #include "nnrt_op_builder.h"
25 
26 namespace tflite {
27 namespace delegate {
28 namespace nnrt {
29 
30 // Represents a subgraph in TFLite that will be delegated to NNRt.
31 // It is abstracted as a single kernel node in the main TFLite graph and
32 // implements Init/Prepare/Invoke as TFLite kernel nodes.
33 class NnrtDelegateKernel {
34 public:
NnrtDelegateKernel(const NnrtApi * nnrt)35     explicit NnrtDelegateKernel(const NnrtApi* nnrt)
36         : m_initialised(false),
37           m_compiled(false),
38           m_nnrtDevice{0},
39           m_nnrt(nnrt),
40           m_nnModel(nullptr),
41           m_pNnCompilation(nullptr) {}
42 
NnrtDelegateKernel()43     NnrtDelegateKernel() : NnrtDelegateKernel(NnrtImplementation()) {}
~NnrtDelegateKernel()44     virtual ~NnrtDelegateKernel()
45     {
46         m_nnrt->OH_NNModel_Destroy(&m_nnModel);
47         m_nnrt->OH_NNCompilation_Destroy(&m_pNnCompilation);
48         m_nnrt = nullptr;
49     }
50 
51     // Returns true if the node can be accelerated with NNRT.
52     static bool Validate(const int32_t builtinCode);
53 
54     // Initialize the kernel (a NN model) and builds the NN Model.
55     TfLiteStatus Init(TfLiteContext* context, const TfLiteDelegateParams* params);
56 
57     // Creates the NNRT Compilation for the NN model. It assumes that Init has
58     // been called and completed successfully.
59     TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node);
60 
61     // Invoke the NN Model. Expects Init and Prepare to have been completed successfully.
62     TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node);
63 
64 private:
65     TfLiteStatus Map(int32_t builtinCode, const NnrtOpMappingArgs& mappingArgs, int32_t& nnOpType) const;
66     TfLiteStatus AddOpsAndTensors(TfLiteContext* context, const TfLiteIntArray* inputTensors,
67         const NnrtDelegate::Options& delegateOptions);
68     TfLiteStatus BuildGraph(TfLiteContext* context, const NnrtDelegate::Options& options,
69         const TfLiteIntArray* inputTensors, const TfLiteIntArray* outputTensors);
70     TfLiteStatus ConvertTensorTypeToNn(TfLiteContext* context, const std::pair<int32_t, int32_t>& indexPair,
71         OH_NN_QuantParam* nnQuantParam, OH_NN_Tensor& nnTensor);
72     TfLiteStatus SetInputTensors(TfLiteContext* context, TfLiteNode* node, OH_NNExecutor* pNnExecution,
73         OH_NN_Tensor& nnTensor);
74     TfLiteStatus SetOutputTensors(TfLiteContext* context, TfLiteNode* node, OH_NNExecutor* pNnExecution);
75     TfLiteStatus SetNnOptions(TfLiteContext* context, const NnrtDelegate::Options& delegateOptions);
76 
77 private:
78     // True if initialization has been completed successfully
79     bool m_initialised;
80 
81     // True if compilation has been completed successfully
82     bool m_compiled;
83 
84     // NN device handle.
85     size_t m_nnrtDevice;
86 
87     // Access to NNRT.
88     const NnrtApi* m_nnrt;
89 
90     // NN API state.
91     OH_NNModel* m_nnModel;
92     OH_NNCompilation* m_pNnCompilation;
93 
94     // Node indices that this delegate is responsible for. Indices here
95     // indexes into the nodes array in the TfLiteContext.
96     std::vector<int32_t> m_delegateNodes;
97 
98     // Track indices we use
99     TensorMapping m_tensorMapping;
100 };
101 } // namespace nnrt
102 } // namespace delegate
103 } // namespace tflite
104 
105 #endif // TENSORFLOW_LITE_DELEGATES_NNRT_DELEGATE_KERNEL_H
106