1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #![cfg(target_os = "linux")]
15 
16 use std::collections::HashMap;
17 use std::io;
18 use std::io::{IoSlice, IoSliceMut, Read, Write};
19 use std::net::Shutdown;
20 use std::os::fd::{FromRawFd, IntoRawFd};
21 use std::str::from_utf8;
22 
23 use ylong_io::{EventTrait, Events, Interest, Poll, Token, UnixDatagram, UnixListener, UnixStream};
24 
25 const PATH: &str = "/tmp/io_uds_path1";
26 const SERVER: Token = Token(0);
27 
28 /// SDV test for UnixStream.
29 ///
30 /// # Brief
31 /// 1. Create a pair of UnixStream.
32 /// 2. Server sends "Hello client".
33 /// 3. Client reads the message and sends "Hello server".
34 /// 4. Server receives the message
35 #[test]
sdv_uds_stream_test()36 fn sdv_uds_stream_test() {
37     let _ = std::fs::remove_file(PATH);
38 
39     let handle = std::thread::spawn(server);
40 
41     let mut stream = loop {
42         if let Ok(stream) = UnixStream::connect(PATH) {
43             break stream;
44         }
45     };
46     loop {
47         let mut buffer = [0_u8; 1024];
48         let slice = IoSliceMut::new(&mut buffer);
49         std::thread::sleep(std::time::Duration::from_micros(300));
50         match stream.read_vectored(&mut [slice]) {
51             Ok(n) => {
52                 assert_eq!(from_utf8(&buffer[0..n]).unwrap(), "Hello client");
53                 break;
54             }
55             Err(_) => continue,
56         }
57     }
58 
59     let buf = b"Hello server";
60     let slice = IoSlice::new(buf);
61     let n = stream.write_vectored(&[slice]).unwrap();
62     assert_eq!(n, 12);
63 
64     handle.join().unwrap().unwrap();
65     stream.shutdown(Shutdown::Both).unwrap();
66     std::fs::remove_file(PATH).unwrap();
67 }
68 
server() -> io::Result<()>69 fn server() -> io::Result<()> {
70     let poll = Poll::new()?;
71     let mut server = UnixListener::bind(PATH)?;
72 
73     poll.register(&mut server, SERVER, Interest::READABLE)?;
74     let mut events = Events::with_capacity(128);
75     // Map of `Token` -> `UnixListener`.
76     let mut connections = HashMap::new();
77     let mut unique_token = Token(SERVER.0 + 1);
78     for _ in 0..3 {
79         poll.poll(&mut events, None)?;
80 
81         for event in events.iter() {
82             if SERVER == event.token() {
83                 let (mut stream, _) = server.accept()?;
84                 let token = Token(unique_token.0 + 1);
85                 unique_token = Token(unique_token.0 + 1);
86                 poll.register(&mut stream, token, Interest::READABLE | Interest::WRITABLE)?;
87                 connections.insert(token, stream);
88             } else {
89                 match connections.get_mut(&event.token()) {
90                     Some(connection) => {
91                         if event.is_writable() {
92                             match connection.write(b"Hello client") {
93                                 Err(_) => {
94                                     poll.deregister(connection)?;
95                                     poll.register(connection, event.token(), Interest::READABLE)?;
96                                     break;
97                                 }
98                                 Ok(_) => {
99                                     poll.deregister(connection)?;
100                                     poll.register(connection, event.token(), Interest::READABLE)?;
101                                     break;
102                                 }
103                             }
104                         } else if event.is_readable() {
105                             let mut msg_buf = [0_u8; 100];
106                             match connection.read(&mut msg_buf) {
107                                 Ok(0) => poll.deregister(connection)?,
108                                 Ok(n) => {
109                                     if let Ok(str_buf) = from_utf8(&msg_buf[0..n]) {
110                                         assert_eq!(str_buf, "Hello server");
111                                     } else {
112                                         println!("Received (none UTF-8) data: {:?}", &msg_buf);
113                                     }
114                                 }
115                                 Err(_n) => {
116                                     poll.deregister(connection)?;
117                                     break;
118                                 }
119                             }
120                         }
121                     }
122                     None => break,
123                 }
124             }
125         }
126     }
127     Ok(())
128 }
129 
130 /// SDV test for UnixDatagram.
131 ///
132 /// # Brief
133 /// 1. Create a pair of UnixDatagram.
134 /// 2. Sender sends message first.
135 /// 3. Receiver receives message.
136 /// 4. Check if the test results are correct.
137 #[test]
sdv_uds_send_recv()138 fn sdv_uds_send_recv() {
139     let (sender, _) = UnixDatagram::pair().unwrap();
140     let addr = sender.local_addr().unwrap();
141     let fmt = format!("{addr:?}");
142     assert_eq!(&fmt, "(unnamed)");
143 
144     let addr = sender.peer_addr().unwrap();
145     let fmt = format!("{addr:?}");
146     assert_eq!(&fmt, "(unnamed)");
147 
148     let sender2 = sender.try_clone().unwrap();
149     sender2.shutdown(Shutdown::Write).unwrap();
150     let n = sender2.send(b"Hello");
151     assert_eq!(n.unwrap_err().kind(), io::ErrorKind::BrokenPipe);
152 
153     let (sender, receiver) = UnixDatagram::pair().unwrap();
154     let n = sender.send(b"Hello").expect("sender send failed");
155     assert_eq!(n, "Hello".len());
156     let mut buf = [0; 5];
157     let ret = sender2.recv(&mut buf);
158     assert!(ret.is_err());
159 
160     let mut recv_buf = [0_u8; 12];
161     let fd = receiver.into_raw_fd();
162     let receiver = unsafe { UnixDatagram::from_raw_fd(fd) };
163     let len = loop {
164         match receiver.recv_from(&mut recv_buf[..]) {
165             Ok((n, addr)) => {
166                 let fmt = format!("{addr:?}");
167                 assert_eq!(&fmt, "(unnamed)");
168                 break n;
169             }
170             Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
171             Err(e) => panic!("{:?}", e),
172         }
173     };
174     let fmt = format!("{receiver:?}");
175     let expected = format!("fd: FileDesc(OwnedFd {{ fd: {fd} }})");
176     assert!(fmt.contains(&expected));
177     assert!(fmt.contains("local: (unnamed), peer: (unnamed)"));
178 
179     assert_eq!(&recv_buf[..len], b"Hello");
180 }
181