1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "trigger_db_helper.h"
17 
18 #include <string>
19 
20 #include "rdb_errno.h"
21 #include "rdb_helper.h"
22 #include "rdb_open_callback.h"
23 #include "intell_voice_log.h"
24 #include "intell_voice_service_manager.h"
25 
26 #define LOG_TAG "TriggerDbHelper"
27 
28 using namespace OHOS::NativeRdb;
29 
30 namespace OHOS {
31 namespace IntellVoiceTrigger {
32 enum {
33     VERSION_ADD_MODEL_TYPE = 2,
34 };
35 static const std::string TABLE_NAME = "trigger";
36 
37 class TriggerModelOpenCallback : public RdbOpenCallback {
38 public:
39     int OnCreate(RdbStore &rdbStore) override;
40     int OnUpgrade(RdbStore &rdbStore, int oldVersion, int newVersion) override;
41 private:
42     static void VersionAddModelType(RdbStore &store);
43 };
44 
OnCreate(RdbStore & rdbStore)45 int TriggerModelOpenCallback::OnCreate(RdbStore &rdbStore)
46 {
47     INTELL_VOICE_LOG_INFO("enter");
48     const std::string CREATE_TABLE_Trigger = "CREATE TABLE IF NOT EXISTS " + TABLE_NAME +
49         " (model_uuid INTEGER PRIMARY KEY, vendor_uuid INTEGER, data BLOB, model_version INTEGER)";
50 
51     int32_t result = rdbStore.ExecuteSql(CREATE_TABLE_Trigger);
52     if (result != NativeRdb::E_OK) {
53         INTELL_VOICE_LOG_ERROR("create table failed, ret:%{public}d", result);
54         return result;
55     }
56 
57     VersionAddModelType(rdbStore);
58     return NativeRdb::E_OK;
59 }
60 
OnUpgrade(RdbStore & rdbStore,int oldVersion,int newVersion)61 int TriggerModelOpenCallback::OnUpgrade(RdbStore &rdbStore, int oldVersion, int newVersion)
62 {
63     INTELL_VOICE_LOG_INFO("enter, oldVersion:%{public}d, newVersion:%{public}d", oldVersion, newVersion);
64     if (oldVersion < VERSION_ADD_MODEL_TYPE) {
65         VersionAddModelType(rdbStore);
66     }
67     return E_OK;
68 }
69 
VersionAddModelType(RdbStore & store)70 void TriggerModelOpenCallback::VersionAddModelType(RdbStore &store)
71 {
72     const std::string alterModelType = "ALTER TABLE " + TABLE_NAME +
73         " ADD COLUMN " + "model_type" + " INTEGER";
74     int32_t result = store.ExecuteSql(alterModelType);
75     if (result != NativeRdb::E_OK) {
76         INTELL_VOICE_LOG_WARN("Upgrade rbd model type failed, ret:%{public}d", result);
77     }
78 }
79 
TriggerDbHelper()80 TriggerDbHelper::TriggerDbHelper()
81 {
82     int errCode = E_OK;
83     RdbStoreConfig config("/data/service/el1/public/database/intell_voice_service_manager/triggerModel.db");
84     TriggerModelOpenCallback helper;
85     store_ = RdbHelper::GetRdbStore(config, VERSION_ADD_MODEL_TYPE, helper, errCode);
86     if (store_ == nullptr) {
87         INTELL_VOICE_LOG_ERROR("store is nullptr");
88     }
89 }
90 
~TriggerDbHelper()91 TriggerDbHelper::~TriggerDbHelper()
92 {
93     store_ = nullptr;
94 }
95 
UpdateGenericTriggerModel(std::shared_ptr<GenericTriggerModel> model)96 bool TriggerDbHelper::UpdateGenericTriggerModel(std::shared_ptr<GenericTriggerModel> model)
97 {
98     std::lock_guard<std::mutex> lock(mutex_);
99     INTELL_VOICE_LOG_INFO("enter");
100 
101     if (store_ == nullptr) {
102         INTELL_VOICE_LOG_ERROR("store is nullptr");
103         return false;
104     }
105 
106     if (model == nullptr) {
107         INTELL_VOICE_LOG_ERROR("model is nullptr");
108         return false;
109     }
110 
111     model->Print();
112     int64_t rowId = -1;
113     ValuesBucket values;
114     values.PutInt("model_uuid", model->GetUuid());
115     values.PutInt("vendor_uuid", model->GetVendorUuid());
116     values.PutBlob("data", model->GetData());
117     values.PutInt("model_version", model->GetVersion());
118     values.PutInt("model_type", model->GetType());
119     int ret = store_->InsertWithConflictResolution(rowId, TABLE_NAME, values, ConflictResolution::ON_CONFLICT_REPLACE);
120     if (ret != E_OK) {
121         INTELL_VOICE_LOG_ERROR("update generic model failed");
122         return false;
123     }
124     return true;
125 }
126 
GetVendorUuid(std::shared_ptr<AbsSharedResultSet> & set,int32_t & vendorUuid) const127 bool TriggerDbHelper::GetVendorUuid(std::shared_ptr<AbsSharedResultSet> &set, int32_t &vendorUuid) const
128 {
129     int columnIndex;
130     int ret = set->GetColumnIndex("vendor_uuid", columnIndex);
131     if (ret != E_OK) {
132         INTELL_VOICE_LOG_ERROR("failed to get model uuid column index, ret:%{public}d", ret);
133         return false;
134     }
135 
136     ret = set->GetInt(columnIndex, vendorUuid);
137     if (ret != E_OK) {
138         INTELL_VOICE_LOG_ERROR("failed to get vendor uuid, ret:%{public}d", ret);
139         return false;
140     }
141     return true;
142 }
143 
GetBlob(std::shared_ptr<AbsSharedResultSet> & set,std::vector<uint8_t> & data) const144 bool TriggerDbHelper::GetBlob(std::shared_ptr<AbsSharedResultSet> &set, std::vector<uint8_t> &data) const
145 {
146     int columnIndex;
147     int ret = set->GetColumnIndex("data", columnIndex);
148     if (ret != E_OK) {
149         INTELL_VOICE_LOG_ERROR("failed to get data column index, ret:%{public}d", ret);
150         return false;
151     }
152 
153     ret = set->GetBlob(columnIndex, data);
154     if (ret != E_OK) {
155         INTELL_VOICE_LOG_ERROR("failed to get data, ret:%{public}d", ret);
156         return false;
157     }
158     return true;
159 }
160 
GetModelVersion(std::shared_ptr<AbsSharedResultSet> & set,int32_t & version) const161 bool TriggerDbHelper::GetModelVersion(std::shared_ptr<AbsSharedResultSet> &set, int32_t &version) const
162 {
163     int columnIndex;
164     int ret = set->GetColumnIndex("model_version", columnIndex);
165     if (ret != E_OK) {
166         INTELL_VOICE_LOG_ERROR("failed to get model version column index, ret:%{public}d", ret);
167         return false;
168     }
169 
170     ret = set->GetInt(columnIndex, version);
171     if (ret != E_OK) {
172         INTELL_VOICE_LOG_ERROR("failed to get model version, ret:%{public}d", ret);
173         return false;
174     }
175     return true;
176 }
177 
GetModelType(std::shared_ptr<AbsSharedResultSet> & set,int32_t & type) const178 bool TriggerDbHelper::GetModelType(std::shared_ptr<AbsSharedResultSet> &set, int32_t &type) const
179 {
180     int columnIndex;
181     int ret = set->GetColumnIndex("model_type", columnIndex);
182     if (ret != E_OK) {
183         INTELL_VOICE_LOG_ERROR("failed to get model type column index, ret:%{public}d", ret);
184         return false;
185     }
186 
187     ret = set->GetInt(columnIndex, type);
188     if (ret != E_OK) {
189         INTELL_VOICE_LOG_ERROR("failed to get model type, ret:%{public}d", ret);
190         return false;
191     }
192     INTELL_VOICE_LOG_INFO("model type:%{public}d", type);
193     return true;
194 }
195 
GetGenericTriggerModel(const int32_t modelUuid)196 std::shared_ptr<GenericTriggerModel> TriggerDbHelper::GetGenericTriggerModel(const int32_t modelUuid)
197 {
198     std::lock_guard<std::mutex> lock(mutex_);
199     INTELL_VOICE_LOG_INFO("enter, model uuid:%{public}d", modelUuid);
200     if (store_ == nullptr) {
201         INTELL_VOICE_LOG_ERROR("store is nullptr");
202         return nullptr;
203     }
204 
205     std::shared_ptr<AbsSharedResultSet> set = store_->QuerySql(
206         "SELECT * FROM trigger WHERE model_uuid = ?", std::vector<std::string> {std::to_string(modelUuid)});
207     if (set == nullptr) {
208         INTELL_VOICE_LOG_ERROR("set is nullptr");
209         return nullptr;
210     }
211 
212     set->GoToFirstRow();
213 
214     int32_t vendorUuid;
215     if (!GetVendorUuid(set, vendorUuid)) {
216         INTELL_VOICE_LOG_ERROR("failed to get vendor uuid");
217         return nullptr;
218     }
219 
220     std::vector<uint8_t> data;
221     if (!GetBlob(set, data)) {
222         INTELL_VOICE_LOG_ERROR("failed to get data");
223         return nullptr;
224     }
225 
226     int32_t modelVersion;
227     if (!GetModelVersion(set, modelVersion)) {
228         INTELL_VOICE_LOG_ERROR("failed to get model version");
229         return nullptr;
230     }
231 
232     int32_t type;
233     if (modelVersion >= static_cast<int32_t>(TriggerModel::TriggerModelVersion::MODLE_VERSION_2)) {
234         if (!GetModelType(set, type)) {
235             INTELL_VOICE_LOG_ERROR("failed to get model type");
236             return nullptr;
237         }
238     } else {
239         type = (modelUuid == OHOS::IntellVoiceEngine::VOICE_WAKEUP_MODEL_UUID ?
240             TriggerModel::TriggerModelType::VOICE_WAKEUP_TYPE : TriggerModel::TriggerModelType::PROXIMAL_WAKEUP_TYPE);
241     }
242 
243     std::shared_ptr<GenericTriggerModel> model = std::make_shared<GenericTriggerModel>(modelUuid, modelVersion,
244         static_cast<TriggerModel::TriggerModelType>(type));
245     if (model == nullptr) {
246         INTELL_VOICE_LOG_ERROR("failed to alloc model");
247         return nullptr;
248     }
249     model->SetData(data);
250     return model;
251 }
252 
DeleteGenericTriggerModel(const int32_t modelUuid)253 void TriggerDbHelper::DeleteGenericTriggerModel(const int32_t modelUuid)
254 {
255     std::lock_guard<std::mutex> lock(mutex_);
256     INTELL_VOICE_LOG_INFO("enter");
257     if (store_ == nullptr) {
258         INTELL_VOICE_LOG_ERROR("store is nullptr");
259         return;
260     }
261     int deletedRows;
262     store_->Delete(deletedRows, "trigger", "model_uuid = ?", std::vector<std::string> {std::to_string(modelUuid)});
263 }
264 }  // namespace IntellVoiceTrigger
265 }  // namespace OHOS
266