Skip to content

Commit c73547b

Browse files
committed
Add net module
1 parent abf5dec commit c73547b

File tree

9 files changed

+412
-0
lines changed

9 files changed

+412
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ regex = ["dep:regex", "dep:ouroboros", "dep:quick_cache"]
2222
yaml = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_yaml"]
2323
http = ["dep:http", "dep:http-serde-ext"]
2424
task = ["async"]
25+
net = ["task"]
2526

2627
[dependencies]
2728
bytes = "1"

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ pub mod http;
2727

2828
#[cfg(feature = "task")]
2929
pub mod task;
30+
31+
#[cfg(feature = "net")]
32+
pub mod net;

src/net/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use mlua::{Lua, Result, Table};
2+
3+
/// A loader for the `net` module.
4+
fn loader(lua: &Lua) -> Result<Table> {
5+
let t = lua.create_table()?;
6+
Ok(t)
7+
}
8+
9+
/// Registers the `net` module in the given Lua state.
10+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
11+
let name = name.unwrap_or("@net");
12+
let value = loader(lua)?;
13+
lua.register_module(name, &value)?;
14+
Ok(value)
15+
}
16+
17+
pub mod tcp;

src/net/tcp/listener.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use std::io;
2+
use std::net::SocketAddr;
3+
use std::result::Result as StdResult;
4+
5+
use mlua::{Lua, Result, Table, UserData, UserDataMethods, UserDataRegistry};
6+
use tokio::net::lookup_host;
7+
8+
use super::{SocketOptions, TcpSocket, TcpStream};
9+
10+
pub struct TcpListener(tokio::net::TcpListener);
11+
12+
impl UserData for TcpListener {
13+
fn register(registry: &mut UserDataRegistry<Self>) {
14+
registry.add_method("local_addr", |_, this, ()| Ok(this.0.local_addr()?.to_string()));
15+
16+
registry.add_async_function("listen", listen);
17+
18+
registry.add_async_method("accept", |_, this, ()| async move {
19+
let (stream, _) = lua_try!(this.0.accept().await);
20+
Ok(Ok(TcpStream::from(stream)))
21+
});
22+
}
23+
}
24+
25+
pub async fn listen(
26+
_: Lua,
27+
(addr, params): (String, Option<Table>),
28+
) -> Result<StdResult<TcpListener, String>> {
29+
let addrs = lua_try!(lookup_host(addr).await);
30+
31+
let sock_options = SocketOptions::from_table(&params)?;
32+
let backlog = opt_param!(params, "backlog")?;
33+
34+
let try_listen = |addr: SocketAddr| {
35+
let sock = TcpSocket::new_for_addr(addr)?;
36+
sock.set_options(sock_options)?;
37+
sock.0.set_reuseaddr(true)?;
38+
sock.0.bind(addr)?;
39+
let listener = TcpListener(sock.0.listen(backlog.unwrap_or(1024))?);
40+
io::Result::Ok(listener)
41+
};
42+
43+
let mut last_err = None;
44+
for addr in addrs {
45+
match try_listen(addr) {
46+
Ok(sock) => return Ok(Ok(sock)),
47+
Err(e) => {
48+
last_err = Some(e);
49+
continue;
50+
}
51+
}
52+
}
53+
54+
Ok(Err(last_err.map(|err| err.to_string()).unwrap_or_else(|| {
55+
"could not resolve to any address".to_string()
56+
})))
57+
}

src/net/tcp/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use mlua::{Lua, Result, Table};
2+
3+
pub use listener::{TcpListener, listen};
4+
pub use stream::{TcpStream, connect};
5+
6+
use socket::{SocketOptions, TcpSocket};
7+
8+
/// A loader for the `net/tcp` module.
9+
fn loader(lua: &Lua) -> Result<Table> {
10+
let t = lua.create_table()?;
11+
t.set("TcpListener", lua.create_proxy::<TcpListener>()?)?;
12+
t.set("TcpStream", lua.create_proxy::<TcpStream>()?)?;
13+
t.set("listen", lua.create_async_function(listen)?)?;
14+
t.set("connect", lua.create_async_function(connect)?)?;
15+
Ok(t)
16+
}
17+
18+
/// Registers the `net/tcp` module in the given Lua state.
19+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
20+
let name = name.unwrap_or("@net/tcp");
21+
let value = loader(lua)?;
22+
lua.register_module(name, &value)?;
23+
Ok(value)
24+
}
25+
26+
mod listener;
27+
mod socket;
28+
mod stream;

