Skip to content

Commit c9e5c5f

Browse files
committed
Allow FutureSpawner to return the result of the spawned future
`tokio::spawn` can be use both to spawn a forever-running background task or to spawn a task which gets `poll`ed independently and eventually returns a result which the callsite wants. In LDK, we have only ever needed the first, and thus didn't bother defining a return type for `FutureSpawner::spawn`. However, in the next commit we'll start using `FutureSpawner` in a context where we actually do want the spawned future's result. Thus, here, we add a result output to `FutureSpawner::spawn`, mirroring the `tokio::spawn` API.
1 parent ac848fe commit c9e5c5f

File tree

3 files changed

+125
-12
lines changed

3 files changed

+125
-12
lines changed

lightning-block-sync/src/gossip.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ pub trait UtxoSource: BlockSource + 'static {
4747
pub struct TokioSpawner;
4848
#[cfg(feature = "tokio")]
4949
impl FutureSpawner for TokioSpawner {
50-
fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
51-
tokio::spawn(future);
50+
type E = tokio::task::JoinError;
51+
type SpawnedFutureResult<O> = tokio::task::JoinHandle<O>;
52+
fn spawn<O: Send + 'static, F: Future<Output = O> + Send + 'static>(
53+
&self, future: F,
54+
) -> Self::SpawnedFutureResult<O> {
55+
tokio::spawn(future)
5256
}
5357
}
5458

