// Copyright (c) 2023 Huawei Device Co., Ltd. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Frame recv coroutine. use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use ylong_http::h2::{ ErrorCode, Frame, FrameDecoder, FrameKind, FramesIntoIter, H2Error, Payload, Setting, }; use crate::runtime::{AsyncRead, BoundedSender, ReadBuf, ReadHalf, SendError}; use crate::util::dispatcher::http2::{ DispatchErrorKind, OutputMessage, SettingsState, SettingsSync, }; pub(crate) type OutputSendFut = Pin>> + Send + Sync>>; #[derive(Copy, Clone)] enum DecodeState { Read, Send, Exit(DispatchErrorKind), } pub(crate) struct RecvData { decoder: FrameDecoder, settings: Arc>, reader: ReadHalf, state: DecodeState, next_state: DecodeState, resp_tx: BoundedSender, curr_message: Option, pending_iter: Option, } impl Future for RecvData { type Output = Result<(), DispatchErrorKind>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let receiver = self.get_mut(); receiver.poll_read_frame(cx) } } impl RecvData { pub(crate) fn new( decoder: FrameDecoder, settings: Arc>, reader: ReadHalf, resp_tx: BoundedSender, ) -> Self { Self { decoder, settings, reader, state: DecodeState::Read, next_state: DecodeState::Read, resp_tx, curr_message: None, pending_iter: None, } } fn poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll> { let mut buf = [0u8; 1024]; loop { match self.state { DecodeState::Read => { let mut read_buf = ReadBuf::new(&mut buf); match Pin::new(&mut self.reader).poll_read(cx, &mut read_buf) { Poll::Ready(Err(e)) => { return self.transmit_error(cx, e.into()); } Poll::Ready(Ok(())) => {} Poll::Pending => { return Poll::Pending; } } let read = read_buf.filled().len(); if read == 0 { return self.transmit_error(cx, DispatchErrorKind::Disconnect); } match self.decoder.decode(&buf[..read]) { Ok(frames) => match self.poll_iterator_frames(cx, frames.into_iter()) { Poll::Ready(Ok(_)) => {} Poll::Ready(Err(e)) => { return Poll::Ready(Err(e)); } Poll::Pending => { self.next_state = DecodeState::Read; } }, Err(e) => { match self.transmit_message(cx, OutputMessage::OutputExit(e.into())) { Poll::Ready(Err(_)) => { return Poll::Ready(Err(DispatchErrorKind::ChannelClosed)) } Poll::Ready(Ok(_)) => {} Poll::Pending => { self.next_state = DecodeState::Read; return Poll::Pending; } } } } } DecodeState::Send => { match self.poll_blocked_task(cx) { Poll::Ready(Ok(_)) => { self.state = self.next_state; // Reset next state. self.next_state = DecodeState::Read; } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } DecodeState::Exit(e) => { return Poll::Ready(Err(e)); } } } } fn poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll> { if let Some(mut task) = self.curr_message.take() { match task.as_mut().poll(cx) { Poll::Ready(Ok(_)) => {} Poll::Ready(Err(_)) => { return Poll::Ready(Err(DispatchErrorKind::ChannelClosed)); } Poll::Pending => { self.curr_message = Some(task); return Poll::Pending; } } } if let Some(iter) = self.pending_iter.take() { return self.poll_iterator_frames(cx, iter); } Poll::Ready(Ok(())) } fn poll_iterator_frames( &mut self, cx: &mut Context<'_>, mut iter: FramesIntoIter, ) -> Poll> { while let Some(kind) = iter.next() { match kind { FrameKind::Complete(frame) => { // TODO Whether to continue processing the remaining frames after connection // error occurs in the Settings frame. let message = if let Err(e) = self.update_settings(&frame) { OutputMessage::OutputExit(DispatchErrorKind::H2(e)) } else { OutputMessage::Output(frame) }; match self.transmit_message(cx, message) { Poll::Ready(Ok(_)) => {} Poll::Ready(Err(e)) => { return Poll::Ready(Err(e)); } Poll::Pending => { self.pending_iter = Some(iter); return Poll::Pending; } } } FrameKind::Partial => {} } } Poll::Ready(Ok(())) } fn transmit_error( &mut self, cx: &mut Context<'_>, exit_err: DispatchErrorKind, ) -> Poll> { match self.transmit_message(cx, OutputMessage::OutputExit(exit_err)) { Poll::Ready(_) => Poll::Ready(Err(exit_err)), Poll::Pending => { self.next_state = DecodeState::Exit(exit_err); Poll::Pending } } } fn transmit_message( &mut self, cx: &mut Context<'_>, message: OutputMessage, ) -> Poll> { let mut task = { let sender = self.resp_tx.clone(); let ft = async move { sender.send(message).await }; Box::pin(ft) }; match task.as_mut().poll(cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), // The current coroutine sending the request exited prematurely. Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), Poll::Pending => { self.state = DecodeState::Send; self.curr_message = Some(task); Poll::Pending } } } fn update_settings(&mut self, frame: &Frame) -> Result<(), H2Error> { if let Payload::Settings(_settings) = frame.payload() { if frame.flags().is_ack() { self.update_decoder_settings()?; } } Ok(()) } fn update_decoder_settings(&mut self) -> Result<(), H2Error> { let connection = self.settings.lock().unwrap(); match &connection.settings { SettingsState::Acknowledging(settings) => { for setting in settings.get_settings() { if let Setting::MaxHeaderListSize(size) = setting { self.decoder.set_max_header_list_size(*size as usize); } if let Setting::MaxFrameSize(size) = setting { self.decoder.set_max_frame_size(*size)?; } } Ok(()) } SettingsState::Synced => Err(H2Error::ConnectionError(ErrorCode::ConnectError)), } } }