1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! Frame send coroutine.
15 
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::sync::{Arc, Mutex};
19 use std::task::{Context, Poll};
20 
21 use ylong_http::h2::{ErrorCode, Frame, FrameEncoder, H2Error, Payload, Setting, Settings};
22 
23 use crate::runtime::{AsyncWrite, UnboundedReceiver, WriteHalf};
24 use crate::util::dispatcher::http2::{DispatchErrorKind, SettingsState, SettingsSync};
25 
26 pub(crate) struct SendData<S> {
27     encoder: FrameEncoder,
28     settings: Arc<Mutex<SettingsSync>>,
29     writer: WriteHalf<S>,
30     req_rx: UnboundedReceiver<Frame>,
31     state: InputState,
32     buf: WriteBuf,
33 }
34 
35 enum InputState {
36     RecvFrame,
37     WriteFrame,
38 }
39 
40 enum SettingState {
41     Not,
42     Local(Settings),
43     Ack,
44 }
45 
46 pub(crate) struct WriteBuf {
47     buf: [u8; 1024],
48     end: usize,
49     start: usize,
50     empty: bool,
51 }
52 
53 impl WriteBuf {
new() -> Self54     pub(crate) fn new() -> Self {
55         Self {
56             buf: [0; 1024],
57             end: 0,
58             start: 0,
59             empty: true,
60         }
61     }
clear(&mut self)62     pub(crate) fn clear(&mut self) {
63         self.start = 0;
64         self.end = 0;
65         self.empty = true;
66     }
67 }
68 
69 impl<S: AsyncWrite + Unpin + Sync + Send + 'static> Future for SendData<S> {
70     type Output = Result<(), DispatchErrorKind>;
71 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>72     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73         let sender = self.get_mut();
74         loop {
75             match sender.state {
76                 InputState::RecvFrame => {
77                     let frame = match sender.poll_recv_frame(cx) {
78                         Poll::Ready(Ok(frame)) => frame,
79                         Poll::Ready(Err(e)) => {
80                             // Errors in the Frame Writer are thrown directly to exit the coroutine.
81                             return Poll::Ready(Err(e));
82                         }
83                         Poll::Pending => return Poll::Pending,
84                     };
85 
86                     let state = sender.update_settings(&frame);
87 
88                     if let SettingState::Local(setting) = &state {
89                         let mut sync = sender.settings.lock().unwrap();
90                         sync.settings = SettingsState::Acknowledging(setting.clone());
91                     }
92 
93                     let frame = if let SettingState::Ack = state {
94                         Settings::ack()
95                     } else {
96                         frame
97                     };
98                     // This error will never happen.
99                     sender.encoder.set_frame(frame).map_err(|_| {
100                         DispatchErrorKind::H2(H2Error::ConnectionError(ErrorCode::IntervalError))
101                     })?;
102                     sender.state = InputState::WriteFrame;
103                 }
104                 InputState::WriteFrame => {
105                     match sender.poll_writer_frame(cx) {
106                         Poll::Ready(Ok(())) => {}
107                         Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
108                         Poll::Pending => return Poll::Pending,
109                     };
110                     sender.state = InputState::RecvFrame;
111                 }
112             }
113         }
114     }
115 }
116 
117 impl<S: AsyncWrite + Unpin + Sync + Send + 'static> SendData<S> {
new( encoder: FrameEncoder, settings: Arc<Mutex<SettingsSync>>, writer: WriteHalf<S>, req_rx: UnboundedReceiver<Frame>, ) -> Self118     pub(crate) fn new(
119         encoder: FrameEncoder,
120         settings: Arc<Mutex<SettingsSync>>,
121         writer: WriteHalf<S>,
122         req_rx: UnboundedReceiver<Frame>,
123     ) -> Self {
124         Self {
125             encoder,
126             settings,
127             writer,
128             req_rx,
129             state: InputState::RecvFrame,
130             buf: WriteBuf::new(),
131         }
132     }
133 
134     // io write interface
poll_writer_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>>135     fn poll_writer_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), DispatchErrorKind>> {
136         if !self.buf.empty {
137             loop {
138                 match Pin::new(&mut self.writer)
139                     .poll_write(cx, &self.buf.buf[self.buf.start..self.buf.end])
140                     .map_err(|e| DispatchErrorKind::Io(e.kind()))?
141                 {
142                     Poll::Ready(written) => {
143                         self.buf.start += written;
144                         if self.buf.start == self.buf.end {
145                             self.buf.clear();
146                             break;
147                         }
148                     }
149                     Poll::Pending => {
150                         return Poll::Pending;
151                     }
152                 }
153             }
154         }
155 
156         loop {
157             let size = self.encoder.encode(&mut self.buf.buf).map_err(|_| {
158                 DispatchErrorKind::H2(H2Error::ConnectionError(ErrorCode::IntervalError))
159             })?;
160 
161             if size == 0 {
162                 break;
163             }
164             let mut index = 0;
165 
166             loop {
167                 match Pin::new(&mut self.writer)
168                     .poll_write(cx, &self.buf.buf[index..size])
169                     .map_err(|e| DispatchErrorKind::Io(e.kind()))?
170                 {
171                     Poll::Ready(written) => {
172                         index += written;
173                         if index == size {
174                             break;
175                         }
176                     }
177                     Poll::Pending => {
178                         self.buf.start = index;
179                         self.buf.end = size;
180                         self.buf.empty = false;
181                         return Poll::Pending;
182                     }
183                 }
184             }
185         }
186         Poll::Ready(Ok(()))
187     }
188 
189     // io write interface
poll_recv_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame, DispatchErrorKind>>190     fn poll_recv_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame, DispatchErrorKind>> {
191         #[cfg(feature = "tokio_base")]
192         match self.req_rx.poll_recv(cx) {
193             Poll::Ready(None) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)),
194             Poll::Ready(Some(frame)) => Poll::Ready(Ok(frame)),
195             Poll::Pending => Poll::Pending,
196         }
197         #[cfg(feature = "ylong_base")]
198         match self.req_rx.poll_recv(cx) {
199             Poll::Ready(Err(_e)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)),
200             Poll::Ready(Ok(frame)) => Poll::Ready(Ok(frame)),
201             Poll::Pending => Poll::Pending,
202         }
203     }
204 
update_settings(&mut self, frame: &Frame) -> SettingState205     fn update_settings(&mut self, frame: &Frame) -> SettingState {
206         let settings = if let Payload::Settings(settings) = frame.payload() {
207             settings
208         } else {
209             return SettingState::Not;
210         };
211         // The ack in Writer is sent from the client to the server to confirm the
212         // Settings of the encoder on the client. The ack in Reader is sent
213         // from the server to the client to confirm the Settings of the decoder on the
214         // client
215         if frame.flags().is_ack() {
216             for setting in settings.get_settings() {
217                 if let Setting::HeaderTableSize(size) = setting {
218                     self.encoder.update_header_table_size(*size as usize);
219                 }
220                 if let Setting::MaxFrameSize(size) = setting {
221                     self.encoder.update_max_frame_size(*size as usize);
222                 }
223             }
224             SettingState::Ack
225         } else {
226             SettingState::Local(settings.clone())
227         }
228     }
229 }
230