1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef KEYWORD_SPOTTING_PLUGIN_H
17 #define KEYWORD_SPOTTING_PLUGIN_H
18 
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <mutex>
23 #include <vector>
24 #include <string>
25 
26 #include "ai_datatype.h"
27 #include "engine_adapter.h"
28 #include "feature_processor.h"
29 #include "keyword_spotting/kws_constants.h"
30 #include "plugin_helper.h"
31 #include "plugin/i_plugin.h"
32 
33 namespace OHOS {
34 namespace AI {
35 struct KWSWorkplace {
36     PluginConfig config;
37     std::shared_ptr<Feature::FeatureProcessor> normProcessor;
38     std::shared_ptr<Feature::FeatureProcessor> typeConverter;
39     std::shared_ptr<Feature::FeatureProcessor> slideProcessor;
40 };
41 
42 class KWSPlugin : public IPlugin {
43 public:
44     KWSPlugin();
45     ~KWSPlugin();
46 
47     const long long GetVersion() const override;
48     const char *GetName() const override;
49     const char *GetInferMode() const override;
50 
51     int32_t Prepare(long long transactionId, const DataInfo &inputInfo, DataInfo &outputInfo) override;
52     int32_t SetOption(int optionType, const DataInfo &inputInfo) override;
53     int32_t GetOption(int optionType, const DataInfo &inputInfo, DataInfo &outputInfo) override;
54     int32_t SyncProcess(IRequest *request, IResponse *&response) override;
55     int32_t AsyncProcess(IRequest *request, IPluginCallback *callback) override;
56     int32_t Release(bool isFullUnload, long long transactionId, const DataInfo &inputInfo) override;
57 
58 private:
59     int32_t InitComponents(KWSWorkplace &workplace);
60     int32_t GetNormedFeatures(const Array<uint16_t> &input, Array<int32_t> &output, const KWSWorkplace &worker);
61     int32_t BuildConfig(intptr_t handle, PluginConfig &config);
62     int32_t MakeInference(intptr_t handle, Array<int32_t> &input, PluginConfig &config, DataInfo &outputInfo);
63     void FreeHandle(intptr_t handle);
64     void ReleaseAllHandles();
65 
66 private:
67     std::shared_ptr<EngineAdapter> adapter_;
68     std::mutex mutex_;
69     std::map<intptr_t, KWSWorkplace> handles_;
70 };
71 }  // namespace AI
72 }  // namespace OHOS
73 #endif  // KEYWORD_SPOTTING_PLUGIN_H