1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 //! This module provides data operation management on database tables.
17 //! The managed data can be user input. Because we will prepare and bind data.
18 
19 use core::ffi::c_void;
20 use std::cmp::Ordering;
21 
22 use asset_definition::{log_throw_error, Conversion, DataType, ErrCode, Result, Value};
23 use asset_log::logi;
24 
25 use crate::{
26     database::Database,
27     statement::Statement,
28     transaction::Transaction,
29     types::{ColumnInfo, DbMap, QueryOptions, UpgradeColumnInfo, DB_UPGRADE_VERSION, SQLITE_ROW},
30 };
31 
32 extern "C" {
SqliteChanges(db: *mut c_void) -> i3233     fn SqliteChanges(db: *mut c_void) -> i32;
34 }
35 
36 #[repr(C)]
37 pub(crate) struct Table<'a> {
38     pub(crate) table_name: String,
39     pub(crate) db: &'a Database,
40 }
41 
42 #[inline(always)]
bind_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()>43 fn bind_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()> {
44     for (_, value) in datas.iter() {
45         stmt.bind_data(*index, value)?;
46         *index += 1;
47     }
48     Ok(())
49 }
50 
bind_where_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()>51 fn bind_where_datas(datas: &DbMap, stmt: &Statement, index: &mut i32) -> Result<()> {
52     for (key, value) in datas.iter() {
53         if *key == "SyncType" {
54             stmt.bind_data(*index, value)?;
55             *index += 1;
56         }
57         stmt.bind_data(*index, value)?;
58         *index += 1;
59     }
60     Ok(())
61 }
62 
bind_where_with_specific_condifion(datas: &[Value], stmt: &Statement, index: &mut i32) -> Result<()>63 fn bind_where_with_specific_condifion(datas: &[Value], stmt: &Statement, index: &mut i32) -> Result<()> {
64     for value in datas.iter() {
65         stmt.bind_data(*index, value)?;
66         *index += 1;
67     }
68     Ok(())
69 }
70 
71 #[inline(always)]
build_sql_columns_not_empty(columns: &Vec<&str>, sql: &mut String)72 fn build_sql_columns_not_empty(columns: &Vec<&str>, sql: &mut String) {
73     for i in 0..columns.len() {
74         let column = &columns[i];
75         sql.push_str(column);
76         if i != columns.len() - 1 {
77             sql.push(',');
78         }
79     }
80 }
81 
82 #[inline(always)]
build_sql_columns(columns: &Vec<&str>, sql: &mut String)83 fn build_sql_columns(columns: &Vec<&str>, sql: &mut String) {
84     if !columns.is_empty() {
85         build_sql_columns_not_empty(columns, sql);
86     } else {
87         sql.push('*');
88     }
89 }
90 
91 #[inline(always)]
build_sql_where(conditions: &DbMap, filter: bool, sql: &mut String)92 fn build_sql_where(conditions: &DbMap, filter: bool, sql: &mut String) {
93     if !conditions.is_empty() || filter {
94         sql.push_str(" where ");
95         if filter {
96             sql.push_str("SyncStatus <> 2");
97             if !conditions.is_empty() {
98                 sql.push_str(" and ");
99             }
100         }
101         if !conditions.is_empty() {
102             for (i, column_name) in conditions.keys().enumerate() {
103                 if *column_name == "SyncType" {
104                     sql.push_str("(SyncType & ?) = ?");
105                 } else {
106                     sql.push_str(column_name);
107                     sql.push_str("=?");
108                 }
109                 if i != conditions.len() - 1 {
110                     sql.push_str(" and ")
111                 }
112             }
113         }
114     }
115 }
116 
117 #[inline(always)]
build_sql_values(len: usize, sql: &mut String)118 fn build_sql_values(len: usize, sql: &mut String) {
119     for i in 0..len {
120         sql.push('?');
121         if i != len - 1 {
122             sql.push(',');
123         }
124     }
125 }
126 
from_data_type_to_str(value: &DataType) -> &'static str127 fn from_data_type_to_str(value: &DataType) -> &'static str {
128     match *value {
129         DataType::Bytes => "BLOB",
130         DataType::Number => "INTEGER",
131         DataType::Bool => "INTEGER",
132     }
133 }
134 
from_data_value_to_str_value(value: &Value) -> String135 fn from_data_value_to_str_value(value: &Value) -> String {
136     match *value {
137         Value::Number(i) => format!("{}", i),
138         Value::Bytes(_) => String::from("NOT SUPPORTED"),
139         Value::Bool(b) => format!("{}", b),
140     }
141 }
142 
build_sql_query_options(query_options: Option<&QueryOptions>, sql: &mut String)143 fn build_sql_query_options(query_options: Option<&QueryOptions>, sql: &mut String) {
144     if let Some(option) = query_options {
145         if let Some(order_by) = &option.order_by {
146             if !order_by.is_empty() {
147                 sql.push_str(" order by ");
148                 build_sql_columns_not_empty(order_by, sql);
149             }
150         }
151         if let Some(order) = option.order {
152             let str = if order == Ordering::Greater {
153                 "ASC"
154             } else if order == Ordering::Less {
155                 "DESC"
156             } else {
157                 ""
158             };
159             sql.push_str(format!(" {}", str).as_str());
160         }
161         if let Some(limit) = option.limit {
162             sql.push_str(format!(" limit {}", limit).as_str());
163             if let Some(offset) = option.offset {
164                 sql.push_str(format!(" offset {}", offset).as_str());
165             }
166         } else if let Some(offset) = option.offset {
167             sql.push_str(format!(" limit -1 offset {}", offset).as_str());
168         }
169     }
170 }
171 
build_sql_reverse_condition(reverse_condition: Option<&DbMap>, sql: &mut String)172 fn build_sql_reverse_condition(reverse_condition: Option<&DbMap>, sql: &mut String) {
173     if let Some(conditions) = reverse_condition {
174         if !conditions.is_empty() {
175             sql.push_str(" and ");
176             for (i, column_name) in conditions.keys().enumerate() {
177                 if *column_name == "SyncType" {
178                     sql.push_str("(SyncType & ?) == 0");
179                 } else {
180                     sql.push_str(column_name);
181                     sql.push_str("<>?");
182                 }
183                 if i != conditions.len() - 1 {
184                     sql.push_str(" and ")
185                 }
186             }
187         }
188     }
189 }
190 
get_column_info(columns: &'static [ColumnInfo], db_column: &str) -> Result<&'static ColumnInfo>191 fn get_column_info(columns: &'static [ColumnInfo], db_column: &str) -> Result<&'static ColumnInfo> {
192     for column in columns.iter() {
193         if column.name.eq(db_column) {
194             return Ok(column);
195         }
196     }
197     log_throw_error!(ErrCode::DataCorrupted, "Database is corrupted.")
198 }
199 
200 impl<'a> Table<'a> {
new(table_name: &str, db: &'a Database) -> Table<'a>201     pub(crate) fn new(table_name: &str, db: &'a Database) -> Table<'a> {
202         Table { table_name: table_name.to_string(), db }
203     }
204 
exist(&self) -> Result<bool>205     pub(crate) fn exist(&self) -> Result<bool> {
206         let sql = format!("select * from sqlite_master where type ='table' and name = '{}'", self.table_name);
207         let stmt = Statement::prepare(sql.as_str(), self.db)?;
208         let ret = stmt.step()?;
209         if ret == SQLITE_ROW {
210             Ok(true)
211         } else {
212             Ok(false)
213         }
214     }
215 
216     #[allow(dead_code)]
delete(&self) -> Result<()>217     pub(crate) fn delete(&self) -> Result<()> {
218         let sql = format!("DROP TABLE {}", self.table_name);
219         self.db.exec(&sql)
220     }
221 
222     /// Create a table with name 'table_name' at specific version.
223     /// The columns is descriptions for each column.
create_with_version(&self, columns: &[ColumnInfo], version: u32) -> Result<()>224     pub(crate) fn create_with_version(&self, columns: &[ColumnInfo], version: u32) -> Result<()> {
225         let is_exist = self.exist()?;
226         if is_exist {
227             return Ok(());
228         }
229         let mut sql = format!("CREATE TABLE IF NOT EXISTS {}(", self.table_name);
230         for i in 0..columns.len() {
231             let column = &columns[i];
232             sql.push_str(column.name);
233             sql.push(' ');
234             sql.push_str(from_data_type_to_str(&column.data_type));
235             if column.is_primary_key {
236                 sql.push_str(" PRIMARY KEY");
237             }
238             if column.not_null {
239                 sql.push_str(" NOT NULL");
240             }
241             if i != columns.len() - 1 {
242                 sql.push(',')
243             };
244         }
245         sql.push_str(");");
246         let mut trans = Transaction::new(self.db);
247         trans.begin()?;
248         if self.db.exec(sql.as_str()).is_ok() && self.db.set_version(version).is_ok() {
249             trans.commit()
250         } else {
251             trans.rollback()
252         }
253     }
254 
255     /// Create a table with name 'table_name'.
256     /// The columns is descriptions for each column.
create(&self, columns: &[ColumnInfo]) -> Result<()>257     pub(crate) fn create(&self, columns: &[ColumnInfo]) -> Result<()> {
258         self.create_with_version(columns, DB_UPGRADE_VERSION)
259     }
260 
upgrade(&self, ver: u32, columns: &[UpgradeColumnInfo]) -> Result<()>261     pub(crate) fn upgrade(&self, ver: u32, columns: &[UpgradeColumnInfo]) -> Result<()> {
262         let is_exist = self.exist()?;
263         if !is_exist {
264             return Ok(());
265         }
266         logi!("upgrade table!");
267         let mut trans = Transaction::new(self.db);
268         trans.begin()?;
269         for item in columns {
270             if self.add_column(&item.base_info, &item.default_value).is_err() {
271                 return trans.rollback();
272             }
273         }
274         if self.db.set_version(ver).is_err() {
275             trans.rollback()
276         } else {
277             trans.commit()
278         }
279     }
280 
281     /// Insert a row into table, and datas is the value to be insert.
282     ///
283     /// # Examples
284     ///
285     /// ```
286     /// // SQL: insert into table_name(id,alias) values (3,'alias1')
287     /// let datas = &DbMap::from([("id", Value::Number(3), ("alias", Value::Bytes(b"alias1"))]);
288     /// let ret = table.insert_row(datas);
289     /// ```
insert_row(&self, datas: &DbMap) -> Result<i32>290     pub(crate) fn insert_row(&self, datas: &DbMap) -> Result<i32> {
291         let mut sql = format!("insert into {} (", self.table_name);
292         for (i, column_name) in datas.keys().enumerate() {
293             sql.push_str(column_name);
294             if i != datas.len() - 1 {
295                 sql.push(',');
296             }
297         }
298 
299         sql.push_str(") values (");
300         build_sql_values(datas.len(), &mut sql);
301         sql.push(')');
302         let stmt = Statement::prepare(&sql, self.db)?;
303         let mut index = 1;
304         bind_datas(datas, &stmt, &mut index)?;
305         stmt.step()?;
306         let count = unsafe { SqliteChanges(self.db.handle as _) };
307         Ok(count)
308     }
309 
310     /// Delete row from table.
311     ///
312     /// # Examples
313     ///
314     /// ```
315     /// // SQL: delete from table_name where id=2
316     /// let condition = &DbMap::from([("id", Value::Number(2)]);
317     /// let ret = table.delete_row(condition, None, false);
318     /// ```
delete_row( &self, condition: &DbMap, reverse_condition: Option<&DbMap>, is_filter_sync: bool, ) -> Result<i32>319     pub(crate) fn delete_row(
320         &self,
321         condition: &DbMap,
322         reverse_condition: Option<&DbMap>,
323         is_filter_sync: bool,
324     ) -> Result<i32> {
325         let mut sql = format!("delete from {}", self.table_name);
326         build_sql_where(condition, is_filter_sync, &mut sql);
327         build_sql_reverse_condition(reverse_condition, &mut sql);
328         let stmt = Statement::prepare(&sql, self.db)?;
329         let mut index = 1;
330         bind_where_datas(condition, &stmt, &mut index)?;
331         if let Some(datas) = reverse_condition {
332             bind_datas(datas, &stmt, &mut index)?;
333         }
334         stmt.step()?;
335         let count = unsafe { SqliteChanges(self.db.handle as _) };
336         Ok(count)
337     }
338 
339     /// Delete row from table with specific condition.
340     ///
341     /// # Examples
342     ///
343     /// ```
344     /// // SQL: delete from table_name where id=2
345     /// let specific_cond = "id".to_string();
346     /// let condition_value = Value::Number(2);
347     /// let ret = table.delete_with_specific_cond(specific_cond, condition_value);
348     /// ```
delete_with_specific_cond(&self, specific_cond: &str, condition_value: &[Value]) -> Result<i32>349     pub(crate) fn delete_with_specific_cond(&self, specific_cond: &str, condition_value: &[Value]) -> Result<i32> {
350         let sql: String = format!("delete from {} where {}", self.table_name, specific_cond);
351         let stmt = Statement::prepare(&sql, self.db)?;
352         let mut index = 1;
353         bind_where_with_specific_condifion(condition_value, &stmt, &mut index)?;
354         stmt.step()?;
355         let count = unsafe { SqliteChanges(self.db.handle as _) };
356         Ok(count)
357     }
358 
359     /// Update a row in table.
360     ///
361     /// # Examples
362     ///
363     /// ```
364     /// // SQL: update table_name set alias='update_value' where id=2
365     /// let condition = &DbMap::from([("id", Value::Number(2)]);
366     /// let datas = &DbMap::from([("alias", Value::Bytes(b"update_value")]);
367     /// let ret = table.update_row(conditions, false, datas);
368     /// ```
update_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<i32>369     pub(crate) fn update_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<i32> {
370         let mut sql = format!("update {} set ", self.table_name);
371         for (i, column_name) in datas.keys().enumerate() {
372             sql.push_str(column_name);
373             sql.push_str("=?");
374             if i != datas.len() - 1 {
375                 sql.push(',');
376             }
377         }
378         build_sql_where(condition, is_filter_sync, &mut sql);
379         let stmt = Statement::prepare(&sql, self.db)?;
380         let mut index = 1;
381         bind_datas(datas, &stmt, &mut index)?;
382         bind_where_datas(condition, &stmt, &mut index)?;
383         stmt.step()?;
384         let count = unsafe { SqliteChanges(self.db.handle as _) };
385         Ok(count)
386     }
387 
388     /// Query row from table.
389     /// If length of columns is 0, all table columns are queried. (eg. select * xxx)
390     /// If length of condition is 0, all data in the table is queried.
391     ///
392     /// # Examples
393     ///
394     /// ```
395     /// // SQL: select alias,blobs from table_name
396     /// let result_set = table.query_datas_with_key_value(&vec!["alias", "blobs"], false, &vec![]);
397     /// ```
query_row( &self, columns: &Vec<&'static str>, condition: &DbMap, query_options: Option<&QueryOptions>, is_filter_sync: bool, column_info: &'static [ColumnInfo], ) -> Result<Vec<DbMap>>398     pub(crate) fn query_row(
399         &self,
400         columns: &Vec<&'static str>,
401         condition: &DbMap,
402         query_options: Option<&QueryOptions>,
403         is_filter_sync: bool,
404         column_info: &'static [ColumnInfo],
405     ) -> Result<Vec<DbMap>> {
406         let mut sql = String::from("select ");
407         if !columns.is_empty() {
408             sql.push_str("distinct ");
409         }
410         build_sql_columns(columns, &mut sql);
411         sql.push_str(" from ");
412         sql.push_str(self.table_name.as_str());
413         build_sql_where(condition, is_filter_sync, &mut sql);
414         build_sql_query_options(query_options, &mut sql);
415         let stmt = Statement::prepare(&sql, self.db)?;
416         let mut index = 1;
417         bind_where_datas(condition, &stmt, &mut index)?;
418         let mut result = vec![];
419         while stmt.step()? == SQLITE_ROW {
420             let mut record = DbMap::new();
421             let n = stmt.data_count();
422             for i in 0..n {
423                 let column_name = stmt.query_column_name(i)?;
424                 let column_info = get_column_info(column_info, column_name)?;
425                 match stmt.query_column_auto_type(i)? {
426                     Some(Value::Number(n)) if column_info.data_type == DataType::Bool => {
427                         record.insert(column_info.name, Value::Bool(n != 0))
428                     },
429                     Some(n) if n.data_type() == column_info.data_type => record.insert(column_info.name, n),
430                     Some(_) => {
431                         return log_throw_error!(ErrCode::DataCorrupted, "The data in DB has been tampered with.")
432                     },
433                     None => continue,
434                 };
435             }
436             result.push(record);
437         }
438         Ok(result)
439     }
440 
441     /// Count the number of datas with query condition(can be empty).
442     ///
443     /// # Examples
444     ///
445     /// ```
446     /// // SQL: select count(*) as count from table_name where id=3
447     /// let count = table.count_datas(&DbMap::from([("id", Value::Number(3))]), false);
448     /// ```
count_datas(&self, condition: &DbMap, is_filter_sync: bool) -> Result<u32>449     pub(crate) fn count_datas(&self, condition: &DbMap, is_filter_sync: bool) -> Result<u32> {
450         let mut sql = format!("select count(*) as count from {}", self.table_name);
451         build_sql_where(condition, is_filter_sync, &mut sql);
452         let stmt = Statement::prepare(&sql, self.db)?;
453         let mut index = 1;
454         bind_where_datas(condition, &stmt, &mut index)?;
455         stmt.step()?;
456         let count = stmt.query_column_int(0);
457         Ok(count)
458     }
459 
460     /// Check whether data exists in the database table.
461     ///
462     /// # Examples
463     ///
464     /// ```
465     /// // SQL: select count(*) as count from table_name where id=3 and alias='alias'
466     /// let exits = table
467     ///     .is_data_exists(&DbMap::from([("id", Value::Number(3)), ("alias", Value::Bytes(b"alias"))]), false);
468     /// ```
is_data_exists(&self, cond: &DbMap, is_filter_sync: bool) -> Result<bool>469     pub(crate) fn is_data_exists(&self, cond: &DbMap, is_filter_sync: bool) -> Result<bool> {
470         let ret = self.count_datas(cond, is_filter_sync);
471         match ret {
472             Ok(count) => Ok(count > 0),
473             Err(e) => Err(e),
474         }
475     }
476 
477     /// Add new column tp table.
478     /// 1. Primary key cannot be added.
479     /// 2. Cannot add a non-null column with no default value
480     /// 3. Only the integer and blob types support the default value, and the default value of the blob type is null.
481     ///
482     /// # Examples
483     ///
484     /// ```
485     /// // SQL: alter table table_name add cloumn id integer not null
486     /// let ret = table.add_column(
487     ///     ColumnInfo {
488     ///         name: "id",
489     ///         data_type: DataType::INTEGER,
490     ///         is_primary_key: false,
491     ///         not_null: true,
492     ///     },
493     ///     Some(Value::Number(0)),
494     /// );
495     /// ```
add_column(&self, column: &ColumnInfo, default_value: &Option<Value>) -> Result<()>496     pub(crate) fn add_column(&self, column: &ColumnInfo, default_value: &Option<Value>) -> Result<()> {
497         if column.is_primary_key {
498             return log_throw_error!(ErrCode::InvalidArgument, "The primary key already exists in the table.");
499         }
500         if column.not_null && default_value.is_none() {
501             return log_throw_error!(ErrCode::InvalidArgument, "A default value is required for a non-null column.");
502         }
503         let data_type = from_data_type_to_str(&column.data_type);
504         let mut sql = format!("ALTER TABLE {} ADD COLUMN {} {}", self.table_name, column.name, data_type);
505         if let Some(data) = default_value {
506             sql.push_str(" DEFAULT ");
507             sql.push_str(&from_data_value_to_str_value(data));
508         }
509         if column.not_null {
510             sql.push_str(" NOT NULL");
511         }
512         self.db.exec(sql.as_str())
513     }
514 
replace_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<()>515     pub(crate) fn replace_row(&self, condition: &DbMap, is_filter_sync: bool, datas: &DbMap) -> Result<()> {
516         let mut trans = Transaction::new(self.db);
517         trans.begin()?;
518         if self.delete_row(condition, None, is_filter_sync).is_ok() && self.insert_row(datas).is_ok() {
519             trans.commit()
520         } else {
521             trans.rollback()
522         }
523     }
524 }
525