// Copyright (c) 2023 Huawei Device Co., Ltd. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::cmp::min; use std::ops::Deref; use std::pin::Pin; use std::sync::atomic::Ordering; use std::task::{Context, Poll}; use ylong_http::error::HttpError; use ylong_http::h2; use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders}; use ylong_http::headers::Headers; use ylong_http::request::uri::Scheme; use ylong_http::request::RequestPart; use ylong_http::response::status::StatusCode; use ylong_http::response::ResponsePart; use crate::async_impl::conn::StreamData; use crate::async_impl::request::Message; use crate::async_impl::{HttpBody, Response}; use crate::error::{ErrorKind, HttpClientError}; use crate::runtime::{AsyncRead, ReadBuf}; use crate::util::dispatcher::http2::Http2Conn; use crate::util::h2::{BodyDataRef, RequestWrapper}; use crate::util::normalizer::BodyLengthParser; const UNUSED_FLAG: u8 = 0x0; pub(crate) async fn request( mut conn: Http2Conn, mut message: Message, ) -> Result where S: Sync + Send + Unpin + 'static, { message .interceptor .intercept_request(message.request.ref_mut())?; let part = message.request.ref_mut().part().clone(); // TODO Implement trailer. let headers = build_headers_frame(conn.id, part, false) .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?; let data = BodyDataRef::new(message.request.clone()); let stream = RequestWrapper { header: headers, data, }; conn.send_frame_to_controller(stream)?; let frame = conn.receiver.recv().await?; frame_2_response(conn, frame, message) } fn frame_2_response( conn: Http2Conn, headers_frame: Frame, mut message: Message, ) -> Result where S: Sync + Send + Unpin + 'static, { let part = match headers_frame.payload() { Payload::Headers(headers) => { let (pseudo, fields) = headers.parts(); let status_code = match pseudo.status() { Some(status) => StatusCode::from_bytes(status.as_bytes()) .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?, None => { return Err(HttpClientError::from_error( ErrorKind::Request, HttpError::from(H2Error::StreamError(conn.id, ErrorCode::ProtocolError)), )); } }; ResponsePart { version: ylong_http::version::Version::HTTP2, status: status_code, headers: fields.clone(), } } Payload::RstStream(reset) => { return Err(HttpClientError::from_error( ErrorKind::Request, HttpError::from(H2Error::StreamError( conn.id, ErrorCode::try_from(reset.error_code()).unwrap_or(ErrorCode::ProtocolError), )), )); } _ => { return Err(HttpClientError::from_error( ErrorKind::Request, HttpError::from(H2Error::StreamError(conn.id, ErrorCode::ProtocolError)), )); } }; let text_io = TextIo::new(conn); // TODO Can http2 have no content-length header field and rely only on the // end_stream flag? flag has a Body let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() { Ok(length) => length, Err(e) => { return Err(e); } }; let body = HttpBody::new(message.interceptor, length, Box::new(text_io), &[0u8; 0])?; Ok(Response::new( ylong_http::response::Response::from_raw_parts(part, body), )) } pub(crate) fn build_headers_frame( id: u32, mut part: RequestPart, is_end_stream: bool, ) -> Result { remove_connection_specific_headers(&mut part.headers)?; let pseudo = build_pseudo_headers(&mut part)?; let mut header_part = h2::Parts::new(); header_part.set_header_lines(part.headers); header_part.set_pseudo(pseudo); let headers_payload = h2::Headers::new(header_part); let mut flag = FrameFlags::new(UNUSED_FLAG); flag.set_end_headers(true); if is_end_stream { flag.set_end_stream(true); } Ok(Frame::new( id as usize, flag, Payload::Headers(headers_payload), )) } // Illegal headers validation in http2. // [`Connection-Specific Headers`] implementation. // // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header- fn remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError> { const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[ "connection", "keep-alive", "proxy-connection", "upgrade", "transfer-encoding", ]; for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() { headers.remove(*specific_header); } if let Some(te_ref) = headers.get("te") { let te = te_ref.to_string()?; if te.as_str() != "trailers" { headers.remove("te"); } } Ok(()) } fn build_pseudo_headers(request_part: &mut RequestPart) -> Result { let mut pseudo = PseudoHeaders::default(); match request_part.uri.scheme() { Some(scheme) => { pseudo.set_scheme(Some(String::from(scheme.as_str()))); } None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))), } pseudo.set_method(Some(String::from(request_part.method.as_str()))); pseudo.set_path( request_part .uri .path_and_query() .or_else(|| Some(String::from("/"))), ); let host = request_part .headers .remove("host") .and_then(|auth| auth.to_string().ok()); pseudo.set_authority(host); Ok(pseudo) } struct TextIo { pub(crate) handle: Http2Conn, pub(crate) offset: usize, pub(crate) remain: Option, pub(crate) is_closed: bool, } struct HttpReadBuf<'a, 'b> { buf: &'a mut ReadBuf<'b>, } impl<'a, 'b> HttpReadBuf<'a, 'b> { pub(crate) fn append_slice(&mut self, buf: &[u8]) { #[cfg(feature = "ylong_base")] self.buf.append(buf); #[cfg(feature = "tokio_base")] self.buf.put_slice(buf); } } impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> { type Target = ReadBuf<'b>; fn deref(&self) -> &Self::Target { self.buf } } impl TextIo where S: Sync + Send + Unpin + 'static, { pub(crate) fn new(handle: Http2Conn) -> Self { Self { handle, offset: 0, remain: None, is_closed: false, } } fn match_channel_message( poll_result: Poll, text_io: &mut TextIo, buf: &mut HttpReadBuf, ) -> Option>> { match poll_result { Poll::Ready(frame) => match frame.payload() { Payload::Headers(_) => { text_io.remain = Some(frame); text_io.offset = 0; Some(Poll::Ready(Ok(()))) } Payload::Data(data) => { let data = data.data(); let unfilled_len = buf.remaining(); let data_len = data.len(); let fill_len = min(data_len, unfilled_len); if unfilled_len < data_len { buf.append_slice(&data[..fill_len]); text_io.offset += fill_len; text_io.remain = Some(frame); Some(Poll::Ready(Ok(()))) } else { buf.append_slice(&data[..fill_len]); Self::end_read(text_io, frame.flags().is_end_stream(), data_len) } } Payload::RstStream(reset) => { if reset.is_no_error() { text_io.is_closed = true; Some(Poll::Ready(Ok(()))) } else { Some(Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::Other, HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)), )))) } } _ => Some(Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::Other, HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)), )))), }, Poll::Pending => Some(Poll::Pending), } } fn end_read( text_io: &mut TextIo, end_stream: bool, data_len: usize, ) -> Option>> { text_io.offset = 0; text_io.remain = None; if end_stream { text_io.is_closed = true; Some(Poll::Ready(Ok(()))) } else if data_len == 0 { // no data read and is not end stream. None } else { Some(Poll::Ready(Ok(()))) } } fn read_remaining_data( text_io: &mut TextIo, buf: &mut HttpReadBuf, ) -> Option>> { if let Some(frame) = &text_io.remain { return match frame.payload() { Payload::Headers(_) => Some(Poll::Ready(Ok(()))), Payload::Data(data) => { let data = data.data(); let unfilled_len = buf.remaining(); let data_len = data.len() - text_io.offset; let fill_len = min(unfilled_len, data_len); // The peripheral function already ensures that the remaing of buf will not be // 0. if unfilled_len < data_len { buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]); text_io.offset += fill_len; Some(Poll::Ready(Ok(()))) } else { buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]); Self::end_read(text_io, frame.flags().is_end_stream(), data_len) } } _ => Some(Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::Other, HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)), )))), }; } None } } impl StreamData for TextIo { fn shutdown(&self) { self.handle.io_shutdown.store(true, Ordering::Release); } } impl AsyncRead for TextIo { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let text_io = self.get_mut(); let mut buf = HttpReadBuf { buf }; if buf.remaining() == 0 || text_io.is_closed { return Poll::Ready(Ok(())); } while buf.remaining() != 0 { if let Some(result) = Self::read_remaining_data(text_io, &mut buf) { return result; } let poll_result = text_io .handle .receiver .poll_recv(cx) .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?; if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) { return result; } } Poll::Ready(Ok(())) } } #[cfg(feature = "http2")] #[cfg(test)] mod ut_http2 { use ylong_http::body::TextBody; use ylong_http::h2::Payload; use ylong_http::request::RequestBuilder; use crate::async_impl::conn::http2::build_headers_frame; macro_rules! build_request { ( Request: { Method: $method: expr, Uri: $uri:expr, Version: $version: expr, $( Header: $req_n: expr, $req_v: expr, )* Body: $req_body: expr, } ) => { RequestBuilder::new() .method($method) .url($uri) .version($version) $(.header($req_n, $req_v))* .body(TextBody::from_bytes($req_body.as_bytes())) .expect("Request build failed") } } #[test] fn ut_http2_build_headers_frame() { let request = build_request!( Request: { Method: "GET", Uri: "http://127.0.0.1:0/data", Version: "HTTP/2.0", Header: "te", "trailers", Header: "host", "127.0.0.1:0", Body: "Hi", } ); let frame = build_headers_frame(1, request.part().clone(), false).unwrap(); assert_eq!(frame.flags().bits(), 0x4); let frame = build_headers_frame(1, request.part().clone(), true).unwrap(); assert_eq!(frame.stream_id(), 1); assert_eq!(frame.flags().bits(), 0x5); if let Payload::Headers(headers) = frame.payload() { let (pseudo, _headers) = headers.parts(); assert_eq!(pseudo.status(), None); assert_eq!(pseudo.scheme().unwrap(), "http"); assert_eq!(pseudo.method().unwrap(), "GET"); assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0"); assert_eq!(pseudo.path().unwrap(), "/data") } else { panic!("Unexpected frame type") } } /// UT for ensure that the response body(data frame) can read ends normally. /// /// # Brief /// 1. Creates three data frames, one greater than buf, one less than buf, /// and the last one equal to and finished with buf. /// 2. The response body data is read from TextIo using a buf of 10 bytes. /// 3. The body is all read, and the size is the same as the default. /// 5. Checks that result. #[cfg(feature = "ylong_base")] #[test] fn ut_http2_body_poll_read() { use std::pin::Pin; use std::sync::atomic::AtomicBool; use std::sync::Arc; use ylong_http::h2::{Data, Frame, FrameFlags}; use ylong_runtime::futures::poll_fn; use ylong_runtime::io::{AsyncRead, ReadBuf}; use crate::async_impl::conn::http2::TextIo; use crate::util::dispatcher::http2::Http2Conn; let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20); let (req_tx, _req_rx) = crate::runtime::unbounded_channel(); let shutdown = Arc::new(AtomicBool::new(false)); let mut conn: Http2Conn<()> = Http2Conn::new(1, 20, shutdown, req_tx); conn.receiver.set_receiver(resp_rx); let mut text_io = TextIo::new(conn); let data_1 = Frame::new( 1, FrameFlags::new(0), Payload::Data(Data::new(vec![b'a'; 128])), ); let data_2 = Frame::new( 1, FrameFlags::new(0), Payload::Data(Data::new(vec![b'a'; 2])), ); let data_3 = Frame::new( 1, FrameFlags::new(1), Payload::Data(Data::new(vec![b'a'; 10])), ); ylong_runtime::block_on(async { let _ = resp_tx .send(crate::util::dispatcher::http2::RespMessage::Output(data_1)) .await; let _ = resp_tx .send(crate::util::dispatcher::http2::RespMessage::Output(data_2)) .await; let _ = resp_tx .send(crate::util::dispatcher::http2::RespMessage::Output(data_3)) .await; }); ylong_runtime::block_on(async { let mut buf = [0_u8; 10]; let mut output_vec = vec![]; let mut size = buf.len(); // `output_vec < 1024` in order to be able to exit normally in case of an // exception. while size != 0 && output_vec.len() < 1024 { let mut buffer = ReadBuf::new(buf.as_mut_slice()); poll_fn(|cx| Pin::new(&mut text_io).poll_read(cx, &mut buffer)) .await .unwrap(); size = buffer.filled_len(); output_vec.extend_from_slice(&buf[..size]); } assert_eq!(output_vec.len(), 140); }) } }