1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #[cfg(feature = "async")]
15 mod async_utils;
16 
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19 
20 use tokio::runtime::Runtime;
21 
22 macro_rules! define_service_handle {
23     (
24         HTTP;
25     ) => {
26         use tokio::sync::mpsc::{Receiver, Sender};
27 
28         pub struct HttpHandle {
29             pub port: u16,
30 
31             // This channel allows the server to notify the client when it is up and running.
32             pub server_start: Receiver<()>,
33 
34             // This channel allows the client to notify the server when it is ready to shut down.
35             pub client_shutdown: Sender<()>,
36 
37             // This channel allows the server to notify the client when it has shut down.
38             pub server_shutdown: Receiver<()>,
39         }
40     };
41     (
42         HTTPS;
43     ) => {
44         pub struct TlsHandle {
45             pub port: u16,
46         }
47     };
48 }
49 
50 #[macro_export]
51 macro_rules! start_server {
52     (
53         HTTPS;
54         ServerNum: $server_num: expr,
55         Runtime: $runtime: expr,
56         Handles: $handle_vec: expr,
57         ServeFnName: $service_fn: ident,
58     ) => {{
59         for _i in 0..$server_num {
60             let (tx, rx) = std::sync::mpsc::channel();
61             let server_handle = $runtime.spawn(async move {
62                 let handle = start_http_server!(
63                     HTTPS;
64                     $service_fn
65                 );
66                 tx.send(handle)
67                     .expect("Failed to send the handle to the test thread.");
68             });
69             $runtime
70                 .block_on(server_handle)
71                 .expect("Runtime start server coroutine failed");
72             let handle = rx
73                 .recv()
74                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
75             $handle_vec.push(handle);
76         }
77     }};
78     (
79         HTTP;
80         ServerNum: $server_num: expr,
81         Runtime: $runtime: expr,
82         Handles: $handle_vec: expr,
83         ServeFnName: $service_fn: ident,
84     ) => {{
85         for _i in 0..$server_num {
86             let (tx, rx) = std::sync::mpsc::channel();
87             let server_handle = $runtime.spawn(async move {
88                 let mut handle = start_http_server!(
89                     HTTP;
90                     $service_fn
91                 );
92                 handle
93                     .server_start
94                     .recv()
95                     .await
96                     .expect("Start channel (Server-Half) be closed unexpectedly");
97                 tx.send(handle)
98                     .expect("Failed to send the handle to the test thread.");
99             });
100             $runtime
101                 .block_on(server_handle)
102                 .expect("Runtime start server coroutine failed");
103             let handle = rx
104                 .recv()
105                 .expect("Handle send channel (Server-Half) be closed unexpectedly");
106             $handle_vec.push(handle);
107         }
108     }};
109 }
110 
111 #[macro_export]
112 macro_rules! start_http_server {
113     (
114         HTTP;
115         $server_fn: ident
116     ) => {{
117         use hyper::service::{make_service_fn, service_fn};
118         use std::convert::Infallible;
119         use tokio::sync::mpsc::channel;
120 
121         let (start_tx, start_rx) = channel::<()>(1);
122         let (client_tx, mut client_rx) = channel::<()>(1);
123         let (server_tx, server_rx) = channel::<()>(1);
124 
125         let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
126         let addr = tcp_listener.local_addr().expect("get server local address failed!");
127         let port = addr.port();
128         let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
129 
130         tokio::spawn(async move {
131             let make_svc =
132                 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
133             server
134                 .serve(make_svc)
135                 .with_graceful_shutdown(async {
136                     start_tx
137                         .send(())
138                         .await
139                         .expect("Start channel (Client-Half) be closed unexpectedly");
140                     client_rx
141                         .recv()
142                         .await
143                         .expect("Client channel (Client-Half) be closed unexpectedly");
144                 })
145                 .await
146                 .expect("Start server failed");
147             server_tx
148                 .send(())
149                 .await
150                 .expect("Server channel (Client-Half) be closed unexpectedly");
151         });
152 
153         HttpHandle {
154             port,
155             server_start: start_rx,
156             client_shutdown: client_tx,
157             server_shutdown: server_rx,
158         }
159     }};
160     (
161         HTTPS;
162         $service_fn: ident
163     ) => {{
164         let mut port = 10000;
165         let listener = loop {
166             let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
167             match tokio::net::TcpListener::bind(addr).await {
168                 Ok(listener) => break listener,
169                 Err(_) => {
170                     port += 1;
171                     if port == u16::MAX {
172                         port = 10000;
173                     }
174                     continue;
175                 }
176             }
177         };
178         let port = listener.local_addr().unwrap().port();
179 
180         tokio::spawn(async move {
181             let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
182                 .expect("SslAcceptorBuilder error");
183             acceptor
184                 .set_session_id_context(b"test")
185                 .expect("Set session id error");
186             acceptor
187                 .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM)
188                 .expect("Set private key error");
189             acceptor
190                 .set_certificate_chain_file("tests/file/cert.pem")
191                 .expect("Set cert error");
192             acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
193             acceptor.set_alpn_select_callback(|_, client| {
194                 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
195             });
196 
197             let acceptor = acceptor.build();
198 
199             let (stream, _) = listener.accept().await.expect("TCP listener accept error");
200             let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
201             let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
202             core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
203 
204             hyper::server::conn::Http::new()
205                 .http1_only(true)
206                 .http1_keep_alive(true)
207                 .serve_connection(stream, hyper::service::service_fn($service_fn))
208                 .await
209         });
210 
211         TlsHandle {
212             port,
213         }
214     }};
215 }
216 
217 /// Creates a `Request`.
218 #[macro_export]
219 #[cfg(feature = "sync")]
220 macro_rules! ylong_request {
221     (
222         Request: {
223             Method: $method: expr,
224             Host: $host: expr,
225             Port: $port: expr,
226             $(
227                 Header: $req_n: expr, $req_v: expr,
228             )*
229             Body: $req_body: expr,
230         },
231     ) => {
232         ylong_http::request::RequestBuilder::new()
233             .method($method)
234             .url(format!("{}:{}", $host, $port).as_str())
235             $(.header($req_n, $req_v))*
236             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
237             .expect("Request build failed")
238     };
239 }
240 
241 /// Creates a `Request`.
242 #[macro_export]
243 #[cfg(feature = "async")]
244 macro_rules! ylong_request {
245     (
246         Request: {
247             Method: $method: expr,
248             Host: $host: expr,
249             Port: $port: expr,
250             $(
251                 Header: $req_n: expr, $req_v: expr,
252             )*
253             Body: $req_body: expr,
254         },
255     ) => {
256         ylong_http_client::async_impl::RequestBuilder::new()
257              .method($method)
258              .url(format!("{}:{}", $host, $port).as_str())
259              $(.header($req_n, $req_v))*
260              .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
261              .expect("Request build failed")
262     };
263 }
264 
265 /// Sets server async function.
266 #[macro_export]
267 macro_rules! set_server_fn {
268     (
269         ASYNC;
270         $server_fn_name: ident,
271         $(Request: {
272             Method: $method: expr,
273             $(
274                 Header: $req_n: expr, $req_v: expr,
275             )*
276             Body: $req_body: expr,
277         },
278         Response: {
279             Status: $status: expr,
280             Version: $version: expr,
281             $(
282                 Header: $resp_n: expr, $resp_v: expr,
283             )*
284             Body: $resp_body: expr,
285         },)*
286     ) => {
287         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
288             match request.method().as_str() {
289                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
290                 $(
291                     $method => {
292                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
293                         assert_eq!(
294                             "/",
295                             request.uri().to_string(),
296                             "Assert request host failed",
297                         );
298                         assert_eq!(
299                             $version,
300                             format!("{:?}", request.version()),
301                             "Assert request version failed",
302                         );
303                         $(assert_eq!(
304                             $req_v,
305                             request
306                                 .headers()
307                                 .get($req_n)
308                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
309                                 .to_str()
310                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
311                             "Assert request header {} failed", $req_n,
312                         );)*
313                         let body = hyper::body::to_bytes(request.into_body()).await
314                             .expect("Get request body failed");
315                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
316                         Ok(
317                             hyper::Response::builder()
318                                 .version(hyper::Version::HTTP_11)
319                                 .status($status)
320                                 $(.header($resp_n, $resp_v))*
321                                 .body($resp_body.into())
322                                 .expect("Build response failed")
323                         )
324                     },
325                 )*
326                 _ => {panic!("Unrecognized METHOD !");},
327             }
328         }
329 
330     };
331     (
332         SYNC;
333         $server_fn_name: ident,
334         $(Request: {
335             Method: $method: expr,
336             $(
337                 Header: $req_n: expr, $req_v: expr,
338             )*
339             Body: $req_body: expr,
340         },
341         Response: {
342             Status: $status: expr,
343             Version: $version: expr,
344             $(
345                 Header: $resp_n: expr, $resp_v: expr,
346             )*
347             Body: $resp_body: expr,
348         },)*
349     ) => {
350         async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
351             match request.method().as_str() {
352                 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
353                 $(
354                     $method => {
355                         assert_eq!($method, request.method().as_str(), "Assert request method failed");
356                         assert_eq!(
357                             "/",
358                             request.uri().to_string(),
359                             "Assert request uri failed",
360                         );
361                         assert_eq!(
362                             $version,
363                             format!("{:?}", request.version()),
364                             "Assert request version failed",
365                         );
366                         $(assert_eq!(
367                             $req_v,
368                             request
369                                 .headers()
370                                 .get($req_n)
371                                 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
372                                 .to_str()
373                                 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
374                             "Assert request header {} failed", $req_n,
375                         );)*
376                         let body = hyper::body::to_bytes(request.into_body()).await
377                             .expect("Get request body failed");
378                         assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
379                         Ok(
380                             hyper::Response::builder()
381                                 .version(hyper::Version::HTTP_11)
382                                 .status($status)
383                                 $(.header($resp_n, $resp_v))*
384                                 .body($resp_body.into())
385                                 .expect("Build response failed")
386                         )
387                     },
388                 )*
389                 _ => {panic!("Unrecognized METHOD !");},
390             }
391         }
392 
393     };
394 }
395 
396 #[macro_export]
397 macro_rules! ensure_server_shutdown {
398     (ServerHandle: $handle:expr) => {
399         $handle
400             .client_shutdown
401             .send(())
402             .await
403             .expect("Client channel (Server-Half) be closed unexpectedly");
404         $handle
405             .server_shutdown
406             .recv()
407             .await
408             .expect("Server channel (Server-Half) be closed unexpectedly");
409     };
410 }
411 
init_test_work_runtime(thread_num: usize) -> Runtime412 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
413     tokio::runtime::Builder::new_multi_thread()
414         .worker_threads(thread_num)
415         .enable_all()
416         .build()
417         .expect("Build runtime failed.")
418 }
419