Skip to content

Commit 3a5a556

Browse files
committed
Add net/tls module
1 parent 648921d commit 3a5a556

26 files changed

+1330
-90
lines changed

Cargo.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ yaml = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_yaml"]
2323
http = ["dep:http", "dep:http-serde-ext"]
2424
task = ["async"]
2525
net = ["task"]
26+
tls = ["net", "dep:rustls", "dep:tokio-rustls", "dep:rustls-pemfile", "dep:rustls-native-certs", "dep:webpki-roots"]
2627

2728
[dependencies]
2829
bytes = "1"
29-
mlua = { version = "0.11", features = ["error-send"] }
30+
mlua = { version = "0.11", features = ["error-send", "macros"] }
3031
ouroboros = { version = "0.18", optional = true }
3132
serde = { version = "1.0", optional = true }
3233
serde_json = { version = "1.0", optional = true }
@@ -43,5 +44,12 @@ http-serde-ext = { version = "1.0", optional = true }
4344
tokio = { version = "1", features = ["full"], optional = true }
4445
tokio-util = { version = "0.7", features = ["time"], optional = true }
4546

47+
# tls
48+
rustls = { version = "0.23", optional = true, default-features = false, features = ["aws-lc-rs"] }
49+
tokio-rustls = { version = "0.26", optional = true, default-features = false }
50+
rustls-pemfile = { version = "2", optional = true }
51+
rustls-native-certs = { version = "0.8", optional = true }
52+
webpki-roots = { version = "1", optional = true }
53+
4654
[dev-dependencies]
4755
tokio = { version = "1", features = ["full"] }

src/macros.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ macro_rules! opt_param {
2525
Ok(None) => Ok(None),
2626
Err(err) => {
2727
use mlua::ErrorContext as _;
28-
Err(err.with_context(|_| format!("invalid `{}`", $name)))
28+
Err(err.with_context(|_| format!("invalid `{}` param", $name)))
2929
}
3030
}
3131
};
@@ -39,7 +39,20 @@ macro_rules! opt_param {
3939
Ok(None) => Ok(None),
4040
Err(err) => {
4141
use mlua::ErrorContext as _;
42-
Err(err.with_context(|_| format!("invalid `{}`", $name)))
42+
Err(err.with_context(|_| format!("invalid `{}` param", $name)))
43+
}
44+
}
45+
};
46+
}
47+
48+
macro_rules! param {
49+
($table:expr, $name:expr) => {
50+
match $table.raw_get::<Option<_>>($name) {
51+
Ok(Some(v)) => Ok(v),
52+
Ok(None) => Err(mlua::Error::runtime(format!("`{}` param is required", $name))),
53+
Err(err) => {
54+
use mlua::ErrorContext as _;
55+
Err(err.with_context(|_| format!("invalid `{}` param", $name)))
4356
}
4457
}
4558
};

src/net/common.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! Common types and traits shared between client and server TLS implementations.
2+
3+
use std::borrow::Cow;
4+
use std::{fmt, io};
5+
6+
use mlua::{IntoLua, Lua, Result, Value};
7+
8+
/// Socket address that can be either TCP or Unix domain socket.
9+
pub enum AnySocketAddr {
10+
Tcp(std::net::SocketAddr),
11+
#[cfg(unix)]
12+
Unix(tokio::net::unix::SocketAddr),
13+
}
14+
15+
impl fmt::Display for AnySocketAddr {
16+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17+
match self {
18+
AnySocketAddr::Tcp(addr) => write!(f, "{addr}"),
19+
#[cfg(unix)]
20+
AnySocketAddr::Unix(addr) => {
21+
let path = addr
22+
.as_pathname()
23+
.map(|p| p.to_string_lossy())
24+
.unwrap_or_else(|| Cow::Borrowed("(unnamed)"));
25+
write!(f, "{path}")
26+
}
27+
}
28+
}
29+
}
30+
31+
impl IntoLua for AnySocketAddr {
32+
fn into_lua(self, lua: &Lua) -> Result<Value> {
33+
lua.create_string(self.to_string()).map(Value::String)
34+
}
35+
}
36+
37+
/// Trait for getting local and peer addresses from various stream types.
38+
pub trait AddressProvider {
39+
fn local_addr(&self) -> io::Result<AnySocketAddr>;
40+
fn peer_addr(&self) -> io::Result<AnySocketAddr>;
41+
}
42+
43+
impl AddressProvider for tokio::net::TcpStream {
44+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
45+
Ok(AnySocketAddr::Tcp(self.local_addr()?))
46+
}
47+
48+
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
49+
Ok(AnySocketAddr::Tcp(self.peer_addr()?))
50+
}
51+
}
52+
53+
#[cfg(unix)]
54+
impl AddressProvider for tokio::net::UnixStream {
55+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
56+
Ok(AnySocketAddr::Unix(self.local_addr()?))
57+
}
58+
59+
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
60+
Ok(AnySocketAddr::Unix(self.peer_addr()?))
61+
}
62+
}

src/net/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use mlua::{Lua, Result, Table};
22

3+
pub use common::{AddressProvider, AnySocketAddr};
4+
35
/// A loader for the `net` module.
46
fn loader(lua: &Lua) -> Result<Table> {
57
let t = lua.create_table()?;
@@ -25,7 +27,10 @@ macro_rules! with_io_timeout {
2527
};
2628
}
2729

28-
pub mod tcp;
30+
mod common;
2931

