1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::sync::mpsc::Receiver;
15
16 #[cfg(feature = "async")]
17 mod async_utils;
18
19 #[cfg(feature = "sync")]
20 mod sync_utils;
21
22 pub struct TcpHandle {
23 pub addr: String,
24
25 // This channel allows the server to notify the client when it has shut down.
26 pub server_shutdown: Receiver<()>,
27 }
28
format_header_str(key: &str, value: &str) -> String29 pub fn format_header_str(key: &str, value: &str) -> String {
30 format!("{}:{}\r\n", key.to_ascii_lowercase(), value)
31 }
32
33 #[macro_export]
34 macro_rules! start_tcp_server {
35 (
36 ASYNC;
37 Proxy: $proxy: expr,
38 ServerNum: $server_num: expr,
39 Handles: $handle_vec: expr,
40 $(
41 Request: {
42 Method: $method: expr,
43 Version: $req_version: expr,
44 Path: $path: expr,
45 $(
46 Header: $req_n: expr, $req_v: expr,
47 )*
48 Body: $req_body: expr,
49 },
50 $(
51 Response: {
52 Status: $status: expr,
53 Version: $resp_version: expr,
54 $(
55 Header: $resp_n: expr, $resp_v: expr,
56 )*
57 Body: $resp_body: expr,
58 },
59 $(Sleep: $during: expr,)?
60 )?
61 )*
62 $(RequestEnds: $end: expr,)?
63 $(Shutdown: $shutdown: expr,)?
64
65 ) => {{
66 use std::sync::mpsc::channel;
67 use ylong_runtime::net::TcpListener;
68 use ylong_runtime::io::{AsyncReadExt, AsyncWriteExt};
69
70 for _i in 0..$server_num {
71 let (rx, tx) = channel();
72 let (rx2, tx2) = channel();
73
74 ylong_runtime::spawn(async move {
75
76 let server = TcpListener::bind("127.0.0.1:0").await.expect("server is failed to bind a address !");
77 let addr = server.local_addr().expect("failed to get server address !");
78 let handle = TcpHandle {
79 addr: addr.to_string(),
80 server_shutdown: tx,
81 };
82 rx2.send(handle).expect("send TcpHandle out coroutine failed !");
83
84 let (mut stream, _client) = server.accept().await.expect("failed to build a tcp stream");
85
86 $(
87 {
88 let mut buf = [0u8; 4096];
89
90 let size = stream.read(&mut buf).await.expect("tcp stream read error !");
91 let mut length = 0;
92 let crlf = "\r\n";
93 let request_str = String::from_utf8_lossy(&buf[..size]);
94
95 let request_line = if $proxy {
96 format!("{} http://{}{} {}{}", $method, addr.to_string().as_str(), $path, $req_version, crlf)
97 } else {
98 format!("{} {} {}{}", $method, $path, $req_version, crlf)
99 };
100 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
101 length += request_line.len();
102
103 let host = format_header_str("host", addr.to_string().as_str());
104 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
105 length += host.len();
106
107 $(
108 let header_str = format_header_str($req_n, $req_v);
109 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
110 length += header_str.len();
111 )*
112
113 length += crlf.len();
114 length += $req_body.len();
115
116 if length > size {
117 let size2 = stream.read(&mut buf).await.expect("tcp stream read error2 !");
118 assert_eq!(&buf[..size2], $req_body.as_bytes());
119 assert_eq!(size + size2, length, "Incorrect total request bytes !");
120 } else {
121 assert_eq!(size, length, "Incorrect total request bytes !");
122 }
123
124 $(
125 let mut resp_str = String::from(format!("{} {} OK\r\n", $resp_version, $status));
126 $(
127 let header = format_header_str($resp_n, $resp_v);
128 resp_str.push_str(header.as_str());
129 )*
130 resp_str.push_str(crlf);
131 resp_str.push_str($resp_body);
132 $(ylong_runtime::time::sleep(Duration::from_millis($during)).await;)?
133 stream.write_all(resp_str.as_bytes()).await.expect("server write response failed");
134 )?
135 }
136 )*
137
138 $(
139 stream.shutdown($shutdown).expect("server shutdown failed");
140 )?
141 rx.send(()).expect("server send order failed !");
142
143 });
144
145 let handle = tx2.recv().expect("recv server handle failed !");
146
147 $handle_vec.push(handle);
148 }
149 }};
150
151 (
152 SYNC;
153 ServerNum: $server_num: expr,
154 Handles: $handle_vec: expr,
155 $(Request: {
156 Method: $method: expr,
157 Path: $path: expr,
158 $(
159 Header: $req_n: expr, $req_v: expr,
160 )*
161 Body: $req_body: expr,
162 },
163 Response: {
164 Status: $status: expr,
165 Version: $version: expr,
166 $(
167 Header: $resp_n: expr, $resp_v: expr,
168 )*
169 Body: $resp_body: expr,
170 },)*
171
172 ) => {{
173 use std::net::TcpListener;
174 use std::io::{Read, Write};
175 use std::sync::mpsc::channel;
176 use std::time::Duration;
177
178 for _i in 0..$server_num {
179 let server = TcpListener::bind("127.0.0.1:0").expect("server is failed to bind a address !");
180 let addr = server.local_addr().expect("failed to get server address !");
181 let (rx, tx) = channel();
182
183 std::thread::spawn( move || {
184
185 let (mut stream, _client) = server.accept().expect("failed to build a tcp stream");
186 stream.set_read_timeout(Some(Duration::from_secs(10))).expect("tcp stream set read time out error !");
187 stream.set_write_timeout(Some(Duration::from_secs(10))).expect("tcp stream set write time out error !");
188
189 $(
190 {
191 let mut buf = [0u8; 4096];
192
193 let size = stream.read(&mut buf).expect("tcp stream read error !");
194 let mut length = 0;
195 let crlf = "\r\n";
196 let request_str = String::from_utf8_lossy(&buf[..size]);
197 let request_line = format!("{} {} {}{}", $method, $path, "HTTP/1.1", crlf);
198 assert!(&buf[..size].starts_with(request_line.as_bytes()), "Incorrect Request-Line!");
199
200 length += request_line.len();
201
202 let accept = format_header_str("accept", "*/*");
203 assert!(request_str.contains(accept.as_str()), "Incorrect accept header!");
204 length += accept.len();
205
206 let host = format_header_str("host", addr.to_string().as_str());
207 assert!(request_str.contains(host.as_str()), "Incorrect host header!");
208 length += host.len();
209
210 $(
211 let header_str = format_header_str($req_n, $req_v);
212 assert!(request_str.contains(header_str.as_str()), "Incorrect {} header!", $req_n);
213 length += header_str.len();
214 )*
215
216 length += crlf.len();
217 length += $req_body.len();
218
219 if length > size {
220 let size2 = stream.read(&mut buf).expect("tcp stream read error2 !");
221 assert_eq!(&buf[..size2], $req_body.as_bytes());
222 assert_eq!(size + size2, length, "Incorrect total request bytes !");
223 } else {
224 assert_eq!(size, length, "Incorrect total request bytes !");
225 }
226
227 let mut resp_str = String::from(format!("{} {} OK\r\n", $version, $status));
228 $(
229 let header = format_header_str($resp_n, $resp_v);
230 resp_str.push_str(header.as_str());
231 )*
232 resp_str.push_str(crlf);
233 resp_str.push_str($resp_body);
234
235 stream.write_all(resp_str.as_bytes()).expect("server write response failed");
236 }
237 )*
238 rx.send(()).expect("server send order failed !");
239
240 });
241
242 let handle = TcpHandle {
243 addr: addr.to_string(),
244 server_shutdown: tx,
245 };
246 $handle_vec.push(handle);
247 }
248
249 }}
250 }
251
252 /// Creates a sync `Request`.
253 #[macro_export]
254 #[cfg(feature = "sync")]
255 macro_rules! build_client_request {
256 (
257 Request: {
258 Method: $method: expr,
259 Path: $path: expr,
260 Addr: $addr: expr,
261 $(
262 Header: $req_n: expr, $req_v: expr,
263 )*
264 Body: $req_body: expr,
265 },
266 ) => {{
267 ylong_http::request::RequestBuilder::new()
268 .method($method)
269 .url(format!("http://{}{}",$addr, $path).as_str())
270 $(.header($req_n, $req_v))*
271 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
272 .expect("Request build failed")
273 }};
274 }
275
276 /// Creates a sync `Request`.
277 #[macro_export]
278 #[cfg(feature = "async")]
279 macro_rules! build_client_request {
280 (
281 Request: {
282 Method: $method: expr,
283 $(Version: $version: expr,)?
284 Path: $path: expr,
285 Addr: $addr: expr,
286 $(
287 Header: $req_n: expr, $req_v: expr,
288 )*
289 Body: $req_body: expr,
290 },
291 ) => {{
292 ylong_http_client::async_impl::RequestBuilder::new()
293 .method($method)
294 $(.version($version))?
295 .url(format!("http://{}{}",$addr, $path).as_str())
296 $(.header($req_n, $req_v))*
297 .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
298 .expect("Request build failed")
299 }};
300 }
301