1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! Frame recv coroutine.
15 
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::sync::{Arc, Mutex};
19 use std::task::{Context, Poll};
20 
21 use ylong_http::h2::{
22     ErrorCode, Frame, FrameDecoder, FrameKind, FramesIntoIter, H2Error, Payload, Setting,
23 };
24 
25 use crate::runtime::{AsyncRead, BoundedSender, ReadBuf, ReadHalf, SendError};
26 use crate::util::dispatcher::http2::{
27     DispatchErrorKind, OutputMessage, SettingsState, SettingsSync,
28 };
29 
30 pub(crate) type OutputSendFut =
31     Pin<Box<dyn Future<Output = Result<(), SendError<OutputMessage>>> + Send + Sync>>;
32 
33 #[derive(Copy, Clone)]
34 enum DecodeState {
35     Read,
36     Send,
37     Exit(DispatchErrorKind),
38 }
39 
40 pub(crate) struct RecvData<S> {
41     decoder: FrameDecoder,
42     settings: Arc<Mutex<SettingsSync>>,
43     reader: ReadHalf<S>,
44     state: DecodeState,
45     next_state: DecodeState,
46     resp_tx: BoundedSender<OutputMessage>,
47     curr_message: Option<OutputSendFut>,
48     pending_iter: Option<FramesIntoIter>,
49 }
50 
51 impl<S: AsyncRead + Unpin + Sync + Send + 'static> Future for RecvData<S> {
52     type Output = Result<(), DispatchErrorKind>;
53 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>54     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55         let receiver = self.get_mut();
56         receiver.poll_read_frame(cx)
57     }
58 }
59 
60 impl<S: AsyncRead + Unpin + Sync + Send + 'static> RecvData<S> {
new( decoder: FrameDecoder, settings: Arc<Mutex<SettingsSync>>, reader: ReadHalf<S>, resp_tx: BoundedSender<OutputMessage>, ) -> Self61     pub(crate) fn new(
62         decoder: FrameDecoder,
63         settings: Arc<Mutex<SettingsSync>>,
64         reader: ReadHalf<S>,
65         resp_tx: BoundedSender<OutputMessage>,
66     ) -> Self {
67         Self {
68             decoder,
69             settings,
70             reader,
71             state: DecodeState::Read,
72             next_state: DecodeState::Read,
73             resp_tx,
74             curr_message: None,
75             pending_iter: None,
76         }
77     }
78 
poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>>79     fn poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>> {
80         let mut buf = [0u8; 1024];
81         loop {
82             match self.state {
83                 DecodeState::Read => {
84                     let mut read_buf = ReadBuf::new(&mut buf);
85                     match Pin::new(&mut self.reader).poll_read(cx, &mut read_buf) {
86                         Poll::Ready(Err(e)) => {
87                             return self.transmit_error(cx, e.into());
88                         }
89                         Poll::Ready(Ok(())) => {}
90                         Poll::Pending => {
91                             return Poll::Pending;
92                         }
93                     }
94                     let read = read_buf.filled().len();
95                     if read == 0 {
96                         return self.transmit_error(cx, DispatchErrorKind::Disconnect);
97                     }
98 
99                     match self.decoder.decode(&buf[..read]) {
100                         Ok(frames) => match self.poll_iterator_frames(cx, frames.into_iter()) {
101                             Poll::Ready(Ok(_)) => {}
102                             Poll::Ready(Err(e)) => {
103                                 return Poll::Ready(Err(e));
104                             }
105                             Poll::Pending => {
106                                 self.next_state = DecodeState::Read;
107                             }
108                         },
109                         Err(e) => {
110                             match self.transmit_message(cx, OutputMessage::OutputExit(e.into())) {
111                                 Poll::Ready(Err(_)) => {
112                                     return Poll::Ready(Err(DispatchErrorKind::ChannelClosed))
113                                 }
114                                 Poll::Ready(Ok(_)) => {}
115                                 Poll::Pending => {
116                                     self.next_state = DecodeState::Read;
117                                     return Poll::Pending;
118                                 }
119                             }
120                         }
121                     }
122                 }
123                 DecodeState::Send => {
124                     match self.poll_blocked_task(cx) {
125                         Poll::Ready(Ok(_)) => {
126                             self.state = self.next_state;
127                             // Reset next state.
128                             self.next_state = DecodeState::Read;
129                         }
130                         Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
131                         Poll::Pending => return Poll::Pending,
132                     }
133                 }
134                 DecodeState::Exit(e) => {
135                     return Poll::Ready(Err(e));
136                 }
137             }
138         }
139     }
140 
poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>>141     fn poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>> {
142         if let Some(mut task) = self.curr_message.take() {
143             match task.as_mut().poll(cx) {
144                 Poll::Ready(Ok(_)) => {}
145                 Poll::Ready(Err(_)) => {
146                     return Poll::Ready(Err(DispatchErrorKind::ChannelClosed));
147                 }
148                 Poll::Pending => {
149                     self.curr_message = Some(task);
150                     return Poll::Pending;
151                 }
152             }
153         }
154 
155         if let Some(iter) = self.pending_iter.take() {
156             return self.poll_iterator_frames(cx, iter);
157         }
158         Poll::Ready(Ok(()))
159     }
160 
poll_iterator_frames( &mut self, cx: &mut Context<'_>, mut iter: FramesIntoIter, ) -> Poll<Result<(), DispatchErrorKind>>161     fn poll_iterator_frames(
162         &mut self,
163         cx: &mut Context<'_>,
164         mut iter: FramesIntoIter,
165     ) -> Poll<Result<(), DispatchErrorKind>> {
166         while let Some(kind) = iter.next() {
167             match kind {
168                 FrameKind::Complete(frame) => {
169                     // TODO Whether to continue processing the remaining frames after connection
170                     // error occurs in the Settings frame.
171                     let message = if let Err(e) = self.update_settings(&frame) {
172                         OutputMessage::OutputExit(DispatchErrorKind::H2(e))
173                     } else {
174                         OutputMessage::Output(frame)
175                     };
176 
177                     match self.transmit_message(cx, message) {
178                         Poll::Ready(Ok(_)) => {}
179                         Poll::Ready(Err(e)) => {
180                             return Poll::Ready(Err(e));
181                         }
182                         Poll::Pending => {
183                             self.pending_iter = Some(iter);
184                             return Poll::Pending;
185                         }
186                     }
187                 }
188                 FrameKind::Partial => {}
189             }
190         }
191         Poll::Ready(Ok(()))
192     }
193 
transmit_error( &mut self, cx: &mut Context<'_>, exit_err: DispatchErrorKind, ) -> Poll<Result<(), DispatchErrorKind>>194     fn transmit_error(
195         &mut self,
196         cx: &mut Context<'_>,
197         exit_err: DispatchErrorKind,
198     ) -> Poll<Result<(), DispatchErrorKind>> {
199         match self.transmit_message(cx, OutputMessage::OutputExit(exit_err)) {
200             Poll::Ready(_) => Poll::Ready(Err(exit_err)),
201             Poll::Pending => {
202                 self.next_state = DecodeState::Exit(exit_err);
203                 Poll::Pending
204             }
205         }
206     }
207 
transmit_message( &mut self, cx: &mut Context<'_>, message: OutputMessage, ) -> Poll<Result<(), DispatchErrorKind>>208     fn transmit_message(
209         &mut self,
210         cx: &mut Context<'_>,
211         message: OutputMessage,
212     ) -> Poll<Result<(), DispatchErrorKind>> {
213         let mut task = {
214             let sender = self.resp_tx.clone();
215             let ft = async move { sender.send(message).await };
216             Box::pin(ft)
217         };
218 
219         match task.as_mut().poll(cx) {
220             Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
221             // The current coroutine sending the request exited prematurely.
222             Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)),
223             Poll::Pending => {
224                 self.state = DecodeState::Send;
225                 self.curr_message = Some(task);
226                 Poll::Pending
227             }
228         }
229     }
230 
update_settings(&mut self, frame: &Frame) -> Result<(), H2Error>231     fn update_settings(&mut self, frame: &Frame) -> Result<(), H2Error> {
232         if let Payload::Settings(_settings) = frame.payload() {
233             if frame.flags().is_ack() {
234                 self.update_decoder_settings()?;
235             }
236         }
237         Ok(())
238     }
239 
update_decoder_settings(&mut self) -> Result<(), H2Error>240     fn update_decoder_settings(&mut self) -> Result<(), H2Error> {
241         let connection = self.settings.lock().unwrap();
242         match &connection.settings {
243             SettingsState::Acknowledging(settings) => {
244                 for setting in settings.get_settings() {
245                     if let Setting::MaxHeaderListSize(size) = setting {
246                         self.decoder.set_max_header_list_size(*size as usize);
247                     }
248                     if let Setting::MaxFrameSize(size) = setting {
249                         self.decoder.set_max_frame_size(*size)?;
250                     }
251                 }
252                 Ok(())
253             }
254             SettingsState::Synced => Err(H2Error::ConnectionError(ErrorCode::ConnectError)),
255         }
256     }
257 }
258