@@ -273,7 +277,7 @@ where
273277
let gossiper = Arc::clone(&self.gossiper);
274278
let block_cache = Arc::clone(&self.block_cache);
275279
let pmw = Arc::clone(&self.peer_manager_wake);
276-
self.spawn.spawn(async move {
280+
let _ = self.spawn.spawn(async move {
277281
let res = Self::retrieve_utxo(source, block_cache, short_channel_id).await;
278282
fut.resolve(gossiper.network_graph(), &*gossiper, res);
279283
(pmw)();

lightning/src/util/native_async.rs

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,42 @@
88
//! environment.
99
1010
#[cfg(all(test, feature = "std"))]
11-
use crate::sync::Mutex;
11+
use crate::sync::{Arc, Mutex};
1212
use crate::util::async_poll::{MaybeSend, MaybeSync};
1313

14+
#[cfg(all(test, not(feature = "std")))]
15+
use alloc::rc::Rc;
16+
1417
#[cfg(all(test, not(feature = "std")))]
1518
use core::cell::RefCell;
19+
#[cfg(test)]
20+
use core::convert::Infallible;
1621
use core::future::Future;
1722
#[cfg(test)]
1823
use core::pin::Pin;
24+
#[cfg(test)]
25+
use core::task::{Context, Poll};
1926

20-
/// A generic trait which is able to spawn futures in the background.
27+
/// A generic trait which is able to spawn futures to be polled in the background.
28+
///
29+
/// When the spawned future completes, the returned [`Self::SpawnedFutureResult`] should resolve
30+
/// with the output of the spawned future.
31+
///
32+
/// Spawned futures must be polled independently in the background even if the returned
33+
/// [`Self::SpawnedFutureResult`] is dropped without being polled. This matches the semantics of
34+
/// `tokio::spawn`.
2135
pub trait FutureSpawner: MaybeSend + MaybeSync + 'static {
36+
/// The error type of [`Self::SpawnedFutureResult`]. This can be used to indicate that the
37+
/// spawned future was cancelled or panicked.
38+
type E;
39+
/// The result of [`Self::spawn`], a future which completes when the spawned future completes.
40+
type SpawnedFutureResult<O>: Future<Output = Result<O, Self::E>> + Unpin;
2241
/// Spawns the given future as a background task.
2342
///
2443
/// This method MUST NOT block on the given future immediately.
25-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T);
44+
fn spawn<O: MaybeSend + 'static, T: Future<Output = O> + MaybeSend + 'static>(
45+
&self, future: T,
46+
) -> Self::SpawnedFutureResult<O>;
2647
}
2748

2849
#[cfg(test)]
@@ -37,6 +58,69 @@ pub(crate) struct FutureQueue(Mutex<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3758
#[cfg(all(test, not(feature = "std")))]
3859
pub(crate) struct FutureQueue(RefCell<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3960

61+
#[cfg(all(test, feature = "std"))]
62+
pub struct FutureQueueCompletion<O>(Arc<Mutex<Option<O>>>);
63+
#[cfg(all(test, not(feature = "std")))]
64+
pub struct FutureQueueCompletion<O>(Rc<RefCell<Option<O>>>);
65+
66+
#[cfg(all(test, feature = "std"))]
67+
impl<O> FutureQueueCompletion<O> {
68+
fn new() -> Self {
69+
Self(Arc::new(Mutex::new(None)))
70+
}
71+
72+
fn complete(&self, o: O) {
73+
*self.0.lock().unwrap() = Some(o);
74+
}
75+
}
76+
77+
#[cfg(all(test, feature = "std"))]
78+
impl<O> Clone for FutureQueueCompletion<O> {
79+
fn clone(&self) -> Self {
80+
Self(self.0.clone())
81+
}
82+
}
83+
84+
#[cfg(all(test, not(feature = "std")))]
85+
impl<O> FutureQueueCompletion<O> {
86+
fn new() -> Self {
87+
Self(Rc::new(RefCell::new(None)))
88+
}
89+
90+
fn complete(&self, o: O) {
91+
*self.0.lock().unwrap() = Some(o);
92+
}
93+
}
94+
95+
#[cfg(all(test, not(feature = "std")))]
96+
impl<O> Clone for FutureQueueCompletion<O> {
97+
fn clone(&self) -> Self {
98+
Self(self.0.clone())
99+
}
100+
}
101+
102+
#[cfg(all(test, feature = "std"))]
103+
impl<O> Future for FutureQueueCompletion<O> {
104+
type Output = Result<O, Infallible>;
105+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
106+
match Pin::into_inner(self).0.lock().unwrap().take() {
107+
None => Poll::Pending,
108+
Some(o) => Poll::Ready(Ok(o)),
109+
}
110+
}
111+
}
112+
113+
#[cfg(all(test, not(feature = "std")))]
114+
impl<O> Future for FutureQueueCompletion<O> {
115+
type Output = Result<O, Infallible>;
116+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
117+
match Pin::into_inner(self).0.get_mut().take() {
118+
None => Poll::Pending,
119+
Some(o) => Poll::Ready(Ok(o)),
120+
}
121+
}
122+
}
123+
40124
#[cfg(test)]
41125
impl FutureQueue {
42126
pub(crate) fn new() -> Self {
@@ -84,7 +168,16 @@ impl FutureQueue {
84168

85169
#[cfg(test)]
86170
impl FutureSpawner for FutureQueue {
87-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
171+
type E = Infallible;
172+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
173+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
174+
&self, f: F,
175+
) -> FutureQueueCompletion<O> {
176+
let completion = FutureQueueCompletion::new();
177+
let compl_ref = completion.clone();
178+
let future = async move {
179+
compl_ref.complete(f.await);
180+
};
88181
#[cfg(feature = "std")]
89182
{
90183
self.0.lock().unwrap().push(Box::pin(future));
@@ -93,14 +186,24 @@ impl FutureSpawner for FutureQueue {
93186
{
94187
self.0.borrow_mut().push(Box::pin(future));
95188
}
189+
completion
96190
}
97191
}
98192

99193
#[cfg(test)]
100194
impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static> FutureSpawner
101195
for D
102196
{
103-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
197+
type E = Infallible;
198+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
199+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
200+
&self, f: F,
201+
) -> FutureQueueCompletion<O> {
202+
let completion = FutureQueueCompletion::new();
203+
let compl_ref = completion.clone();
204+
let future = async move {
205+
compl_ref.complete(f.await);
206+
};
104207
#[cfg(feature = "std")]
105208
{
106209
self.0.lock().unwrap().push(Box::pin(future));
@@ -109,5 +212,6 @@ impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static
109212
{
110213
self.0.borrow_mut().push(Box::pin(future));
111214
}
215+
completion
112216
}
113217
}

lightning/src/util/persist.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use alloc::sync::Arc;
1616
use bitcoin::hashes::hex::FromHex;
1717
use bitcoin::{BlockHash, Txid};
1818

19+
use core::convert::Infallible;
1920
use core::future::Future;
2021
use core::mem;
2122
use core::ops::Deref;
@@ -407,7 +408,11 @@ where
407408

408409
struct PanicingSpawner;
409410
impl FutureSpawner for PanicingSpawner {
410-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, _: T) {
411+
type E = Infallible;
412+
type SpawnedFutureResult<O> = Box<dyn Future<Output = Result<O, Infallible>> + Unpin>;
413+
fn spawn<O, T: Future<Output = O> + MaybeSend + 'static>(
414+
&self, _: T,
415+
) -> Self::SpawnedFutureResult<O> {
411416
unreachable!();
412417
}
413418
}
@@ -865,7 +870,7 @@ where
865870
let future = inner.persist_new_channel(monitor_name, monitor);
866871
let channel_id = monitor.channel_id();
867872
let completion = (monitor.channel_id(), monitor.get_latest_update_id());
868-
self.0.future_spawner.spawn(async move {
873+
let _ = self.0.future_spawner.spawn(async move {
869874
match future.await {
870875
Ok(()) => inner.async_completed_updates.lock().unwrap().push(completion),
871876
Err(e) => {
@@ -893,7 +898,7 @@ where
893898
None
894899
};
895900
let inner = Arc::clone(&self.0);
896-
self.0.future_spawner.spawn(async move {
901+
let _ = self.0.future_spawner.spawn(async move {
897902
match future.await {
898903
Ok(()) => if let Some(completion) = completion {
899904
inner.async_completed_updates.lock().unwrap().push(completion);
@@ -910,7 +915,7 @@ where
910915

911916
pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) {
912917
let inner = Arc::clone(&self.0);
913-
self.0.future_spawner.spawn(async move {
918+
let _ = self.0.future_spawner.spawn(async move {
914919
inner.archive_persisted_channel(monitor_name).await;
915920
});
916921
}

0 commit comments

Comments
 (0)