1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::ops::Deref;
16 use std::pin::Pin;
17 use std::sync::atomic::Ordering;
18 use std::task::{Context, Poll};
19 
20 use ylong_http::error::HttpError;
21 use ylong_http::h2;
22 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders};
23 use ylong_http::headers::Headers;
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::RequestPart;
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::ResponsePart;
28 
29 use crate::async_impl::conn::StreamData;
30 use crate::async_impl::request::Message;
31 use crate::async_impl::{HttpBody, Response};
32 use crate::error::{ErrorKind, HttpClientError};
33 use crate::runtime::{AsyncRead, ReadBuf};
34 use crate::util::dispatcher::http2::Http2Conn;
35 use crate::util::h2::{BodyDataRef, RequestWrapper};
36 use crate::util::normalizer::BodyLengthParser;
37 
38 const UNUSED_FLAG: u8 = 0x0;
39 
request<S>( mut conn: Http2Conn<S>, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,40 pub(crate) async fn request<S>(
41     mut conn: Http2Conn<S>,
42     mut message: Message,
43 ) -> Result<Response, HttpClientError>
44 where
45     S: Sync + Send + Unpin + 'static,
46 {
47     message
48         .interceptor
49         .intercept_request(message.request.ref_mut())?;
50     let part = message.request.ref_mut().part().clone();
51 
52     // TODO Implement trailer.
53     let headers = build_headers_frame(conn.id, part, false)
54         .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
55     let data = BodyDataRef::new(message.request.clone());
56     let stream = RequestWrapper {
57         header: headers,
58         data,
59     };
60     conn.send_frame_to_controller(stream)?;
61     let frame = conn.receiver.recv().await?;
62     frame_2_response(conn, frame, message)
63 }
64 
frame_2_response<S>( conn: Http2Conn<S>, headers_frame: Frame, mut message: Message, ) -> Result<Response, HttpClientError> where S: Sync + Send + Unpin + 'static,65 fn frame_2_response<S>(
66     conn: Http2Conn<S>,
67     headers_frame: Frame,
68     mut message: Message,
69 ) -> Result<Response, HttpClientError>
70 where
71     S: Sync + Send + Unpin + 'static,
72 {
73     let part = match headers_frame.payload() {
74         Payload::Headers(headers) => {
75             let (pseudo, fields) = headers.parts();
76             let status_code = match pseudo.status() {
77                 Some(status) => StatusCode::from_bytes(status.as_bytes())
78                     .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
79                 None => {
80                     return Err(HttpClientError::from_error(
81                         ErrorKind::Request,
82                         HttpError::from(H2Error::StreamError(conn.id, ErrorCode::ProtocolError)),
83                     ));
84                 }
85             };
86             ResponsePart {
87                 version: ylong_http::version::Version::HTTP2,
88                 status: status_code,
89                 headers: fields.clone(),
90             }
91         }
92         Payload::RstStream(reset) => {
93             return Err(HttpClientError::from_error(
94                 ErrorKind::Request,
95                 HttpError::from(H2Error::StreamError(
96                     conn.id,
97                     ErrorCode::try_from(reset.error_code()).unwrap_or(ErrorCode::ProtocolError),
98                 )),
99             ));
100         }
101         _ => {
102             return Err(HttpClientError::from_error(
103                 ErrorKind::Request,
104                 HttpError::from(H2Error::StreamError(conn.id, ErrorCode::ProtocolError)),
105             ));
106         }
107     };
108 
109     let text_io = TextIo::new(conn);
110     // TODO Can http2 have no content-length header field and rely only on the
111     // end_stream flag? flag has a Body
112     let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
113         Ok(length) => length,
114         Err(e) => {
115             return Err(e);
116         }
117     };
118     let body = HttpBody::new(message.interceptor, length, Box::new(text_io), &[0u8; 0])?;
119 
120     Ok(Response::new(
121         ylong_http::response::Response::from_raw_parts(part, body),
122     ))
123 }
124 
build_headers_frame( id: u32, mut part: RequestPart, is_end_stream: bool, ) -> Result<Frame, HttpError>125 pub(crate) fn build_headers_frame(
126     id: u32,
127     mut part: RequestPart,
128     is_end_stream: bool,
129 ) -> Result<Frame, HttpError> {
130     remove_connection_specific_headers(&mut part.headers)?;
131     let pseudo = build_pseudo_headers(&mut part)?;
132     let mut header_part = h2::Parts::new();
133     header_part.set_header_lines(part.headers);
134     header_part.set_pseudo(pseudo);
135     let headers_payload = h2::Headers::new(header_part);
136 
137     let mut flag = FrameFlags::new(UNUSED_FLAG);
138     flag.set_end_headers(true);
139     if is_end_stream {
140         flag.set_end_stream(true);
141     }
142     Ok(Frame::new(
143         id as usize,
144         flag,
145         Payload::Headers(headers_payload),
146     ))
147 }
148 
149 // Illegal headers validation in http2.
150 // [`Connection-Specific Headers`] implementation.
151 //
152 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError>153 fn remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError> {
154     const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
155         "connection",
156         "keep-alive",
157         "proxy-connection",
158         "upgrade",
159         "transfer-encoding",
160     ];
161     for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
162         headers.remove(*specific_header);
163     }
164 
165     if let Some(te_ref) = headers.get("te") {
166         let te = te_ref.to_string()?;
167         if te.as_str() != "trailers" {
168             headers.remove("te");
169         }
170     }
171     Ok(())
172 }
173 
build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError>174 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
175     let mut pseudo = PseudoHeaders::default();
176     match request_part.uri.scheme() {
177         Some(scheme) => {
178             pseudo.set_scheme(Some(String::from(scheme.as_str())));
179         }
180         None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
181     }
182     pseudo.set_method(Some(String::from(request_part.method.as_str())));
183     pseudo.set_path(
184         request_part
185             .uri
186             .path_and_query()
187             .or_else(|| Some(String::from("/"))),
188     );
189     let host = request_part
190         .headers
191         .remove("host")
192         .and_then(|auth| auth.to_string().ok());
193     pseudo.set_authority(host);
194     Ok(pseudo)
195 }
196 
197 struct TextIo<S> {
198     pub(crate) handle: Http2Conn<S>,
199     pub(crate) offset: usize,
200     pub(crate) remain: Option<Frame>,
201     pub(crate) is_closed: bool,
202 }
203 
204 struct HttpReadBuf<'a, 'b> {
205     buf: &'a mut ReadBuf<'b>,
206 }
207 
208 impl<'a, 'b> HttpReadBuf<'a, 'b> {
append_slice(&mut self, buf: &[u8])209     pub(crate) fn append_slice(&mut self, buf: &[u8]) {
210         #[cfg(feature = "ylong_base")]
211         self.buf.append(buf);
212 
213         #[cfg(feature = "tokio_base")]
214         self.buf.put_slice(buf);
215     }
216 }
217 
218 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
219     type Target = ReadBuf<'b>;
220 
deref(&self) -> &Self::Target221     fn deref(&self) -> &Self::Target {
222         self.buf
223     }
224 }
225 
226 impl<S> TextIo<S>
227 where
228     S: Sync + Send + Unpin + 'static,
229 {
new(handle: Http2Conn<S>) -> Self230     pub(crate) fn new(handle: Http2Conn<S>) -> Self {
231         Self {
232             handle,
233             offset: 0,
234             remain: None,
235             is_closed: false,
236         }
237     }
238 
match_channel_message( poll_result: Poll<Frame>, text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>239     fn match_channel_message(
240         poll_result: Poll<Frame>,
241         text_io: &mut TextIo<S>,
242         buf: &mut HttpReadBuf,
243     ) -> Option<Poll<std::io::Result<()>>> {
244         match poll_result {
245             Poll::Ready(frame) => match frame.payload() {
246                 Payload::Headers(_) => {
247                     text_io.remain = Some(frame);
248                     text_io.offset = 0;
249                     Some(Poll::Ready(Ok(())))
250                 }
251                 Payload::Data(data) => {
252                     let data = data.data();
253                     let unfilled_len = buf.remaining();
254                     let data_len = data.len();
255                     let fill_len = min(data_len, unfilled_len);
256                     if unfilled_len < data_len {
257                         buf.append_slice(&data[..fill_len]);
258                         text_io.offset += fill_len;
259                         text_io.remain = Some(frame);
260                         Some(Poll::Ready(Ok(())))
261                     } else {
262                         buf.append_slice(&data[..fill_len]);
263                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
264                     }
265                 }
266                 Payload::RstStream(reset) => {
267                     if reset.is_no_error() {
268                         text_io.is_closed = true;
269                         Some(Poll::Ready(Ok(())))
270                     } else {
271                         Some(Poll::Ready(Err(std::io::Error::new(
272                             std::io::ErrorKind::Other,
273                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
274                         ))))
275                     }
276                 }
277                 _ => Some(Poll::Ready(Err(std::io::Error::new(
278                     std::io::ErrorKind::Other,
279                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
280                 )))),
281             },
282             Poll::Pending => Some(Poll::Pending),
283         }
284     }
285 
end_read( text_io: &mut TextIo<S>, end_stream: bool, data_len: usize, ) -> Option<Poll<std::io::Result<()>>>286     fn end_read(
287         text_io: &mut TextIo<S>,
288         end_stream: bool,
289         data_len: usize,
290     ) -> Option<Poll<std::io::Result<()>>> {
291         text_io.offset = 0;
292         text_io.remain = None;
293         if end_stream {
294             text_io.is_closed = true;
295             Some(Poll::Ready(Ok(())))
296         } else if data_len == 0 {
297             // no data read and is not end stream.
298             None
299         } else {
300             Some(Poll::Ready(Ok(())))
301         }
302     }
303 
read_remaining_data( text_io: &mut TextIo<S>, buf: &mut HttpReadBuf, ) -> Option<Poll<std::io::Result<()>>>304     fn read_remaining_data(
305         text_io: &mut TextIo<S>,
306         buf: &mut HttpReadBuf,
307     ) -> Option<Poll<std::io::Result<()>>> {
308         if let Some(frame) = &text_io.remain {
309             return match frame.payload() {
310                 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
311                 Payload::Data(data) => {
312                     let data = data.data();
313                     let unfilled_len = buf.remaining();
314                     let data_len = data.len() - text_io.offset;
315                     let fill_len = min(unfilled_len, data_len);
316                     // The peripheral function already ensures that the remaing of buf will not be
317                     // 0.
318                     if unfilled_len < data_len {
319                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
320                         text_io.offset += fill_len;
321                         Some(Poll::Ready(Ok(())))
322                     } else {
323                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
324                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
325                     }
326                 }
327                 _ => Some(Poll::Ready(Err(std::io::Error::new(
328                     std::io::ErrorKind::Other,
329                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
330                 )))),
331             };
332         }
333         None
334     }
335 }
336 
337 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdown(&self)338     fn shutdown(&self) {
339         self.handle.io_shutdown.store(true, Ordering::Release);
340     }
341 }
342 
343 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>344     fn poll_read(
345         self: Pin<&mut Self>,
346         cx: &mut Context<'_>,
347         buf: &mut ReadBuf<'_>,
348     ) -> Poll<std::io::Result<()>> {
349         let text_io = self.get_mut();
350         let mut buf = HttpReadBuf { buf };
351 
352         if buf.remaining() == 0 || text_io.is_closed {
353             return Poll::Ready(Ok(()));
354         }
355         while buf.remaining() != 0 {
356             if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
357                 return result;
358             }
359 
360             let poll_result = text_io
361                 .handle
362                 .receiver
363                 .poll_recv(cx)
364                 .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?;
365 
366             if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
367                 return result;
368             }
369         }
370         Poll::Ready(Ok(()))
371     }
372 }
373 
374 #[cfg(feature = "http2")]
375 #[cfg(test)]
376 mod ut_http2 {
377     use ylong_http::body::TextBody;
378     use ylong_http::h2::Payload;
379     use ylong_http::request::RequestBuilder;
380 
381     use crate::async_impl::conn::http2::build_headers_frame;
382 
383     macro_rules! build_request {
384         (
385             Request: {
386                 Method: $method: expr,
387                 Uri: $uri:expr,
388                 Version: $version: expr,
389                 $(
390                     Header: $req_n: expr, $req_v: expr,
391                 )*
392                 Body: $req_body: expr,
393             }
394         ) => {
395             RequestBuilder::new()
396                 .method($method)
397                 .url($uri)
398                 .version($version)
399                 $(.header($req_n, $req_v))*
400                 .body(TextBody::from_bytes($req_body.as_bytes()))
401                 .expect("Request build failed")
402         }
403     }
404 
405     #[test]
ut_http2_build_headers_frame()406     fn ut_http2_build_headers_frame() {
407         let request = build_request!(
408             Request: {
409             Method: "GET",
410             Uri: "http://127.0.0.1:0/data",
411             Version: "HTTP/2.0",
412             Header: "te", "trailers",
413             Header: "host", "127.0.0.1:0",
414             Body: "Hi",
415         }
416         );
417         let frame = build_headers_frame(1, request.part().clone(), false).unwrap();
418         assert_eq!(frame.flags().bits(), 0x4);
419         let frame = build_headers_frame(1, request.part().clone(), true).unwrap();
420         assert_eq!(frame.stream_id(), 1);
421         assert_eq!(frame.flags().bits(), 0x5);
422         if let Payload::Headers(headers) = frame.payload() {
423             let (pseudo, _headers) = headers.parts();
424             assert_eq!(pseudo.status(), None);
425             assert_eq!(pseudo.scheme().unwrap(), "http");
426             assert_eq!(pseudo.method().unwrap(), "GET");
427             assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
428             assert_eq!(pseudo.path().unwrap(), "/data")
429         } else {
430             panic!("Unexpected frame type")
431         }
432     }
433 
434     /// UT for ensure that the response body(data frame) can read ends normally.
435     ///
436     /// # Brief
437     /// 1. Creates three data frames, one greater than buf, one less than buf,
438     ///    and the last one equal to and finished with buf.
439     /// 2. The response body data is read from TextIo using a buf of 10 bytes.
440     /// 3. The body is all read, and the size is the same as the default.
441     /// 5. Checks that result.
442     #[cfg(feature = "ylong_base")]
443     #[test]
ut_http2_body_poll_read()444     fn ut_http2_body_poll_read() {
445         use std::pin::Pin;
446         use std::sync::atomic::AtomicBool;
447         use std::sync::Arc;
448 
449         use ylong_http::h2::{Data, Frame, FrameFlags};
450         use ylong_runtime::futures::poll_fn;
451         use ylong_runtime::io::{AsyncRead, ReadBuf};
452 
453         use crate::async_impl::conn::http2::TextIo;
454         use crate::util::dispatcher::http2::Http2Conn;
455 
456         let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20);
457         let (req_tx, _req_rx) = crate::runtime::unbounded_channel();
458         let shutdown = Arc::new(AtomicBool::new(false));
459         let mut conn: Http2Conn<()> = Http2Conn::new(1, 20, shutdown, req_tx);
460         conn.receiver.set_receiver(resp_rx);
461         let mut text_io = TextIo::new(conn);
462         let data_1 = Frame::new(
463             1,
464             FrameFlags::new(0),
465             Payload::Data(Data::new(vec![b'a'; 128])),
466         );
467         let data_2 = Frame::new(
468             1,
469             FrameFlags::new(0),
470             Payload::Data(Data::new(vec![b'a'; 2])),
471         );
472         let data_3 = Frame::new(
473             1,
474             FrameFlags::new(1),
475             Payload::Data(Data::new(vec![b'a'; 10])),
476         );
477 
478         ylong_runtime::block_on(async {
479             let _ = resp_tx
480                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_1))
481                 .await;
482             let _ = resp_tx
483                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_2))
484                 .await;
485             let _ = resp_tx
486                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_3))
487                 .await;
488         });
489 
490         ylong_runtime::block_on(async {
491             let mut buf = [0_u8; 10];
492             let mut output_vec = vec![];
493 
494             let mut size = buf.len();
495             // `output_vec < 1024` in order to be able to exit normally in case of an
496             // exception.
497             while size != 0 && output_vec.len() < 1024 {
498                 let mut buffer = ReadBuf::new(buf.as_mut_slice());
499                 poll_fn(|cx| Pin::new(&mut text_io).poll_read(cx, &mut buffer))
500                     .await
501                     .unwrap();
502                 size = buffer.filled_len();
503                 output_vec.extend_from_slice(&buf[..size]);
504             }
505             assert_eq!(output_vec.len(), 140);
506         })
507     }
508 }
509