src/net/tcp/socket.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use std::io::Result as IoResult;
2+
use std::net::SocketAddr;
3+
use std::ops::Deref;
4+
5+
use mlua::{Result, Table};
6+
use tokio::net::TcpSocket as TokioTcpSocket;
7+
8+
pub(crate) struct TcpSocket(pub(crate) TokioTcpSocket);
9+
10+
impl Deref for TcpSocket {
11+
type Target = TokioTcpSocket;
12+
13+
#[inline]
14+
fn deref(&self) -> &Self::Target {
15+
&self.0
16+
}
17+
}
18+
19+
#[derive(Debug, Copy, Clone)]
20+
pub(super) struct SocketOptions {
21+
keepalive: Option<bool>,
22+
nodelay: Option<bool>,
23+
recv_buffer_size: Option<u32>,
24+
send_buffer_size: Option<u32>,
25+
reuseraddr: Option<bool>,
26+
reuseport: Option<bool>,
27+
}
28+
29+
impl TcpSocket {
30+
pub(crate) fn new_for_addr(addr: SocketAddr) -> IoResult<Self> {
31+
let sock = match addr {
32+
SocketAddr::V4(_) => TokioTcpSocket::new_v4()?,
33+
SocketAddr::V6(_) => TokioTcpSocket::new_v6()?,
34+
};
35+
Ok(TcpSocket(sock))
36+
}
37+
38+
pub(crate) fn set_options(&self, options: SocketOptions) -> IoResult<()> {
39+
if let Some(keepalive) = options.keepalive {
40+
self.set_keepalive(keepalive)?;
41+
}
42+
if let Some(nodelay) = options.nodelay {
43+
self.set_nodelay(nodelay)?;
44+
}
45+
if let Some(reuseraddr) = options.reuseraddr {
46+
self.set_reuseaddr(reuseraddr)?;
47+
}
48+
if let Some(reuseport) = options.reuseport {
49+
self.set_reuseport(reuseport)?;
50+
}
51+
if let Some(recv_buffer_size) = options.recv_buffer_size {
52+
self.set_recv_buffer_size(recv_buffer_size)?;
53+
}
54+
if let Some(send_buffer_size) = options.send_buffer_size {
55+
self.set_send_buffer_size(send_buffer_size)?;
56+
}
57+
Ok(())
58+
}
59+
}
60+
61+
impl SocketOptions {
62+
pub(crate) fn from_table(params: &Option<Table>) -> Result<Self> {
63+
Ok(SocketOptions {
64+
keepalive: opt_param!(params, "keepalive")?,
65+
nodelay: opt_param!(params, "nodelay")?,
66+
recv_buffer_size: opt_param!(params, "recv_buffer_size")?,
67+
send_buffer_size: opt_param!(params, "send_buffer_size")?,
68+
reuseraddr: opt_param!(params, "reuseaddr")?,
69+
reuseport: opt_param!(params, "reuseport")?,
70+
})
71+
}
72+
}

