Skip to content

Commit 4f862e3

Browse files
committed
Allow FutureSpawner to return the result of the spawned future
`tokio::spawn` can be use both to spawn a forever-running background task or to spawn a task which gets `poll`ed independently and eventually returns a result which the callsite wants. In LDK, we have only ever needed the first, and thus didn't bother defining a return type for `FutureSpawner::spawn`. However, in the next commit we'll start using `FutureSpawner` in a context where we actually do want the spawned future's result. Thus, here, we add a result output to `FutureSpawner::spawn`, mirroring the `tokio::spawn` API.
1 parent ac848fe commit 4f862e3

File tree

3 files changed

+124
-12
lines changed

3 files changed

+124
-12
lines changed

lightning-block-sync/src/gossip.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ pub trait UtxoSource: BlockSource + 'static {
4747
pub struct TokioSpawner;
4848
#[cfg(feature = "tokio")]
4949
impl FutureSpawner for TokioSpawner {
50-
fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
51-
tokio::spawn(future);
50+
type E = tokio::task::JoinError;
51+
type SpawnedFutureResult<O> = tokio::task::JoinHandle<O>;
52+
fn spawn<O: Send + 'static, F: Future<Output = O> + Send + 'static>(
53+
&self, future: F,
54+
) -> Self::SpawnedFutureResult<O> {
55+
tokio::spawn(future)
5256
}
5357
}
5458

