1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #[cfg(feature = "async")]
15 mod async_utils;
16
17 #[cfg(feature = "sync")]
18 mod sync_utils;
19
20 use tokio::runtime::Runtime;
21
22 macro_rules! define_service_handle {
23 (
24 HTTP;
25 ) => {
26 use tokio::sync::mpsc::{Receiver, Sender};
27
28 pub struct HttpHandle {
29 pub port: u16,
30
31 // This channel allows the server to notify the client when it is up and running.
32 pub server_start: Receiver<()>,
33
34 // This channel allows the client to notify the server when it is ready to shut down.
35 pub client_shutdown: Sender<()>,
36
37 // This channel allows the server to notify the client when it has shut down.
38 pub server_shutdown: Receiver<()>,
39 }
40 };
41 (
42 HTTPS;
43 ) => {
44 pub struct TlsHandle {
45 pub port: u16,
46 }
47 };
48 }
49
50 #[macro_export]
51 macro_rules! start_server {
52 (
53 HTTPS;
54 ServerNum: $server_num: expr,
55 Runtime: $runtime: expr,
56 Handles: $handle_vec: expr,
57 ServeFnName: $service_fn: ident,
58 ) => {{
59 for _i in 0..$server_num {
60 let (tx, rx) = std::sync::mpsc::channel();
61 let server_handle = $runtime.spawn(async move {
62 let handle = start_http_server!(
63 HTTPS;
64 $service_fn
65 );
66 tx.send(handle)
67 .expect("Failed to send the handle to the test thread.");
68 });
69 $runtime
70 .block_on(server_handle)
71 .expect("Runtime start server coroutine failed");
72 let handle = rx
73 .recv()
74 .expect("Handle send channel (Server-Half) be closed unexpectedly");
75 $handle_vec.push(handle);
76 }
77 }};
78 (
79 HTTP;
80 ServerNum: $server_num: expr,
81 Runtime: $runtime: expr,
82 Handles: $handle_vec: expr,
83 ServeFnName: $service_fn: ident,
84 ) => {{
85 for _i in 0..$server_num {
86 let (tx, rx) = std::sync::mpsc::channel();
87 let server_handle = $runtime.spawn(async move {
88 let mut handle = start_http_server!(
89 HTTP;
90 $service_fn
91 );
92 handle
93 .server_start
94 .recv()
95 .await
96 .expect("Start channel (Server-Half) be closed unexpectedly");
97 tx.send(handle)
98 .expect("Failed to send the handle to the test thread.");
99 });
100 $runtime
101 .block_on(server_handle)
102 .expect("Runtime start server coroutine failed");
103 let handle = rx
104 .recv()
105 .expect("Handle send channel (Server-Half) be closed unexpectedly");
106 $handle_vec.push(handle);
107 }
108 }};
109 }
110
111 #[macro_export]
112 macro_rules! start_http_server {
113 (
114 HTTP;
115 $server_fn: ident
116 ) => {{
117 use hyper::service::{make_service_fn, service_fn};
118 use std::convert::Infallible;
119 use tokio::sync::mpsc::channel;
120
121 let (start_tx, start_rx) = channel::<()>(1);
122 let (client_tx, mut client_rx) = channel::<()>(1);
123 let (server_tx, server_rx) = channel::<()>(1);
124
125 let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
126 let addr = tcp_listener.local_addr().expect("get server local address failed!");
127 let port = addr.port();
128 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
129
130 tokio::spawn(async move {
131 let make_svc =
132 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
133 server
134 .serve(make_svc)
135 .with_graceful_shutdown(async {
136 start_tx
137 .send(())
138 .await
139 .expect("Start channel (Client-Half) be closed unexpectedly");
140 client_rx
141 .recv()
142 .await
143 .expect("Client channel (Client-Half) be closed unexpectedly");
144 })
145 .await
146 .expect("Start server failed");
147 server_tx
148 .send(())
149 .await
150 .expect("Server channel (Client-Half) be closed unexpectedly");
151 });
152
153 HttpHandle {
154 port,
155 server_start: start_rx,
156 client_shutdown: client_tx,
157 server_shutdown: server_rx,
158 }
159 }};
160 (
161 HTTPS;
162 $service_fn: ident
163 ) => {{
164 let mut port = 10000;
165 let listener = loop {
166 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
167 match tokio::net::TcpListener::bind(addr).await {
168 Ok(listener) => break listener,
169 Err(_) => {
170 port += 1;
171 if port == u16::MAX {
172 port = 10000;
173 }
174 continue;
175 }
176 }
177 };
178 let port = listener.local_addr().unwrap().port();
179
180 tokio::spawn(async move {
181 let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
182 .expect("SslAcceptorBuilder error");
183 acceptor
184 .set_session_id_context(b"test")
185 .expect("Set session id error");
186 acceptor
187 .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM)
188 .expect("Set private key error");
189 acceptor
190 .set_certificate_chain_file("tests/file/cert.pem")
191 .expect("Set cert error");
192 acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
193 acceptor.set_alpn_select_callback(|_, client| {
194 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
195 });
196
197 let acceptor = acceptor.build();
198
199 let (stream, _) = listener.accept().await.expect("TCP listener accept error");
200 let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
201 let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
202 core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
203
204 hyper::server::conn::Http::new()
205 .http1_only(true)
206 .http1_keep_alive(true)
207 .serve_connection(stream, hyper::service::service_fn($service_fn))
208 .await
209 });
210
211 TlsHandle {
212 port,
213 }
214 }};
215 }
216
217 /// Creates a `Request`.
218 #[macro_export]
219 #[cfg(feature = "sync")]
220 macro_rules! ylong_request {
221 (
222 Request: {
223 Method: $method: expr,
224 Host: $host: expr,
225 Port: $port: expr,
226 $(
227 Header: $req_n: expr, $req_v: expr,
228 )*
229 Body: $req_body: expr,
230 },
231 ) => {
232 ylong_http::request::RequestBuilder::new()
233 .method($method)
234 .url(format!("{}:{}", $host, $port).as_str())
235 $(.header($req_n, $req_v))*
236 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
237 .expect("Request build failed")
238 };
239 }
240
241 /// Creates a `Request`.
242 #[macro_export]
243 #[cfg(feature = "async")]
244 macro_rules! ylong_request {
245 (
246 Request: {
247 Method: $method: expr,
248 Host: $host: expr,
249 Port: $port: expr,
250 $(
251 Header: $req_n: expr, $req_v: expr,
252 )*
253 Body: $req_body: expr,
254 },
255 ) => {
256 ylong_http_client::async_impl::RequestBuilder::new()
257 .method($method)
258 .url(format!("{}:{}", $host, $port).as_str())
259 $(.header($req_n, $req_v))*
260 .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
261 .expect("Request build failed")
262 };
263 }
264
265 /// Sets server async function.
266 #[macro_export]
267 macro_rules! set_server_fn {
268 (
269 ASYNC;
270 $server_fn_name: ident,
271 $(Request: {
272 Method: $method: expr,
273 $(
274 Header: $req_n: expr, $req_v: expr,
275 )*
276 Body: $req_body: expr,
277 },
278 Response: {
279 Status: $status: expr,
280 Version: $version: expr,
281 $(
282 Header: $resp_n: expr, $resp_v: expr,
283 )*
284 Body: $resp_body: expr,
285 },)*
286 ) => {
287 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
288 match request.method().as_str() {
289 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
290 $(
291 $method => {
292 assert_eq!($method, request.method().as_str(), "Assert request method failed");
293 assert_eq!(
294 "/",
295 request.uri().to_string(),
296 "Assert request host failed",
297 );
298 assert_eq!(
299 $version,
300 format!("{:?}", request.version()),
301 "Assert request version failed",
302 );
303 $(assert_eq!(
304 $req_v,
305 request
306 .headers()
307 .get($req_n)
308 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
309 .to_str()
310 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
311 "Assert request header {} failed", $req_n,
312 );)*
313 let body = hyper::body::to_bytes(request.into_body()).await
314 .expect("Get request body failed");
315 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
316 Ok(
317 hyper::Response::builder()
318 .version(hyper::Version::HTTP_11)
319 .status($status)
320 $(.header($resp_n, $resp_v))*
321 .body($resp_body.into())
322 .expect("Build response failed")
323 )
324 },
325 )*
326 _ => {panic!("Unrecognized METHOD !");},
327 }
328 }
329
330 };
331 (
332 SYNC;
333 $server_fn_name: ident,
334 $(Request: {
335 Method: $method: expr,
336 $(
337 Header: $req_n: expr, $req_v: expr,
338 )*
339 Body: $req_body: expr,
340 },
341 Response: {
342 Status: $status: expr,
343 Version: $version: expr,
344 $(
345 Header: $resp_n: expr, $resp_v: expr,
346 )*
347 Body: $resp_body: expr,
348 },)*
349 ) => {
350 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
351 match request.method().as_str() {
352 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
353 $(
354 $method => {
355 assert_eq!($method, request.method().as_str(), "Assert request method failed");
356 assert_eq!(
357 "/",
358 request.uri().to_string(),
359 "Assert request uri failed",
360 );
361 assert_eq!(
362 $version,
363 format!("{:?}", request.version()),
364 "Assert request version failed",
365 );
366 $(assert_eq!(
367 $req_v,
368 request
369 .headers()
370 .get($req_n)
371 .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
372 .to_str()
373 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
374 "Assert request header {} failed", $req_n,
375 );)*
376 let body = hyper::body::to_bytes(request.into_body()).await
377 .expect("Get request body failed");
378 assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
379 Ok(
380 hyper::Response::builder()
381 .version(hyper::Version::HTTP_11)
382 .status($status)
383 $(.header($resp_n, $resp_v))*
384 .body($resp_body.into())
385 .expect("Build response failed")
386 )
387 },
388 )*
389 _ => {panic!("Unrecognized METHOD !");},
390 }
391 }
392
393 };
394 }
395
396 #[macro_export]
397 macro_rules! ensure_server_shutdown {
398 (ServerHandle: $handle:expr) => {
399 $handle
400 .client_shutdown
401 .send(())
402 .await
403 .expect("Client channel (Server-Half) be closed unexpectedly");
404 $handle
405 .server_shutdown
406 .recv()
407 .await
408 .expect("Server channel (Server-Half) be closed unexpectedly");
409 };
410 }
411
init_test_work_runtime(thread_num: usize) -> Runtime412 pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
413 tokio::runtime::Builder::new_multi_thread()
414 .worker_threads(thread_num)
415 .enable_all()
416 .build()
417 .expect("Build runtime failed.")
418 }
419