32+
pub mod tcp;
33+
#[cfg(feature = "tls")]
34+
pub mod tls;
3035
#[cfg(unix)]
3136
pub mod unix;

src/net/tcp/listener.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@ use mlua::{Lua, Result, Table, UserData, UserDataMethods, UserDataRegistry};
66
use tokio::net::lookup_host;
77

88
use super::{SocketOptions, TcpSocket, TcpStream};
9+
use crate::net::common::AnySocketAddr;
910

10-
pub struct TcpListener(tokio::net::TcpListener);
11+
pub struct TcpListener(pub(crate) tokio::net::TcpListener);
12+
13+
impl TcpListener {
14+
pub(crate) fn local_addr(&self) -> io::Result<AnySocketAddr> {
15+
self.0.local_addr().map(AnySocketAddr::Tcp)
16+
}
17+
}
1118

1219
impl UserData for TcpListener {
1320
fn register(registry: &mut UserDataRegistry<Self>) {
14-
registry.add_method("local_addr", |_, this, ()| Ok(this.0.local_addr()?.to_string()));
21+
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?));
1522

1623
registry.add_async_function("listen", listen);
1724

@@ -24,9 +31,10 @@ impl UserData for TcpListener {
2431

2532
pub async fn listen(
2633
_: Lua,
27-
(addr, params): (String, Option<Table>),
34+
(addr, port, params): (String, Option<u16>, Option<Table>),
2835
) -> Result<StdResult<TcpListener, String>> {
29-
let addrs = lua_try!(lookup_host(addr).await);
36+
let port = port.unwrap_or(0);
37+
let addrs = lua_try!(lookup_host((addr, port)).await);
3038

3139
let sock_options = SocketOptions::from_table(&params)?;
3240
let backlog = opt_param!(params, "backlog")?;

src/net/tcp/socket.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io::Result as IoResult;
1+
use std::io;
22
use std::net::SocketAddr;
33
use std::ops::Deref;
44

@@ -26,15 +26,15 @@ pub(super) struct SocketOptions {
2626
}
2727

2828
impl TcpSocket {
29-
pub(crate) fn new_for_addr(addr: SocketAddr) -> IoResult<Self> {
29+
pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result<Self> {
3030
let sock = match addr {
3131
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
3232
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
3333
};
3434
Ok(TcpSocket(sock))
3535
}
3636

37-
pub(crate) fn set_options(&self, options: SocketOptions) -> IoResult<()> {
37+
pub(crate) fn set_options(&self, options: SocketOptions) -> io::Result<()> {
3838
if let Some(keepalive) = options.keepalive {
3939
self.set_keepalive(keepalive)?;
4040
}

src/net/tcp/stream.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::io;
12
use std::net::SocketAddr;
23
use std::ops::{Deref, DerefMut};
34
use std::result::Result as StdResult;
@@ -7,10 +8,12 @@ use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
78
use tokio::net::lookup_host;
89

910
use super::{SocketOptions, TcpSocket};
11+
use crate::net::{AddressProvider, AnySocketAddr};
1012
use crate::time::Duration;
1113

1214
pub struct TcpStream {
1315
pub(crate) stream: tokio::net::TcpStream,
16+
pub(crate) host: Option<String>,
1417
pub(crate) read_timeout: Option<Duration>,
1518
pub(crate) write_timeout: Option<Duration>,
1619
}
@@ -35,19 +38,29 @@ impl From<tokio::net::TcpStream> for TcpStream {
3538
fn from(stream: tokio::net::TcpStream) -> Self {
3639
TcpStream {
3740
stream,
41+
host: None,
3842
read_timeout: None,
3943
write_timeout: None,
4044
}
4145
}
4246
}
4347

48+
impl AddressProvider for TcpStream {
49+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
50+
self.stream.local_addr().map(AnySocketAddr::Tcp)
51+
}
52+
53+
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
54+
self.stream.peer_addr().map(AnySocketAddr::Tcp)
55+
}
56+
}
57+
4458
impl UserData for TcpStream {
4559
fn register(registry: &mut UserDataRegistry<Self>) {
4660
registry.add_async_function("connect", connect);
4761

48-
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?.to_string()));
49-
50-
registry.add_method("peer_addr", |_, this, ()| Ok(this.peer_addr()?.to_string()));
62+
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?));
63+
registry.add_method("peer_addr", |_, this, ()| Ok(this.peer_addr()?));
5164

5265
registry.add_method_mut("set_read_timeout", |_, this, dur: Option<Duration>| {
5366
this.read_timeout = dur;
@@ -101,9 +114,9 @@ impl UserData for TcpStream {
101114

102115
pub async fn connect(
103116
_: Lua,
104-
(addr, params): (String, Option<Table>),
117+
(host, port, params): (String, u16, Option<Table>),
105118
) -> Result<StdResult<TcpStream, String>> {
106-
let addrs = lua_try!(lookup_host(addr).await);
119+
let addrs = lua_try!(lookup_host((&*host, port)).await);
107120
let options = SocketOptions::from_table(&params)?;
108121

109122
let timeout = opt_param!(Duration, params, "timeout")?; // A single timeout for any operation
@@ -123,6 +136,7 @@ pub async fn connect(
123136
Ok(stream) => {
124137
return Ok(Ok(TcpStream {
125138
stream,
139+
host: Some(host.clone()),
126140
read_timeout,
127141
write_timeout,
128142
}));

0 commit comments

Comments
 (0)