@@ -273,7 +277,7 @@ where
273277
let gossiper = Arc::clone(&self.gossiper);
274278
let block_cache = Arc::clone(&self.block_cache);
275279
let pmw = Arc::clone(&self.peer_manager_wake);
276-
self.spawn.spawn(async move {
280+
let _ = self.spawn.spawn(async move {
277281
let res = Self::retrieve_utxo(source, block_cache, short_channel_id).await;
278282
fut.resolve(gossiper.network_graph(), &*gossiper, res);
279283
(pmw)();

lightning/src/util/native_async.rs

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,41 @@
88
//! environment.
99
1010
#[cfg(all(test, feature = "std"))]
11-
use crate::sync::Mutex;
11+
use crate::sync::{Arc, Mutex};
1212
use crate::util::async_poll::{MaybeSend, MaybeSync};
1313

1414
#[cfg(all(test, not(feature = "std")))]
1515
use core::cell::RefCell;
16+
#[cfg(test)]
17+
use core::convert::Infallible;
1618
use core::future::Future;
1719
#[cfg(test)]
1820
use core::pin::Pin;
21+
#[cfg(all(test, not(feature = "std")))]
22+
use core::rc::Rc;
23+
#[cfg(test)]
24+
use core::task::{Context, Poll};
1925

20-
/// A generic trait which is able to spawn futures in the background.
26+
/// A generic trait which is able to spawn futures to be polled in the background.
27+
///
28+
/// When the spawned future completes, the returned [`Self::SpawnedFutureResult`] should resolve
29+
/// with the output of the spawned future.
30+
///
31+
/// Spawned futures must be polled independently in the background even if the returned
32+
/// [`Self::SpawnedFutureResult`] is dropped without being polled. This matches the semantics of
33+
/// `tokio::spawn`.
2134
pub trait FutureSpawner: MaybeSend + MaybeSync + 'static {
35+
/// The error type of [`Self::SpawnedFutureResult`]. This can be used to indicate that the
36+
/// spawned future was cancelled or panicked.
37+
type E;
38+
/// The result of [`Self::spawn`], a future which completes when the spawned future completes.
39+
type SpawnedFutureResult<O>: Future<Output = Result<O, Self::E>> + Unpin;
2240
/// Spawns the given future as a background task.
2341
///
2442
/// This method MUST NOT block on the given future immediately.
25-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T);
43+
fn spawn<O: MaybeSend + 'static, T: Future<Output = O> + MaybeSend + 'static>(
44+
&self, future: T,
45+
) -> Self::SpawnedFutureResult<O>;
2646
}
2747

2848
#[cfg(test)]
@@ -37,6 +57,69 @@ pub(crate) struct FutureQueue(Mutex<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3757
#[cfg(all(test, not(feature = "std")))]
3858
pub(crate) struct FutureQueue(RefCell<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3959

60+
#[cfg(all(test, feature = "std"))]
61+
pub struct FutureQueueCompletion<O>(Arc<Mutex<Option<O>>>);
62+
#[cfg(all(test, not(feature = "std")))]
63+
pub struct FutureQueueCompletion<O>(Rc<RefCell<Option<O>>>);
64+
65+
#[cfg(all(test, feature = "std"))]
66+
impl<O> FutureQueueCompletion<O> {
67+
fn new() -> Self {
68+
Self(Arc::new(Mutex::new(None)))
69+
}
70+
71+
fn complete(&self, o: O) {
72+
*self.0.lock().unwrap() = Some(o);
73+
}
74+
}
75+
76+
#[cfg(all(test, feature = "std"))]
77+
impl<O> Clone for FutureQueueCompletion<O> {
78+
fn clone(&self) -> Self {
79+
Self(self.0.clone())
80+
}
81+
}
82+
83+
#[cfg(all(test, not(feature = "std")))]
84+
impl<O> FutureQueueCompletion<O> {
85+
fn new() -> Self {
86+
Self(Rc::new(RefCell::new(None)))
87+
}
88+
89+
fn complete(&self, o: O) {
90+
*self.0.lock().unwrap() = Some(o);
91+
}
92+
}
93+
94+
#[cfg(all(test, not(feature = "std")))]
95+
impl<O> Clone for FutureQueueCompletion<O> {
96+
fn clone(&self) -> Self {
97+
Self(self.0.clone())
98+
}
99+
}
100+
101+
#[cfg(all(test, feature = "std"))]
102+
impl<O> Future for FutureQueueCompletion<O> {
103+
type Output = Result<O, Infallible>;
104+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
105+
match Pin::into_inner(self).0.lock().unwrap().take() {
106+
None => Poll::Pending,
107+
Some(o) => Poll::Ready(Ok(o)),
108+
}
109+
}
110+
}
111+
112+
#[cfg(all(test, not(feature = "std")))]
113+
impl<O> Future for FutureQueueCompletion<O> {
114+
type Output = Result<O, Infallible>;
115+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
116+
match Pin::into_inner(self).0.get_mut().take() {
117+
None => Poll::Pending,
118+
Some(o) => Poll::Ready(Ok(o)),
119+
}
120+
}
121+
}
122+
40123
#[cfg(test)]
41124
impl FutureQueue {
42125
pub(crate) fn new() -> Self {
@@ -84,7 +167,16 @@ impl FutureQueue {
84167

85168
#[cfg(test)]
86169
impl FutureSpawner for FutureQueue {
87-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
170+
type E = Infallible;
171+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
172+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
173+
&self, f: F,
174+
) -> FutureQueueCompletion<O> {
175+
let completion = FutureQueueCompletion::new();
176+
let compl_ref = completion.clone();
177+
let future = async move {
178+
compl_ref.complete(f.await);
179+
};
88180
#[cfg(feature = "std")]
89181
{
90182
self.0.lock().unwrap().push(Box::pin(future));
@@ -93,14 +185,24 @@ impl FutureSpawner for FutureQueue {
93185
{
94186
self.0.borrow_mut().push(Box::pin(future));
95187
}
188+
completion
96189
}
97190
}
98191

99192
#[cfg(test)]
100193
impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static> FutureSpawner
101194
for D
102195
{
103-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
196+
type E = Infallible;
197+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
198+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
199+
&self, f: F,
200+
) -> FutureQueueCompletion<O> {
201+
let completion = FutureQueueCompletion::new();
202+
let compl_ref = completion.clone();
203+
let future = async move {
204+
compl_ref.complete(f.await);
205+
};
104206
#[cfg(feature = "std")]
105207
{
106208
self.0.lock().unwrap().push(Box::pin(future));
@@ -109,5 +211,6 @@ impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static
109211
{
110212
self.0.borrow_mut().push(Box::pin(future));
111213
}
214+
completion
112215
}
113216
}

lightning/src/util/persist.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use alloc::sync::Arc;
1616
use bitcoin::hashes::hex::FromHex;
1717
use bitcoin::{BlockHash, Txid};
1818

19+
use core::convert::Infallible;
1920
use core::future::Future;
2021
use core::mem;
2122
use core::ops::Deref;
@@ -407,7 +408,11 @@ where
407408

408409
struct PanicingSpawner;
409410
impl FutureSpawner for PanicingSpawner {
410-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, _: T) {
411+
type E = Infallible;
412+
type SpawnedFutureResult<O> = Box<dyn Future<Output = Result<O, Infallible>> + Unpin>;
413+
fn spawn<O, T: Future<Output = O> + MaybeSend + 'static>(
414+
&self, _: T,
415+
) -> Self::SpawnedFutureResult<O> {
411416
unreachable!();
412417
}
413418
}
@@ -865,7 +870,7 @@ where
865870
let future = inner.persist_new_channel(monitor_name, monitor);
866871
let channel_id = monitor.channel_id();
867872
let completion = (monitor.channel_id(), monitor.get_latest_update_id());
868-
self.0.future_spawner.spawn(async move {
873+
let _ = self.0.future_spawner.spawn(async move {
869874
match future.await {
870875
Ok(()) => inner.async_completed_updates.lock().unwrap().push(completion),
871876
Err(e) => {
@@ -893,7 +898,7 @@ where
893898
None
894899
};
895900
let inner = Arc::clone(&self.0);
896-
self.0.future_spawner.spawn(async move {
901+
let _ = self.0.future_spawner.spawn(async move {
897902
match future.await {
898903
Ok(()) => if let Some(completion) = completion {
899904
inner.async_completed_updates.lock().unwrap().push(completion);
@@ -910,7 +915,7 @@ where
910915

911916
pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) {
912917
let inner = Arc::clone(&self.0);
913-
self.0.future_spawner.spawn(async move {
918+
let _ = self.0.future_spawner.spawn(async move {
914919
inner.archive_persisted_channel(monitor_name).await;
915920
});
916921
}

0 commit comments

Comments
 (0)