src/net/tcp/stream.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
use std::net::SocketAddr;
2+
use std::ops::{Deref, DerefMut};
3+
use std::result::Result as StdResult;
4+
5+
use mlua::{Lua, Result, String as LuaString, Table, UserData, UserDataMethods, UserDataRegistry};
6+
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
7+
use tokio::net::lookup_host;
8+
use tokio::time::timeout;
9+
10+
use super::{SocketOptions, TcpSocket};
11+
use crate::time::Duration;
12+
13+
pub struct TcpStream {
14+
stream: tokio::net::TcpStream,
15+
read_timeout: Option<Duration>,
16+
write_timeout: Option<Duration>,
17+
}
18+
19+
impl Deref for TcpStream {
20+
type Target = tokio::net::TcpStream;
21+
22+
#[inline]
23+
fn deref(&self) -> &Self::Target {
24+
&self.stream
25+
}
26+
}
27+
28+
impl DerefMut for TcpStream {
29+
#[inline]
30+
fn deref_mut(&mut self) -> &mut Self::Target {
31+
&mut self.stream
32+
}
33+
}
34+
35+
macro_rules! with_io_timeout {
36+
($timeout:expr, $fut:expr) => {
37+
match $timeout {
38+
Some(dur) => (timeout(dur.0, $fut).await)
39+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))
40+
.flatten(),
41+
None => $fut.await,
42+
}
43+
};
44+
}
45+
46+
impl From<tokio::net::TcpStream> for TcpStream {
47+
fn from(stream: tokio::net::TcpStream) -> Self {
48+
TcpStream {
49+
stream,
50+
read_timeout: None,
51+
write_timeout: None,
52+
}
53+
}
54+
}
55+
56+
impl UserData for TcpStream {
57+
fn register(registry: &mut UserDataRegistry<Self>) {
58+
registry.add_async_function("connect", connect);
59+
60+
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?.to_string()));
61+
62+
registry.add_method("peer_addr", |_, this, ()| Ok(this.peer_addr()?.to_string()));
63+
64+
registry.add_method_mut("set_read_timeout", |_, this, dur: Option<Duration>| {
65+
this.read_timeout = dur;
66+
Ok(())
67+
});
68+
69+
registry.add_method_mut("set_write_timeout", |_, this, dur: Option<Duration>| {
70+
this.write_timeout = dur;
71+
Ok(())
72+
});
73+
74+
registry.add_async_method_mut("read", |lua, mut this, size: usize| async move {
75+
let mut buf = vec![0; size];
76+
let n = with_io_timeout!(this.read_timeout, this.read(&mut buf));
77+
let n = lua_try!(n);
78+
buf.truncate(n);
79+
Ok(Ok(lua.create_string(buf)?))
80+
});
81+
82+
registry.add_async_method_mut("read_to_end", |lua, mut this, ()| async move {
83+
let mut buf = Vec::new();
84+
let n = with_io_timeout!(this.read_timeout, this.read_to_end(&mut buf));
85+
let _n = lua_try!(n);
86+
Ok(Ok(lua.create_string(buf)?))
87+
});
88+
89+
registry.add_async_method_mut("write", |_, mut this, data: LuaString| async move {
90+
let n = with_io_timeout!(this.write_timeout, this.write(&data.as_bytes()));
91+
let n = lua_try!(n);
92+
Ok(Ok(n))
93+
});
94+
95+
registry.add_async_method_mut("write_all", |_, mut this, data: LuaString| async move {
96+
let r = with_io_timeout!(this.write_timeout, this.write_all(&data.as_bytes()));
97+
lua_try!(r);
98+
Ok(Ok(true))
99+
});
100+
101+
registry.add_async_method_mut("flush", |_, mut this, ()| async move {
102+
let r = with_io_timeout!(this.write_timeout, this.flush());
103+
lua_try!(r);
104+
Ok(Ok(true))
105+
});
106+
107+
registry.add_async_method_mut("shutdown", |_, mut this, ()| async move {
108+
lua_try!(this.shutdown().await);
109+
Ok(Ok(true))
110+
});
111+
}
112+
}
113+
114+
pub async fn connect(
115+
_: Lua,
116+
(addr, params): (String, Option<Table>),
117+
) -> Result<StdResult<TcpStream, String>> {
118+
let addrs = lua_try!(lookup_host(addr).await);
119+
let options = SocketOptions::from_table(&params)?;
120+
121+
let timeout = opt_param!(Duration, params, "timeout")?; // A single timeout for any operation
122+
let connect_timeout = opt_param!(Duration, params, "connect_timeout")?.or(timeout);
123+
let read_timeout = opt_param!(Duration, params, "read_timeout")?.or(timeout);
124+
let write_timeout = opt_param!(Duration, params, "write_timeout")?.or(timeout);
125+
126+
let try_connect = |addr: SocketAddr| async move {
127+
let sock = TcpSocket::new_for_addr(addr)?;
128+
sock.set_options(options)?;
129+
with_io_timeout!(connect_timeout, sock.0.connect(addr))
130+
};
131+
132+
let mut last_err = None;
133+
for addr in addrs {
134+
match try_connect(addr).await {
135+
Ok(stream) => {
136+
return Ok(Ok(TcpStream {
137+
stream,
138+
read_timeout,
139+
write_timeout,
140+
}));
141+
}
142+
Err(e) => {
143+
last_err = Some(e);
144+
continue;
145+
}
146+
}
147+
}
148+
149+
Ok(Err(last_err.map(|err| err.to_string()).unwrap_or_else(|| {
150+
"could not resolve to any address".to_string()
151+
})))
152+
}

tests/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ async fn run_file(modname: &str) -> Result<()> {
2121
mlua_stdlib::regex::register(&lua, None)?;
2222
#[cfg(feature = "http")]
2323
mlua_stdlib::http::register(&lua, None)?;
24+
#[cfg(feature = "net")]
25+
{
26+
mlua_stdlib::net::register(&lua, None)?;
27+
mlua_stdlib::net::tcp::register(&lua, None)?;
28+
}
2429
#[cfg(feature = "task")]
2530
mlua_stdlib::task::register(&lua, None)?;
2631

@@ -91,6 +96,11 @@ include_tests! {
9196
headers,
9297
},
9398

99+
#[cfg(feature = "net")]
100+
net {
101+
tcp,
102+
},
103+
94104
#[cfg(feature = "task")]
95105
task,
96106
}

0 commit comments

Comments
 (0)