1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::sync::mpsc::Receiver;
15 
16 #[cfg(feature = "async")]
17 mod async_utils;
18 
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21 
22 pub struct TcpHandle {
23     pub addr: String,
24 
25     // This channel allows the server to notify the client when it has shut down.
26     pub server_shutdown: Receiver<()>,
27 }
28 
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30     format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32 
33 #[macro_export]
34 macro_rules! start_tcp_server {
35     (
36         ASYNC;
37         Proxy: $proxy: expr,
38         ServerNum: $server_num: expr,
39         Handles: $handle_vec: expr,
40         $(
41         Request: {
42             Method: $method: expr,
43             Version: $req_version: expr,
44             Path: $path: expr,
45             $(
46                 Header: $req_n: expr, $req_v: expr,
47             )*
48             Body: $req_body: expr,
49         },
50         $(
51         Response: {
52             Status: $status: expr,
53             Version: $resp_version: expr,
54             $(
55                 Header: $resp_n: expr, $resp_v: expr,
56             )*
57             Body: $resp_body: expr,
58         },
59         $(Sleep: $during: expr,)?
60         )?
61         )*
62         $(RequestEnds: $end: expr,)?
63         $(Shutdown: $shutdown: expr,)?
64 
65     ) => {{
66         use std::sync::mpsc::channel;
67         use ylong_runtime::net::TcpListener;
68         use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
69 
70         for _i in 0..$server_num {
71             let (rx, tx) = channel();
72             let (rx2, tx2) = channel();
73 
74             ylong_runtime::spawn(async move {
75 
76                 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
77                 let addr = server.local_addr().expect("failed to get server address !");
78                 let handle = TcpHandle {
79                     addr: addr.to_string(),
80                     server_shutdown: tx,
81                 };
82                 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
83 
84                 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
85 
86                 $(
87                 {
88                     let mut buf = [0u8; 4096];
89 
90                     let size = stream.read(&mut buf).await.expect("tcp stream read error !");
91                     let mut length = 0;
92                     let crlf = "\r\n";
93                     let request_str = String::from_utf8_lossy(&buf[..size]);
94 
95                     let request_line = if $proxy {
96                         format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, $req_version, crlf)
97                     } else {
98                         format!("{} {} {}{}", $method, $path, $req_version, crlf)
99                     };
100                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
101                     length += request_line.len();
102 
103                     let host = format_header_str("host", addr.to_string().as_str());
104                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
105                     length += host.len();
106 
107                     $(
108                     let header_str = format_header_str($req_n, $req_v);
109                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
110                     length += header_str.len();
111                     )*
112 
113                     length += crlf.len();
114                     length += $req_body.len();
115 
116                     if length > size {
117                         let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
118                         assert_eq!(&buf[..size2], $req_body.as_bytes());
119                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
120                     } else {
121                         assert_eq!(size, length, "Incorrect total request bytes !");
122                     }
123 
124                     $(
125                     let mut resp_str = String::from(format!("{} {} OK\r\n", $resp_version, $status));
126                     $(
127                     let header = format_header_str($resp_n, $resp_v);
128                     resp_str.push_str(header.as_str());
129                     )*
130                     resp_str.push_str(crlf);
131                     resp_str.push_str($resp_body);
132                     $(ylong_runtime::time::sleep(Duration::from_millis($during)).await;)?
133                     stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
134                     )?
135                 }
136                 )*
137 
138                 $(
139                     stream.shutdown($shutdown).expect("server shutdown failed");
140                 )?
141                 rx.send(()).expect("server send order failed !");
142 
143             });
144 
145             let handle = tx2.recv().expect("recv server handle failed !");
146 
147             $handle_vec.push(handle);
148         }
149     }};
150 
151     (
152         SYNC;
153         ServerNum: $server_num: expr,
154         Handles: $handle_vec: expr,
155         $(Request: {
156             Method: $method: expr,
157             Path: $path: expr,
158             $(
159                 Header: $req_n: expr, $req_v: expr,
160             )*
161             Body: $req_body: expr,
162         },
163         Response: {
164             Status: $status: expr,
165             Version: $version: expr,
166             $(
167                 Header: $resp_n: expr, $resp_v: expr,
168             )*
169             Body: $resp_body: expr,
170         },)*
171 
172     ) => {{
173         use std::net::TcpListener;
174         use std::io::{Read, Write};
175         use std::sync::mpsc::channel;
176         use std::time::Duration;
177 
178         for _i in 0..$server_num {
179             let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
180             let addr = server.local_addr().expect("failed to get server address !");
181             let (rx, tx) = channel();
182 
183             std::thread::spawn( move || {
184 
185                 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
186                 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
187                 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
188 
189                 $(
190                 {
191                     let mut buf = [0u8; 4096];
192 
193                     let size = stream.read(&mut buf).expect("tcp stream read error !");
194                     let mut length = 0;
195                     let crlf = "\r\n";
196                     let request_str = String::from_utf8_lossy(&buf[..size]);
197                     let request_line = format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf);
198                     assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
199 
200                     length += request_line.len();
201 
202                     let accept = format_header_str("accept", "*/*");
203                     assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
204                     length += accept.len();
205 
206                     let host = format_header_str("host", addr.to_string().as_str());
207                     assert!(request_str.contains(host.as_str()), "Incorrect host header!");
208                     length += host.len();
209 
210                     $(
211                     let header_str = format_header_str($req_n, $req_v);
212                     assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
213                     length += header_str.len();
214                     )*
215 
216                     length += crlf.len();
217                     length += $req_body.len();
218 
219                     if length > size {
220                         let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
221                         assert_eq!(&buf[..size2], $req_body.as_bytes());
222                         assert_eq!(size + size2, length, "Incorrect total request bytes !");
223                     } else {
224                         assert_eq!(size, length, "Incorrect total request bytes !");
225                     }
226 
227                     let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
228                     $(
229                     let header = format_header_str($resp_n, $resp_v);
230                     resp_str.push_str(header.as_str());
231                     )*
232                     resp_str.push_str(crlf);
233                     resp_str.push_str($resp_body);
234 
235                     stream.write_all(resp_str.as_bytes()).expect("server write response failed");
236                 }
237                 )*
238                 rx.send(()).expect("server send order failed !");
239 
240             });
241 
242             let handle = TcpHandle {
243                 addr: addr.to_string(),
244                 server_shutdown: tx,
245             };
246             $handle_vec.push(handle);
247         }
248 
249     }}
250 }
251 
252 /// Creates a sync `Request`.
253 #[macro_export]
254 #[cfg(feature = "sync")]
255 macro_rules! build_client_request {
256     (
257         Request: {
258             Method: $method: expr,
259             Path: $path: expr,
260             Addr: $addr: expr,
261             $(
262                 Header: $req_n: expr, $req_v: expr,
263             )*
264             Body: $req_body: expr,
265         },
266     ) => {{
267         ylong_http::request::RequestBuilder::new()
268             .method($method)
269             .url(format!("http://{}{}",$addr, $path).as_str())
270             $(.header($req_n, $req_v))*
271             .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
272             .expect("Request build failed")
273     }};
274 }
275 
276 /// Creates a sync `Request`.
277 #[macro_export]
278 #[cfg(feature = "async")]
279 macro_rules! build_client_request {
280     (
281         Request: {
282             Method: $method: expr,
283             $(Version: $version: expr,)?
284             Path: $path: expr,
285             Addr: $addr: expr,
286             $(
287                 Header: $req_n: expr, $req_v: expr,
288             )*
289             Body: $req_body: expr,
290         },
291     ) => {{
292         ylong_http_client::async_impl::RequestBuilder::new()
293              .method($method)
294              $(.version($version))?
295              .url(format!("http://{}{}",$addr, $path).as_str())
296              $(.header($req_n, $req_v))*
297              .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
298              .expect("Request build failed")
299     }};
300 }
301