1 // Copyright (C) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 mod keeper;
15 mod running_task;
16 use std::collections::{HashMap, HashSet};
17 use std::sync::atomic::{AtomicBool, Ordering};
18 use std::sync::Arc;
19 
20 use keeper::SAKeeper;
21 
22 cfg_oh! {
23     use crate::ability::SYSTEM_CONFIG_MANAGER;
24 }
25 use ylong_runtime::task::JoinHandle;
26 
27 use crate::config::Mode;
28 use crate::error::ErrorCode;
29 use crate::manage::database::RequestDb;
30 use crate::manage::events::{TaskEvent, TaskManagerEvent};
31 use crate::manage::scheduler::qos::{QosChanges, QosDirection};
32 use crate::manage::scheduler::queue::running_task::RunningTask;
33 use crate::manage::task_manager::TaskManagerTx;
34 use crate::service::client::ClientManagerEntry;
35 use crate::service::run_count::RunCountManagerEntry;
36 use crate::task::config::Action;
37 use crate::task::info::State;
38 use crate::task::reason::Reason;
39 use crate::task::request_task::RequestTask;
40 use crate::utils::runtime_spawn;
41 
42 pub(crate) struct RunningQueue {
43     download_queue: HashMap<(u64, u32), Arc<RequestTask>>,
44     upload_queue: HashMap<(u64, u32), Arc<RequestTask>>,
45     running_tasks: HashMap<(u64, u32), Option<AbortHandle>>,
46     keeper: SAKeeper,
47     tx: TaskManagerTx,
48     run_count_manager: RunCountManagerEntry,
49     client_manager: ClientManagerEntry,
50     // paused and then resume upload task need to upload from the breakpoint
51     pub(crate) upload_resume: HashSet<u32>,
52 }
53 
54 impl RunningQueue {
new( tx: TaskManagerTx, run_count_manager: RunCountManagerEntry, client_manager: ClientManagerEntry, ) -> Self55     pub(crate) fn new(
56         tx: TaskManagerTx,
57         run_count_manager: RunCountManagerEntry,
58         client_manager: ClientManagerEntry,
59     ) -> Self {
60         Self {
61             download_queue: HashMap::new(),
62             upload_queue: HashMap::new(),
63             keeper: SAKeeper::new(tx.clone()),
64             tx,
65             running_tasks: HashMap::new(),
66             run_count_manager,
67             client_manager,
68             upload_resume: HashSet::new(),
69         }
70     }
71 
get_task(&self, uid: u64, task_id: u32) -> Option<&Arc<RequestTask>>72     pub(crate) fn get_task(&self, uid: u64, task_id: u32) -> Option<&Arc<RequestTask>> {
73         self.download_queue
74             .get(&(uid, task_id))
75             .or_else(|| self.upload_queue.get(&(uid, task_id)))
76     }
77 
task_finish(&mut self, uid: u64, task_id: u32)78     pub(crate) fn task_finish(&mut self, uid: u64, task_id: u32) {
79         self.running_tasks.remove(&(uid, task_id));
80     }
81 
try_restart(&mut self, uid: u64, task_id: u32) -> bool82     pub(crate) fn try_restart(&mut self, uid: u64, task_id: u32) -> bool {
83         if let Some(task) = self
84             .download_queue
85             .get(&(uid, task_id))
86             .or(self.upload_queue.get(&(uid, task_id)))
87         {
88             info!("{} restart running", task_id);
89             let running_task = RunningTask::new(task.clone(), self.tx.clone(), self.keeper.clone());
90             let abort_flag = Arc::new(AtomicBool::new(false));
91             let abort_flag_clone = abort_flag.clone();
92             let join_handle = runtime_spawn(async move {
93                 running_task.run(abort_flag_clone.clone()).await;
94             });
95             let uid = task.uid();
96             let task_id = task.task_id();
97             self.running_tasks.insert(
98                 (uid, task_id),
99                 Some(AbortHandle::new(abort_flag, join_handle)),
100             );
101             true
102         } else {
103             false
104         }
105     }
106 
tasks(&self) -> impl Iterator<Item = &Arc<RequestTask>>107     pub(crate) fn tasks(&self) -> impl Iterator<Item = &Arc<RequestTask>> {
108         self.download_queue
109             .values()
110             .chain(self.upload_queue.values())
111     }
112 
running_tasks(&self) -> usize113     pub(crate) fn running_tasks(&self) -> usize {
114         self.download_queue.len() + self.upload_queue.len()
115     }
116 
dump_tasks(&self)117     pub(crate) fn dump_tasks(&self) {
118         info!("dump all running {}", self.running_tasks());
119 
120         for ((uid, task_id), task) in self.download_queue.iter().chain(self.upload_queue.iter()) {
121             let task_status = task.status.lock().unwrap();
122             info!(
123                 "dump task {}, uid {}, action {}, mode {}, bundle {}, status {:?}",
124                 task_id,
125                 uid,
126                 task.action().repr,
127                 task.mode().repr,
128                 task.bundle(),
129                 *task_status
130             );
131         }
132     }
133 
reschedule(&mut self, qos: QosChanges, qos_remove_queue: &mut Vec<(u64, u32)>)134     pub(crate) fn reschedule(&mut self, qos: QosChanges, qos_remove_queue: &mut Vec<(u64, u32)>) {
135         if let Some(vec) = qos.download {
136             self.reschedule_inner(Action::Download, vec, qos_remove_queue)
137         }
138         if let Some(vec) = qos.upload {
139             self.reschedule_inner(Action::Upload, vec, qos_remove_queue)
140         }
141     }
142 
reschedule_inner( &mut self, action: Action, qos_vec: Vec<QosDirection>, qos_remove_queue: &mut Vec<(u64, u32)>, )143     pub(crate) fn reschedule_inner(
144         &mut self,
145         action: Action,
146         qos_vec: Vec<QosDirection>,
147         qos_remove_queue: &mut Vec<(u64, u32)>,
148     ) {
149         let mut new_queue = HashMap::new();
150 
151         let queue = if action == Action::Download {
152             &mut self.download_queue
153         } else {
154             &mut self.upload_queue
155         };
156 
157         // We need to decide which tasks need to continue running based on `QosChanges`.
158         for qos_direction in qos_vec.iter() {
159             let uid = qos_direction.uid();
160             let task_id = qos_direction.task_id();
161 
162             if let Some(task) = queue.remove(&(uid, task_id)) {
163                 // If we can find that the task is running in `running_tasks`,
164                 // we just need to adjust its rate.
165                 task.speed_limit(qos_direction.direction() as u64);
166                 // Then we put it into `satisfied_tasks`.
167                 new_queue.insert((uid, task_id), task);
168                 continue;
169             }
170 
171             // If the task is not in the current running queue, retrieve
172             // the corresponding task from the database and start it.
173 
174             #[cfg(feature = "oh")]
175             let system_config = unsafe { SYSTEM_CONFIG_MANAGER.assume_init_ref().system_config() };
176             let upload_resume = self.upload_resume.remove(&task_id);
177 
178             let task = match RequestDb::get_instance().get_task(
179                 task_id,
180                 #[cfg(feature = "oh")]
181                 system_config,
182                 &self.client_manager,
183                 upload_resume,
184             ) {
185                 Ok(task) => task,
186                 Err(ErrorCode::TaskNotFound) => continue, // If we cannot find the task, skip it.
187                 Err(ErrorCode::TaskStateErr) => continue, // If we cannot find the task, skip it.
188                 Err(e) => {
189                     info!("get task {} error:{:?}", task_id, e);
190                     if let Some(info) = RequestDb::get_instance().get_task_qos_info(task_id) {
191                         self.tx.send_event(TaskManagerEvent::Task(TaskEvent::Failed(
192                             task_id,
193                             uid,
194                             Reason::OthersError,
195                             Mode::from(info.mode),
196                         )));
197                     }
198                     qos_remove_queue.push((uid, task_id));
199                     continue;
200                 }
201             };
202             task.speed_limit(qos_direction.direction() as u64);
203 
204             new_queue.insert((uid, task_id), task.clone());
205 
206             if self.running_tasks.contains_key(&(uid, task_id)) {
207                 info!("task {} not finished", task_id);
208                 continue;
209             }
210 
211             info!("{} create running", task_id);
212             let running_task = RunningTask::new(task.clone(), self.tx.clone(), self.keeper.clone());
213             RequestDb::get_instance().update_task_state(
214                 running_task.task_id(),
215                 State::Running,
216                 Reason::Default,
217             );
218             let abort_flag = Arc::new(AtomicBool::new(false));
219             let abort_flag_clone = abort_flag.clone();
220             let join_handle = runtime_spawn(async move {
221                 running_task.run(abort_flag_clone).await;
222             });
223 
224             let uid = task.uid();
225             let task_id = task.task_id();
226             self.running_tasks.insert(
227                 (uid, task_id),
228                 Some(AbortHandle::new(abort_flag, join_handle)),
229             );
230         }
231         // every satisfied tasks in running has been moved, set left tasks to Waiting
232 
233         for task in queue.values() {
234             if let Some(join_handle) = self.running_tasks.get_mut(&(task.uid(), task.task_id())) {
235                 if let Some(join_handle) = join_handle.take() {
236                     join_handle.cancel();
237                 };
238             }
239         }
240         *queue = new_queue;
241 
242         #[cfg(feature = "oh")]
243         self.run_count_manager
244             .notify_run_count(self.download_queue.len() + self.upload_queue.len());
245     }
246 }
247 
248 struct AbortHandle {
249     abort_flag: Arc<AtomicBool>,
250     join_handle: JoinHandle<()>,
251 }
252 
253 impl AbortHandle {
new(abort_flag: Arc<AtomicBool>, join_handle: JoinHandle<()>) -> Self254     fn new(abort_flag: Arc<AtomicBool>, join_handle: JoinHandle<()>) -> Self {
255         Self {
256             abort_flag,
257             join_handle,
258         }
259     }
cancel(self)260     fn cancel(self) {
261         self.abort_flag.store(true, Ordering::Release);
262         self.join_handle.cancel();
263     }
264 }
265