diff options
author | Jeff Vander Stoep <jeffv@google.com> | 2024-02-05 14:19:11 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2024-02-05 14:19:11 +0000 |
commit | 646d9067cfe1430335326bd1a5ffe46d370addd2 (patch) | |
tree | 3044f4002ed695e6540a4969306b4304349073ce | |
parent | 3b1161ba9718578d32f382256dee18a97d01d180 (diff) | |
parent | a7879fe08376d6c078eed44ffd47a14fb340af59 (diff) | |
download | tokio-util-646d9067cfe1430335326bd1a5ffe46d370addd2.tar.gz |
Upgrade tokio-util to 0.7.10 am: a7879fe083HEADmastermainemu-34-2-dev
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/tokio-util/+/2950425
Change-Id: If80db06d7ac2be560713c61ae717d03ae2fcb068
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
38 files changed, 1862 insertions, 318 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..a9990be --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "503fad79087ed5791c7a018e07621689ea5e4676" + }, + "path_in_vcs": "tokio-util" +}
\ No newline at end of file @@ -23,9 +23,9 @@ rust_library { host_supported: true, crate_name: "tokio_util", cargo_env_compat: true, - cargo_pkg_version: "0.7.7", + cargo_pkg_version: "0.7.10", srcs: ["src/lib.rs"], - edition: "2018", + edition: "2021", features: [ "codec", "compat", diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c11b21..b98092c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,75 @@ +# 0.7.10 (October 24th, 2023) + +### Added + +- task: add `TaskTracker` ([#6033]) +- task: add `JoinMap::keys` ([#6046]) +- io: implement `Seek` for `SyncIoBridge` ([#6058]) + +### Changed + +- deps: update hashbrown to 0.14 ([#6102]) + +[#6033]: https://github.com/tokio-rs/tokio/pull/6033 +[#6046]: https://github.com/tokio-rs/tokio/pull/6046 +[#6058]: https://github.com/tokio-rs/tokio/pull/6058 +[#6102]: https://github.com/tokio-rs/tokio/pull/6102 + +# 0.7.9 (September 20th, 2023) + +### Added + +- io: add passthrough `AsyncRead`/`AsyncWrite` to `InspectWriter`/`InspectReader` ([#5739]) +- task: add spawn blocking methods to `JoinMap` ([#5797]) +- io: pass through traits for `StreamReader` and `SinkWriter` ([#5941]) +- io: add `SyncIoBridge::into_inner` ([#5971]) + +### Fixed + +- sync: handle possibly dangling reference safely ([#5812]) +- util: fix broken intra-doc link ([#5849]) +- compat: fix clippy warnings ([#5891]) + +### Documented + +- codec: Specify the line ending of `LinesCodec` ([#5982]) + +[#5739]: https://github.com/tokio-rs/tokio/pull/5739 +[#5797]: https://github.com/tokio-rs/tokio/pull/5797 +[#5941]: https://github.com/tokio-rs/tokio/pull/5941 +[#5971]: https://github.com/tokio-rs/tokio/pull/5971 +[#5812]: https://github.com/tokio-rs/tokio/pull/5812 +[#5849]: https://github.com/tokio-rs/tokio/pull/5849 +[#5891]: https://github.com/tokio-rs/tokio/pull/5891 +[#5982]: https://github.com/tokio-rs/tokio/pull/5982 + +# 0.7.8 (April 25th, 2023) + +This release bumps the MSRV of tokio-util to 1.56. + +### Added + +- time: add `DelayQueue::peek` ([#5569]) + +### Changed + +This release contains one performance improvement: + +- sync: try to lock the parent first in `CancellationToken` ([#5561]) + +### Fixed + +- time: fix panic in `DelayQueue` ([#5630]) + +### Documented + +- sync: improve `CancellationToken` doc on child tokens ([#5632]) + +[#5561]: https://github.com/tokio-rs/tokio/pull/5561 +[#5569]: https://github.com/tokio-rs/tokio/pull/5569 +[#5630]: https://github.com/tokio-rs/tokio/pull/5630 +[#5632]: https://github.com/tokio-rs/tokio/pull/5632 + # 0.7.7 (February 12, 2023) This release reverts the removal of the `Encoder` bound on the `FramedParts` @@ -10,10 +10,10 @@ # See Cargo.toml.orig for the original contents. [package] -edition = "2018" -rust-version = "1.49" +edition = "2021" +rust-version = "1.56" name = "tokio-util" -version = "0.7.7" +version = "0.7.10" authors = ["Tokio Contributors <team@tokio.rs>"] description = """ Additional utilities for working with Tokio. @@ -26,13 +26,13 @@ repository = "https://github.com/tokio-rs/tokio" [package.metadata.docs.rs] all-features = true -rustdoc-args = [ +rustc-args = [ "--cfg", "docsrs", "--cfg", "tokio_unstable", ] -rustc-args = [ +rustdoc-args = [ "--cfg", "docsrs", "--cfg", @@ -57,14 +57,14 @@ version = "0.3.0" optional = true [dependencies.pin-project-lite] -version = "0.2.0" +version = "0.2.11" [dependencies.slab] version = "0.4.4" optional = true [dependencies.tokio] -version = "1.22.0" +version = "1.28.0" features = ["sync"] [dependencies.tracing] @@ -85,6 +85,9 @@ version = "0.3.5" [dev-dependencies.parking_lot] version = "0.12.0" +[dev-dependencies.tempfile] +version = "3.1.0" + [dev-dependencies.tokio] version = "1.0.0" features = ["full"] @@ -127,5 +130,5 @@ time = [ ] [target."cfg(tokio_unstable)".dependencies.hashbrown] -version = "0.12.0" +version = "0.14.0" optional = true diff --git a/Cargo.toml.orig b/Cargo.toml.orig index 267662b..437dc5a 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -4,9 +4,9 @@ name = "tokio-util" # - Remove path dependencies # - Update CHANGELOG.md. # - Create "tokio-util-0.7.x" git tag. -version = "0.7.7" -edition = "2018" -rust-version = "1.49" +version = "0.7.10" +edition = "2021" +rust-version = "1.56" authors = ["Tokio Contributors <team@tokio.rs>"] license = "MIT" repository = "https://github.com/tokio-rs/tokio" @@ -34,18 +34,18 @@ rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"] __docs_rs = ["futures-util"] [dependencies] -tokio = { version = "1.22.0", path = "../tokio", features = ["sync"] } +tokio = { version = "1.28.0", path = "../tokio", features = ["sync"] } bytes = "1.0.0" futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } futures-util = { version = "0.3.0", optional = true } -pin-project-lite = "0.2.0" +pin-project-lite = "0.2.11" slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } [target.'cfg(tokio_unstable)'.dependencies] -hashbrown = { version = "0.12.0", optional = true } +hashbrown = { version = "0.14.0", optional = true } [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } @@ -56,6 +56,7 @@ async-stream = "0.3.0" futures = "0.3.0" futures-test = "0.3.5" parking_lot = "0.12.0" +tempfile = "3.1.0" [package.metadata.docs.rs] all-features = true @@ -1,19 +1,20 @@ +# This project was upgraded with external_updater. +# Usage: tools/external_updater/updater.sh update external/rust/crates/tokio-util +# For more info, check https://cs.android.com/android/platform/superproject/+/main:tools/external_updater/README.md + name: "tokio-util" description: "Utilities for working with Tokio." third_party { - url { - type: HOMEPAGE - value: "https://crates.io/crates/tokio-util" - } - url { - type: ARCHIVE - value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.7.crate" - } - version: "0.7.7" license_type: NOTICE last_upgrade_date { - year: 2023 - month: 3 - day: 3 + year: 2024 + month: 2 + day: 5 + } + homepage: "https://crates.io/crates/tokio-util" + identifier { + type: "Archive" + value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.10.crate" + version: "0.7.10" } } diff --git a/src/codec/lines_codec.rs b/src/codec/lines_codec.rs index 7a0a8f0..5a6035d 100644 --- a/src/codec/lines_codec.rs +++ b/src/codec/lines_codec.rs @@ -6,6 +6,8 @@ use std::{cmp, fmt, io, str, usize}; /// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. /// +/// This uses the `\n` character as the line ending on all platforms. +/// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] diff --git a/src/compat.rs b/src/compat.rs index 6a8802d..423bd95 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -227,12 +227,14 @@ impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> { pos: io::SeekFrom, ) -> Poll<io::Result<u64>> { if self.seek_pos != Some(pos) { + // Ensure previous seeks have finished before starting a new one + ready!(self.as_mut().project().inner.poll_complete(cx))?; self.as_mut().project().inner.start_seek(pos)?; *self.as_mut().project().seek_pos = Some(pos); } let res = ready!(self.as_mut().project().inner.poll_complete(cx)); *self.as_mut().project().seek_pos = None; - Poll::Ready(res.map(|p| p as u64)) + Poll::Ready(res) } } @@ -255,7 +257,7 @@ impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> { }; let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); *self.as_mut().project().seek_pos = None; - Poll::Ready(res.map(|p| p as u64)) + Poll::Ready(res) } } diff --git a/src/either.rs b/src/either.rs index 9225e53..8a02398 100644 --- a/src/either.rs +++ b/src/either.rs @@ -116,7 +116,7 @@ where } fn consume(self: Pin<&mut Self>, amt: usize) { - delegate_call!(self.consume(amt)) + delegate_call!(self.consume(amt)); } } diff --git a/src/io/copy_to_bytes.rs b/src/io/copy_to_bytes.rs index 9509e71..f0b5c35 100644 --- a/src/io/copy_to_bytes.rs +++ b/src/io/copy_to_bytes.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use futures_core::stream::Stream; use futures_sink::Sink; use pin_project_lite::pin_project; use std::pin::Pin; @@ -66,3 +67,10 @@ where self.project().inner.poll_close(cx) } } + +impl<S: Stream> Stream for CopyToBytes<S> { + type Item = S::Item; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} diff --git a/src/io/inspect.rs b/src/io/inspect.rs index ec5bb97..c860b80 100644 --- a/src/io/inspect.rs +++ b/src/io/inspect.rs @@ -52,6 +52,42 @@ impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> { } } +impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<std::result::Result<usize, std::io::Error>> { + self.project().reader.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), std::io::Error>> { + self.project().reader.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<std::result::Result<(), std::io::Error>> { + self.project().reader.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<Result<usize>> { + self.project().reader.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.reader.is_write_vectored() + } +} + pin_project! { /// An adapter that lets you inspect the data that's being written. /// @@ -132,3 +168,13 @@ impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> { self.writer.is_write_vectored() } } + +impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + self.project().writer.poll_read(cx, buf) + } +} diff --git a/src/io/sink_writer.rs b/src/io/sink_writer.rs index f2af262..e078952 100644 --- a/src/io/sink_writer.rs +++ b/src/io/sink_writer.rs @@ -1,11 +1,12 @@ use futures_core::ready; use futures_sink::Sink; +use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::AsyncWrite; +use tokio::io::{AsyncRead, AsyncWrite}; pin_project! { /// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`]. @@ -59,7 +60,7 @@ pin_project! { /// [`CopyToBytes`]: crate::io::CopyToBytes /// [`Encoder`]: crate::codec::Encoder /// [`Sink`]: futures_sink::Sink - /// [`codec`]: tokio_util::codec + /// [`codec`]: crate::codec #[derive(Debug)] pub struct SinkWriter<S> { #[pin] @@ -115,3 +116,20 @@ where self.project().inner.poll_close(cx).map_err(Into::into) } } + +impl<S: Stream> Stream for SinkWriter<S> { + type Item = S::Item; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +impl<S: AsyncRead> AsyncRead for SinkWriter<S> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.project().inner.poll_read(cx, buf) + } +} diff --git a/src/io/stream_reader.rs b/src/io/stream_reader.rs index 3353722..6ecf8ec 100644 --- a/src/io/stream_reader.rs +++ b/src/io/stream_reader.rs @@ -1,5 +1,6 @@ use bytes::Buf; use futures_core::stream::Stream; +use futures_sink::Sink; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -165,7 +166,7 @@ where B: Buf, E: Into<std::io::Error>, { - /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). + /// Convert a stream of byte chunks into an [`AsyncRead`]. /// /// The item should be a [`Result`] with the ok variant being something that /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error @@ -324,3 +325,22 @@ impl<S, B> StreamReader<S, B> { } } } + +impl<S: Sink<T, Error = E>, E, T> Sink<T> for StreamReader<S, E> { + type Error = E; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} diff --git a/src/io/sync_bridge.rs b/src/io/sync_bridge.rs index f87bfbb..2402207 100644 --- a/src/io/sync_bridge.rs +++ b/src/io/sync_bridge.rs @@ -1,6 +1,7 @@ -use std::io::{BufRead, Read, Write}; +use std::io::{BufRead, Read, Seek, Write}; use tokio::io::{ - AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, + AsyncWriteExt, }; /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or @@ -79,6 +80,13 @@ impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> { } } +impl<T: AsyncSeek + Unpin> Seek for SyncIoBridge<T> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> { + let src = &mut self.src; + self.rt.block_on(AsyncSeekExt::seek(src, pos)) + } +} + // Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time // of this writing still unstable, we expose this as part of a standalone method. impl<T: AsyncWrite> SyncIoBridge<T> { @@ -140,4 +148,9 @@ impl<T: Unpin> SyncIoBridge<T> { pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { Self { src, rt } } + + /// Consume this bridge, returning the underlying stream. + pub fn into_inner(self) -> T { + self.src + } } @@ -55,151 +55,6 @@ pub mod sync; pub mod either; -#[cfg(any(feature = "io", feature = "codec"))] -mod util { - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +pub use bytes; - use bytes::{Buf, BufMut}; - use futures_core::ready; - use std::io::{self, IoSlice}; - use std::mem::MaybeUninit; - use std::pin::Pin; - use std::task::{Context, Poll}; - - /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. - /// - /// [`BufMut`]: bytes::Buf - /// - /// # Example - /// - /// ``` - /// use bytes::{Bytes, BytesMut}; - /// use tokio_stream as stream; - /// use tokio::io::Result; - /// use tokio_util::io::{StreamReader, poll_read_buf}; - /// use futures::future::poll_fn; - /// use std::pin::Pin; - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// - /// // Create a reader from an iterator. This particular reader will always be - /// // ready. - /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); - /// - /// let mut buf = BytesMut::new(); - /// let mut reads = 0; - /// - /// loop { - /// reads += 1; - /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; - /// - /// if n == 0 { - /// break; - /// } - /// } - /// - /// // one or more reads might be necessary. - /// assert!(reads >= 1); - /// assert_eq!(&buf[..], &[0, 1, 2, 3]); - /// # Ok(()) - /// # } - /// ``` - #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] - pub fn poll_read_buf<T: AsyncRead, B: BufMut>( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let n = { - let dst = buf.chunk_mut(); - - // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a - // transparent wrapper around `[MaybeUninit<u8>]`. - let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(io.poll_read(cx, &mut buf)?); - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) - } - - /// Try to write data from an implementer of the [`Buf`] trait to an - /// [`AsyncWrite`], advancing the buffer's internal cursor. - /// - /// This function will use [vectored writes] when the [`AsyncWrite`] supports - /// vectored writes. - /// - /// # Examples - /// - /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements - /// [`Buf`]: - /// - /// ```no_run - /// use tokio_util::io::poll_write_buf; - /// use tokio::io; - /// use tokio::fs::File; - /// - /// use bytes::Buf; - /// use std::io::Cursor; - /// use std::pin::Pin; - /// use futures::future::poll_fn; - /// - /// #[tokio::main] - /// async fn main() -> io::Result<()> { - /// let mut file = File::create("foo.txt").await?; - /// let mut buf = Cursor::new(b"data to write"); - /// - /// // Loop until the entire contents of the buffer are written to - /// // the file. - /// while buf.has_remaining() { - /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; - /// } - /// - /// Ok(()) - /// } - /// ``` - /// - /// [`Buf`]: bytes::Buf - /// [`AsyncWrite`]: tokio::io::AsyncWrite - /// [`File`]: tokio::fs::File - /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored - #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] - pub fn poll_write_buf<T: AsyncWrite, B: Buf>( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - const MAX_BUFS: usize = 64; - - if !buf.has_remaining() { - return Poll::Ready(Ok(0)); - } - - let n = if io.is_write_vectored() { - let mut slices = [IoSlice::new(&[]); MAX_BUFS]; - let cnt = buf.chunks_vectored(&mut slices); - ready!(io.poll_write_vectored(cx, &slices[..cnt]))? - } else { - ready!(io.poll_write(cx, buf.chunk()))? - }; - - buf.advance(n); - - Poll::Ready(Ok(n)) - } -} +mod util; diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs index c44be69..5ef8ba2 100644 --- a/src/sync/cancellation_token.rs +++ b/src/sync/cancellation_token.rs @@ -4,6 +4,7 @@ pub(crate) mod guard; mod tree_node; use crate::loom::sync::Arc; +use crate::util::MaybeDangling; use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; @@ -77,11 +78,23 @@ pin_project! { /// [`CancellationToken`] by value instead of using a reference. #[must_use = "futures do nothing unless polled"] pub struct WaitForCancellationFutureOwned { - // Since `future` is the first field, it is dropped before the - // cancellation_token field. This ensures that the reference inside the - // `Notified` remains valid. + // This field internally has a reference to the cancellation token, but camouflages + // the relationship with `'static`. To avoid Undefined Behavior, we must ensure + // that the reference is only used while the cancellation token is still alive. To + // do that, we ensure that the future is the first field, so that it is dropped + // before the cancellation token. + // + // We use `MaybeDanglingFuture` here because without it, the compiler could assert + // the reference inside `future` to be valid even after the destructor of that + // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed + // as an argument to a function, the reference can be asserted to be valid for the + // rest of that function.) To avoid that, we use `MaybeDangling` which tells the + // compiler that the reference stored inside it might not be valid. + // + // See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706> + // for more info. #[pin] - future: tokio::sync::futures::Notified<'static>, + future: MaybeDangling<tokio::sync::futures::Notified<'static>>, cancellation_token: CancellationToken, } } @@ -97,6 +110,8 @@ impl core::fmt::Debug for CancellationToken { } impl Clone for CancellationToken { + /// Creates a clone of the `CancellationToken` which will get cancelled + /// whenever the current token gets cancelled, and vice versa. fn clone(&self) -> Self { tree_node::increase_handle_refcount(&self.inner); CancellationToken { @@ -118,7 +133,7 @@ impl Default for CancellationToken { } impl CancellationToken { - /// Creates a new CancellationToken in the non-cancelled state. + /// Creates a new `CancellationToken` in the non-cancelled state. pub fn new() -> CancellationToken { CancellationToken { inner: Arc::new(tree_node::TreeNode::new()), @@ -126,7 +141,8 @@ impl CancellationToken { } /// Creates a `CancellationToken` which will get cancelled whenever the - /// current token gets cancelled. + /// current token gets cancelled. Unlike a cloned `CancellationToken`, + /// cancelling a child token does not cancel the parent token. /// /// If the current token is already cancelled, the child token will get /// returned in cancelled state. @@ -276,7 +292,7 @@ impl WaitForCancellationFutureOwned { // # Safety // // cancellation_token is dropped after future due to the field ordering. - future: unsafe { Self::new_future(&cancellation_token) }, + future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }), cancellation_token, } } @@ -317,8 +333,9 @@ impl Future for WaitForCancellationFutureOwned { // # Safety // // cancellation_token is dropped after future due to the field ordering. - this.future - .set(unsafe { Self::new_future(this.cancellation_token) }); + this.future.set(MaybeDangling::new(unsafe { + Self::new_future(this.cancellation_token) + })); } } } diff --git a/src/sync/cancellation_token/tree_node.rs b/src/sync/cancellation_token/tree_node.rs index 8f97dee..b7a9805 100644 --- a/src/sync/cancellation_token/tree_node.rs +++ b/src/sync/cancellation_token/tree_node.rs @@ -1,12 +1,12 @@ //! This mod provides the logic for the inner tree structure of the CancellationToken. //! -//! CancellationTokens are only light handles with references to TreeNode. -//! All the logic is actually implemented in the TreeNode. +//! CancellationTokens are only light handles with references to [`TreeNode`]. +//! All the logic is actually implemented in the [`TreeNode`]. //! -//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of +//! A [`TreeNode`] is part of the cancellation tree and may have one parent and an arbitrary number of //! children. //! -//! A TreeNode can receive the request to perform a cancellation through a CancellationToken. +//! A [`TreeNode`] can receive the request to perform a cancellation through a CancellationToken. //! This cancellation request will cancel the node and all of its descendants. //! //! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no @@ -151,47 +151,43 @@ fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret where F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret, { - let mut potential_parent = { - let locked_node = node.inner.lock().unwrap(); - match locked_node.parent.clone() { - Some(parent) => parent, - // If we locked the node and its parent is `None`, we are in a valid state - // and can return. - None => return func(locked_node, None), - } - }; + use std::sync::TryLockError; + let mut locked_node = node.inner.lock().unwrap(); + + // Every time this fails, the number of ancestors of the node decreases, + // so the loop must succeed after a finite number of iterations. loop { - // Deadlock safety: - // - // Due to invariant #2, we know that we have to lock the parent first, and then the child. - // This is true even if the potential_parent is no longer the current parent or even its - // sibling, as the invariant still holds. - let locked_parent = potential_parent.inner.lock().unwrap(); - let locked_node = node.inner.lock().unwrap(); - - let actual_parent = match locked_node.parent.clone() { - Some(parent) => parent, - // If we locked the node and its parent is `None`, we are in a valid state - // and can return. - None => { - // Was the wrong parent, so unlock it before calling `func` - drop(locked_parent); - return func(locked_node, None); + // Look up the parent of the currently locked node. + let potential_parent = match locked_node.parent.as_ref() { + Some(potential_parent) => potential_parent.clone(), + None => return func(locked_node, None), + }; + + // Lock the parent. This may require unlocking the child first. + let locked_parent = match potential_parent.inner.try_lock() { + Ok(locked_parent) => locked_parent, + Err(TryLockError::WouldBlock) => { + drop(locked_node); + // Deadlock safety: + // + // Due to invariant #2, the potential parent must come before + // the child in the creation order. Therefore, we can safely + // lock the child while holding the parent lock. + let locked_parent = potential_parent.inner.lock().unwrap(); + locked_node = node.inner.lock().unwrap(); + locked_parent } + Err(TryLockError::Poisoned(err)) => Err(err).unwrap(), }; - // Loop until we managed to lock both the node and its parent - if Arc::ptr_eq(&actual_parent, &potential_parent) { - return func(locked_node, Some(locked_parent)); + // If we unlocked the child, then the parent may have changed. Check + // that we still have the right parent. + if let Some(actual_parent) = locked_node.parent.as_ref() { + if Arc::ptr_eq(actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } } - - // Drop locked_parent before reassigning to potential_parent, - // as potential_parent is borrowed in it - drop(locked_node); - drop(locked_parent); - - potential_parent = actual_parent; } } @@ -243,11 +239,7 @@ fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { let len = parent.children.len(); if 4 * len <= parent.children.capacity() { - // equal to: - // parent.children.shrink_to(2 * len); - // but shrink_to was not yet stabilized in our minimal compatible version - let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len)); - parent.children.extend(old_children); + parent.children.shrink_to(2 * len); } } diff --git a/src/sync/mpsc.rs b/src/sync/mpsc.rs index 55ed5c4..fd48c72 100644 --- a/src/sync/mpsc.rs +++ b/src/sync/mpsc.rs @@ -44,7 +44,7 @@ enum State<T> { pub struct PollSender<T> { sender: Option<Sender<T>>, state: State<T>, - acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>, + acquire: PollSenderFuture<T>, } // Creates a future for acquiring a permit from the underlying channel. This is used to ensure @@ -64,13 +64,56 @@ async fn make_acquire_future<T>( } } -impl<T: Send + 'static> PollSender<T> { +type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>; + +#[derive(Debug)] +// TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes +struct PollSenderFuture<T>(InnerFuture<'static, T>); + +impl<T> PollSenderFuture<T> { + /// Create with an empty inner future with no `Send` bound. + fn empty() -> Self { + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender<T>`. + Self(ReusableBoxFuture::new(async { unreachable!() })) + } +} + +impl<T: Send> PollSenderFuture<T> { + /// Create with an empty inner future. + fn new() -> Self { + let v = InnerFuture::new(make_acquire_future(None)); + // This is safe because `make_acquire_future(None)` is actually `'static` + Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) }) + } + + /// Poll the inner future. + fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> { + self.0.poll(cx) + } + + /// Replace the inner future. + fn set(&mut self, sender: Option<Sender<T>>) { + let inner: *mut InnerFuture<'static, T> = &mut self.0; + let inner: *mut InnerFuture<'_, T> = inner.cast(); + // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T` + // becomes invalid, and this casts away the type-level lifetime check for that. However, the + // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not + // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed + // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so + // this is ok. + let inner = unsafe { &mut *inner }; + inner.set(make_acquire_future(sender)); + } +} + +impl<T: Send> PollSender<T> { /// Creates a new `PollSender`. pub fn new(sender: Sender<T>) -> Self { Self { sender: Some(sender.clone()), state: State::Idle(sender), - acquire: ReusableBoxFuture::new(make_acquire_future(None)), + acquire: PollSenderFuture::new(), } } @@ -97,7 +140,7 @@ impl<T: Send + 'static> PollSender<T> { State::Idle(sender) => { // Start trying to acquire a permit to reserve a slot for our send, and // immediately loop back around to poll it the first time. - self.acquire.set(make_acquire_future(Some(sender))); + self.acquire.set(Some(sender)); (None, State::Acquiring) } State::Acquiring => match self.acquire.poll(cx) { @@ -194,7 +237,7 @@ impl<T: Send + 'static> PollSender<T> { match self.state { State::Idle(_) => self.state = State::Closed, State::Acquiring => { - self.acquire.set(make_acquire_future(None)); + self.acquire.set(None); self.state = State::Closed; } _ => {} @@ -215,7 +258,7 @@ impl<T: Send + 'static> PollSender<T> { // We're currently trying to reserve a slot to send into. State::Acquiring => { // Replacing the future drops the in-flight one. - self.acquire.set(make_acquire_future(None)); + self.acquire.set(None); // If we haven't closed yet, we have to clone our stored sender since we have no way // to get it back from the acquire future we just dropped. @@ -255,9 +298,7 @@ impl<T> Clone for PollSender<T> { Self { sender, state, - // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not - // compatible with the transitive bounds required by `Sender<T>`. - acquire: ReusableBoxFuture::new(async { unreachable!() }), + acquire: PollSenderFuture::empty(), } } } diff --git a/src/sync/poll_semaphore.rs b/src/sync/poll_semaphore.rs index 6b44574..4960a7c 100644 --- a/src/sync/poll_semaphore.rs +++ b/src/sync/poll_semaphore.rs @@ -29,7 +29,7 @@ impl PollSemaphore { /// Closes the semaphore. pub fn close(&self) { - self.semaphore.close() + self.semaphore.close(); } /// Obtain a clone of the inner semaphore. diff --git a/src/sync/reusable_box.rs b/src/sync/reusable_box.rs index 1b8ef60..1fae38c 100644 --- a/src/sync/reusable_box.rs +++ b/src/sync/reusable_box.rs @@ -1,7 +1,6 @@ use std::alloc::Layout; use std::fmt; -use std::future::Future; -use std::marker::PhantomData; +use std::future::{self, Future}; use std::mem::{self, ManuallyDrop}; use std::pin::Pin; use std::ptr; @@ -61,7 +60,7 @@ impl<'a, T> ReusableBoxFuture<'a, T> { F: Future + Send + 'a, { // future::Pending<T> is a ZST so this never allocates. - let boxed = mem::replace(&mut this.boxed, Box::pin(Pending(PhantomData))); + let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending())); reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed)) } @@ -156,16 +155,3 @@ impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> { f(); } } - -/// The same as `std::future::Pending<T>`; we can't use that type directly because on rustc -/// versions <1.60 it didn't unconditionally implement `Send`. -// FIXME: use `std::future::Pending<T>` once the MSRV is >=1.60 -struct Pending<T>(PhantomData<fn() -> T>); - -impl<T> Future for Pending<T> { - type Output = T; - - fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { - Poll::Pending - } -} diff --git a/src/task/join_map.rs b/src/task/join_map.rs index c6bf5bc..1fbe274 100644 --- a/src/task/join_map.rs +++ b/src/task/join_map.rs @@ -5,6 +5,7 @@ use std::collections::hash_map::RandomState; use std::fmt; use std::future::Future; use std::hash::{BuildHasher, Hash, Hasher}; +use std::marker::PhantomData; use tokio::runtime::Handle; use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; @@ -316,6 +317,60 @@ where self.insert(key, task); } + /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided + /// key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// Note that blocking tasks cannot be cancelled after execution starts. + /// Replaced blocking tasks will still run to completion if the task has begun + /// to execute when it is replaced. A blocking task which is replaced before + /// it has been scheduled on a blocking worker thread will be cancelled. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_blocking<F>(&mut self, key: K, f: F) + where + F: FnOnce() -> V, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_blocking(f); + self.insert(key, task) + } + + /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// Note that blocking tasks cannot be cancelled after execution starts. + /// Replaced blocking tasks will still run to completion if the task has begun + /// to execute when it is replaced. A blocking task which is replaced before + /// it has been scheduled on a blocking worker thread will be cancelled. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) + where + F: FnOnce() -> V, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_blocking_on(f, handle); + self.insert(key, task); + } + /// Spawn the provided task on the current [`LocalSet`] and store it in this /// `JoinMap` with the provided key. /// @@ -572,6 +627,19 @@ where } } + /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order. + /// + /// If a task has completed, but its output hasn't yet been consumed by a + /// call to [`join_next`], this method will still return its key. + /// + /// [`join_next`]: fn@Self::join_next + pub fn keys(&self) -> JoinMapKeys<'_, K, V> { + JoinMapKeys { + iter: self.tasks_by_key.keys(), + _value: PhantomData, + } + } + /// Returns `true` if this `JoinMap` contains a task for the provided key. /// /// If the task has completed, but its output hasn't yet been consumed by a @@ -805,3 +873,32 @@ impl<K: PartialEq> PartialEq for Key<K> { } impl<K: Eq> Eq for Key<K> {} + +/// An iterator over the keys of a [`JoinMap`]. +#[derive(Debug, Clone)] +pub struct JoinMapKeys<'a, K, V> { + iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>, + /// To make it easier to change JoinMap in the future, keep V as a generic + /// parameter. + _value: PhantomData<&'a V>, +} + +impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { + type Item = &'a K; + + fn next(&mut self) -> Option<&'a K> { + self.iter.next().map(|key| &key.key) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> { + fn len(&self) -> usize { + self.iter.len() + } +} + +impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {} diff --git a/src/task/mod.rs b/src/task/mod.rs index de41dd5..e37015a 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -9,4 +9,7 @@ pub use spawn_pinned::LocalPoolHandle; #[cfg(tokio_unstable)] #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] -pub use join_map::JoinMap; +pub use join_map::{JoinMap, JoinMapKeys}; + +pub mod task_tracker; +pub use task_tracker::TaskTracker; diff --git a/src/task/task_tracker.rs b/src/task/task_tracker.rs new file mode 100644 index 0000000..d8f3bb4 --- /dev/null +++ b/src/task/task_tracker.rs @@ -0,0 +1,719 @@ +//! Types related to the [`TaskTracker`] collection. +//! +//! See the documentation of [`TaskTracker`] for more information. + +use pin_project_lite::pin_project; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{futures::Notified, Notify}; + +#[cfg(feature = "rt")] +use tokio::{ + runtime::Handle, + task::{JoinHandle, LocalSet}, +}; + +/// A task tracker used for waiting until tasks exit. +/// +/// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The +/// `CancellationToken` is used to signal to tasks that they should shut down, and the +/// `TaskTracker` is used to wait for them to finish shutting down. +/// +/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case +/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the +/// [`wait`] method will wait until *both* of the following happen at the same time: +/// +/// * The `TaskTracker` must be closed using the [`close`] method. +/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited. +/// +/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that +/// the destructor of the future has finished running. However, there might be a short amount of +/// time where [`JoinHandle::is_finished`] returns false. +/// +/// # Comparison to `JoinSet` +/// +/// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a +/// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its +/// unique features is required: +/// +/// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory. +/// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from from returning even if +/// the `TaskTracker` is empty. +/// 3. A `TaskTracker` does not require mutable access to insert tasks. +/// 4. A `TaskTracker` can be cloned to share it with many tasks. +/// +/// The first point is the most important one. A [`JoinSet`] keeps track of the return value of +/// every inserted task. This means that if the caller keeps inserting tasks and never calls +/// [`join_next`], then their return values will keep building up and consuming memory, _even if_ +/// most of the tasks have already exited. This can cause the process to run out of memory. With a +/// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the +/// `TaskTracker`. +/// +/// # Examples +/// +/// For more examples, please see the topic page on [graceful shutdown]. +/// +/// ## Spawn tasks and wait for them to exit +/// +/// This is a simple example. For this case, [`JoinSet`] should probably be used instead. +/// +/// ``` +/// use tokio_util::task::TaskTracker; +/// +/// #[tokio::main] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// +/// for i in 0..10 { +/// tracker.spawn(async move { +/// println!("Task {} is running!", i); +/// }); +/// } +/// // Once we spawned everything, we close the tracker. +/// tracker.close(); +/// +/// // Wait for everything to finish. +/// tracker.wait().await; +/// +/// println!("This is printed after all of the tasks."); +/// } +/// ``` +/// +/// ## Wait for tasks to exit +/// +/// This example shows the intended use-case of `TaskTracker`. It is used together with +/// [`CancellationToken`] to implement graceful shutdown. +/// ``` +/// use tokio_util::sync::CancellationToken; +/// use tokio_util::task::TaskTracker; +/// use tokio::time::{self, Duration}; +/// +/// async fn background_task(num: u64) { +/// for i in 0..10 { +/// time::sleep(Duration::from_millis(100*num)).await; +/// println!("Background task {} in iteration {}.", num, i); +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _hidden() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// let token = CancellationToken::new(); +/// +/// for i in 0..10 { +/// let token = token.clone(); +/// tracker.spawn(async move { +/// // Use a `tokio::select!` to kill the background task if the token is +/// // cancelled. +/// tokio::select! { +/// () = background_task(i) => { +/// println!("Task {} exiting normally.", i); +/// }, +/// () = token.cancelled() => { +/// // Do some cleanup before we really exit. +/// time::sleep(Duration::from_millis(50)).await; +/// println!("Task {} finished cleanup.", i); +/// }, +/// } +/// }); +/// } +/// +/// // Spawn a background task that will send the shutdown signal. +/// { +/// let tracker = tracker.clone(); +/// tokio::spawn(async move { +/// // Normally you would use something like ctrl-c instead of +/// // sleeping. +/// time::sleep(Duration::from_secs(2)).await; +/// tracker.close(); +/// token.cancel(); +/// }); +/// } +/// +/// // Wait for all tasks to exit. +/// tracker.wait().await; +/// +/// println!("All tasks have exited now."); +/// } +/// ``` +/// +/// [`CancellationToken`]: crate::sync::CancellationToken +/// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished +/// [`JoinSet`]: tokio::task::JoinSet +/// [`close`]: Self::close +/// [`join_next`]: tokio::task::JoinSet::join_next +/// [`wait`]: Self::wait +/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown +pub struct TaskTracker { + inner: Arc<TaskTrackerInner>, +} + +/// Represents a task tracked by a [`TaskTracker`]. +#[must_use] +#[derive(Debug)] +pub struct TaskTrackerToken { + task_tracker: TaskTracker, +} + +struct TaskTrackerInner { + /// Keeps track of the state. + /// + /// The lowest bit is whether the task tracker is closed. + /// + /// The rest of the bits count the number of tracked tasks. + state: AtomicUsize, + /// Used to notify when the last task exits. + on_last_exit: Notify, +} + +pin_project! { + /// A future that is tracked as a task by a [`TaskTracker`]. + /// + /// The associated [`TaskTracker`] cannot complete until this future is dropped. + /// + /// This future is returned by [`TaskTracker::track_future`]. + #[must_use = "futures do nothing unless polled"] + pub struct TrackedFuture<F> { + #[pin] + future: F, + token: TaskTrackerToken, + } +} + +pin_project! { + /// A future that completes when the [`TaskTracker`] is empty and closed. + /// + /// This future is returned by [`TaskTracker::wait`]. + #[must_use = "futures do nothing unless polled"] + pub struct TaskTrackerWaitFuture<'a> { + #[pin] + future: Notified<'a>, + inner: Option<&'a TaskTrackerInner>, + } +} + +impl TaskTrackerInner { + #[inline] + fn new() -> Self { + Self { + state: AtomicUsize::new(0), + on_last_exit: Notify::new(), + } + } + + #[inline] + fn is_closed_and_empty(&self) -> bool { + // If empty and closed bit set, then we are done. + // + // The acquire load will synchronize with the release store of any previous call to + // `set_closed` and `drop_task`. + self.state.load(Ordering::Acquire) == 1 + } + + #[inline] + fn set_closed(&self) -> bool { + // The AcqRel ordering makes the closed bit behave like a `Mutex<bool>` for synchronization + // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}` + // more meaningful for the user. Without these orderings, this assert could fail: + // ``` + // // thread 1 + // some_other_atomic.store(true, Relaxed); + // tracker.close(); + // + // // thread 2 + // if tracker.reopen() { + // assert!(some_other_atomic.load(Relaxed)); + // } + // ``` + // However, with the AcqRel ordering, we establish a happens-before relationship from the + // call to `close` and the later call to `reopen` that returned true. + let state = self.state.fetch_or(1, Ordering::AcqRel); + + // If there are no tasks, and if it was not already closed: + if state == 0 { + self.notify_now(); + } + + (state & 1) == 0 + } + + #[inline] + fn set_open(&self) -> bool { + // See `set_closed` regarding the AcqRel ordering. + let state = self.state.fetch_and(!1, Ordering::AcqRel); + (state & 1) == 1 + } + + #[inline] + fn add_task(&self) { + self.state.fetch_add(2, Ordering::Relaxed); + } + + #[inline] + fn drop_task(&self) { + let state = self.state.fetch_sub(2, Ordering::Release); + + // If this was the last task and we are closed: + if state == 3 { + self.notify_now(); + } + } + + #[cold] + fn notify_now(&self) { + // Insert an acquire fence. This matters for `drop_task` but doesn't matter for + // `set_closed` since it already uses AcqRel. + // + // This synchronizes with the release store of any other call to `drop_task`, and with the + // release store in the call to `set_closed`. That ensures that everything that happened + // before those other calls to `drop_task` or `set_closed` will be visible after this load, + // and those things will also be visible to anything woken by the call to `notify_waiters`. + self.state.load(Ordering::Acquire); + + self.on_last_exit.notify_waiters(); + } +} + +impl TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[must_use] + pub fn new() -> Self { + Self { + inner: Arc::new(TaskTrackerInner::new()), + } + } + + /// Waits until this `TaskTracker` is both closed and empty. + /// + /// If the `TaskTracker` is already closed and empty when this method is called, then it + /// returns immediately. + /// + /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker` + /// becomes both closed and empty for a short amount of time, then it is guarantee that all + /// `wait` futures that were created before the short time interval will trigger, even if they + /// are not polled during that short time interval. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + /// + /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the + /// condition in a `tokio::select!` loop. + /// + /// [aba]: https://en.wikipedia.org/wiki/ABA_problem + #[inline] + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + TaskTrackerWaitFuture { + future: self.inner.on_last_exit.notified(), + inner: if self.inner.is_closed_and_empty() { + None + } else { + Some(&self.inner) + }, + } + } + + /// Close this `TaskTracker`. + /// + /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks. + /// + /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn close(&self) -> bool { + self.inner.set_closed() + } + + /// Reopen this `TaskTracker`. + /// + /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty. + /// + /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn reopen(&self) -> bool { + self.inner.set_open() + } + + /// Returns `true` if this `TaskTracker` is [closed](Self::close). + #[inline] + #[must_use] + pub fn is_closed(&self) -> bool { + (self.inner.state.load(Ordering::Acquire) & 1) != 0 + } + + /// Returns the number of tasks tracked by this `TaskTracker`. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.inner.state.load(Ordering::Acquire) >> 1 + } + + /// Returns `true` if there are no tasks in this `TaskTracker`. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.inner.state.load(Ordering::Acquire) <= 1 + } + + /// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + tokio::task::spawn(self.track_future(task)) + } + + /// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_on<F>(&self, task: F, handle: &Handle) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + handle.spawn(self.track_future(task)) + } + + /// Spawn the provided future on the current [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output> + where + F: Future + 'static, + F::Output: 'static, + { + tokio::task::spawn_local(self.track_future(task)) + } + + /// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local_on<F>(&self, task: F, local_set: &LocalSet) -> JoinHandle<F::Output> + where + F: Future + 'static, + F::Output: 'static, + { + local_set.spawn_local(self.track_future(task)) + } + + /// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T> + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + tokio::task::spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T> + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + handle.spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Track the provided future. + /// + /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will + /// prevent calls to [`wait`] from returning until the task is dropped. + /// + /// The task is removed from the collection when it is dropped, not when [`poll`] returns + /// [`Poll::Ready`]. + /// + /// # Examples + /// + /// Track a future spawned with [`tokio::spawn`]. + /// + /// ``` + /// # async fn my_async_fn() {} + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// + /// tokio::spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// Track a future spawned on a [`JoinSet`]. + /// ``` + /// # async fn my_async_fn() {} + /// use tokio::task::JoinSet; + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// let mut join_set = JoinSet::new(); + /// + /// join_set.spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// [`JoinSet`]: tokio::task::JoinSet + /// [`Poll::Pending`]: std::task::Poll::Pending + /// [`poll`]: std::future::Future::poll + /// [`wait`]: Self::wait + #[inline] + pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> { + TrackedFuture { + future, + token: self.token(), + } + } + + /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`. + /// + /// This token is a lower-level utility than the spawn methods. Each token is considered to + /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete. + /// Furthermore, the count returned by the [`len`] method will include the tokens in the count. + /// + /// Dropping the token indicates to the `TaskTracker` that the task has exited. + /// + /// [`len`]: TaskTracker::len + #[inline] + pub fn token(&self) -> TaskTrackerToken { + self.inner.add_task(); + TaskTrackerToken { + task_tracker: self.clone(), + } + } + + /// Returns `true` if both task trackers correspond to the same set of tasks. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::TaskTracker; + /// + /// let tracker_1 = TaskTracker::new(); + /// let tracker_2 = TaskTracker::new(); + /// let tracker_1_clone = tracker_1.clone(); + /// + /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone)); + /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2)); + /// ``` + #[inline] + #[must_use] + pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool { + Arc::ptr_eq(&left.inner, &right.inner) + } +} + +impl Default for TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[inline] + fn default() -> TaskTracker { + TaskTracker::new() + } +} + +impl Clone for TaskTracker { + /// Returns a new `TaskTracker` that tracks the same set of tasks. + /// + /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in + /// all other clones. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::TaskTracker; + /// + /// #[tokio::main] + /// # async fn _hidden() {} + /// # #[tokio::main(flavor = "current_thread")] + /// async fn main() { + /// let tracker = TaskTracker::new(); + /// let cloned = tracker.clone(); + /// + /// // Spawns on `tracker` are visible in `cloned`. + /// tracker.spawn(std::future::pending::<()>()); + /// assert_eq!(cloned.len(), 1); + /// + /// // Spawns on `cloned` are visible in `tracker`. + /// cloned.spawn(std::future::pending::<()>()); + /// assert_eq!(tracker.len(), 2); + /// + /// // Calling `close` is visible to `cloned`. + /// tracker.close(); + /// assert!(cloned.is_closed()); + /// + /// // Calling `reopen` is visible to `tracker`. + /// cloned.reopen(); + /// assert!(!tracker.is_closed()); + /// } + /// ``` + #[inline] + fn clone(&self) -> TaskTracker { + Self { + inner: self.inner.clone(), + } + } +} + +fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = inner.state.load(Ordering::Acquire); + let is_closed = (state & 1) != 0; + let len = state >> 1; + + f.debug_struct("TaskTracker") + .field("len", &len) + .field("is_closed", &is_closed) + .field("inner", &(inner as *const TaskTrackerInner)) + .finish() +} + +impl fmt::Debug for TaskTracker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(&self.inner, f) + } +} + +impl TaskTrackerToken { + /// Returns the [`TaskTracker`] that this token is associated with. + #[inline] + #[must_use] + pub fn task_tracker(&self) -> &TaskTracker { + &self.task_tracker + } +} + +impl Clone for TaskTrackerToken { + /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`]. + /// + /// This is equivalent to `token.task_tracker().token()`. + #[inline] + fn clone(&self) -> TaskTrackerToken { + self.task_tracker.token() + } +} + +impl Drop for TaskTrackerToken { + /// Dropping the token indicates to the [`TaskTracker`] that the task has exited. + #[inline] + fn drop(&mut self) { + self.task_tracker.inner.drop_task(); + } +} + +impl<F: Future> Future for TrackedFuture<F> { + type Output = F::Output; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + self.project().future.poll(cx) + } +} + +impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TrackedFuture") + .field("future", &self.future) + .field("task_tracker", self.token.task_tracker()) + .finish() + } +} + +impl<'a> Future for TaskTrackerWaitFuture<'a> { + type Output = (); + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let me = self.project(); + + let inner = match me.inner.as_ref() { + None => return Poll::Ready(()), + Some(inner) => inner, + }; + + let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready(); + if ready { + *me.inner = None; + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Helper<'a>(&'a TaskTrackerInner); + + impl fmt::Debug for Helper<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(self.0, f) + } + } + + f.debug_struct("TaskTrackerWaitFuture") + .field("future", &self.future) + .field("task_tracker", &self.inner.map(Helper)) + .finish() + } +} diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs index ee66adb..9136d90 100644 --- a/src/time/delay_queue.rs +++ b/src/time/delay_queue.rs @@ -62,7 +62,7 @@ use std::task::{self, Poll, Waker}; /// performance and scalability benefits. /// /// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, -/// and allows reuse of the memory allocated for expired entires. +/// and allows reuse of the memory allocated for expired entries. /// /// Capacity can be checked using [`capacity`] and allocated preemptively by using /// the [`reserve`] method. @@ -874,6 +874,41 @@ impl<T> DelayQueue<T> { self.slab.compact(); } + /// Gets the [`Key`] that [`poll_expired`] will pull out of the queue next, without + /// pulling it out or waiting for the deadline to expire. + /// + /// Entries that have already expired may be returned in any order, but it is + /// guaranteed that this method returns them in the same order as when items + /// are popped from the `DelayQueue`. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// let key1 = delay_queue.insert("foo", Duration::from_secs(10)); + /// let key2 = delay_queue.insert("bar", Duration::from_secs(5)); + /// let key3 = delay_queue.insert("baz", Duration::from_secs(15)); + /// + /// assert_eq!(delay_queue.peek().unwrap(), key2); + /// # } + /// ``` + /// + /// [`Key`]: struct@Key + /// [`poll_expired`]: method@Self::poll_expired + pub fn peek(&self) -> Option<Key> { + use self::wheel::Stack; + + self.expired.peek().or_else(|| self.wheel.peek()) + } + /// Returns the next time to poll as determined by the wheel fn next_deadline(&mut self) -> Option<Instant> { self.wheel @@ -1166,6 +1201,10 @@ impl<T> wheel::Stack for Stack<T> { } } + fn peek(&self) -> Option<Self::Owned> { + self.head + } + #[track_caller] fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { let key = *item; diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs index 8ea30af..4290acf 100644 --- a/src/time/wheel/level.rs +++ b/src/time/wheel/level.rs @@ -140,11 +140,31 @@ impl<T: Stack> Level<T> { // TODO: This can probably be simplified w/ power of 2 math let level_start = now - (now % level_range); - let deadline = level_start + slot as u64 * slot_range; - + let mut deadline = level_start + slot as u64 * slot_range; + if deadline < now { + // A timer is in a slot "prior" to the current time. This can occur + // because we do not have an infinite hierarchy of timer levels, and + // eventually a timer scheduled for a very distant time might end up + // being placed in a slot that is beyond the end of all of the + // arrays. + // + // To deal with this, we first limit timers to being scheduled no + // more than MAX_DURATION ticks in the future; that is, they're at + // most one rotation of the top level away. Then, we force timers + // that logically would go into the top+1 level, to instead go into + // the top level's slots. + // + // What this means is that the top level's slots act as a + // pseudo-ring buffer, and we rotate around them indefinitely. If we + // compute a deadline before now, and it's the top level, it + // therefore means we're actually looking at a slot in the future. + debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + + deadline += level_range; + } debug_assert!( deadline >= now, - "deadline={}; now={}; level={}; slot={}; occupied={:b}", + "deadline={:016X}; now={:016X}; level={}; slot={}; occupied={:b}", deadline, now, self.level, @@ -206,6 +226,10 @@ impl<T: Stack> Level<T> { ret } + + pub(crate) fn peek_entry_slot(&self, slot: usize) -> Option<T::Owned> { + self.slot[slot].peek() + } } impl<T> fmt::Debug for Level<T> { diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs index ffa05ab..10a9900 100644 --- a/src/time/wheel/mod.rs +++ b/src/time/wheel/mod.rs @@ -139,6 +139,12 @@ where self.next_expiration().map(|expiration| expiration.deadline) } + /// Next key that will expire + pub(crate) fn peek(&self) -> Option<T::Owned> { + self.next_expiration() + .and_then(|expiration| self.peek_entry(&expiration)) + } + /// Advances the timer up to the instant represented by `now`. pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> { loop { @@ -244,6 +250,10 @@ where self.levels[expiration.level].pop_entry_slot(expiration.slot, store) } + fn peek_entry(&self, expiration: &Expiration) -> Option<T::Owned> { + self.levels[expiration.level].peek_entry_slot(expiration.slot) + } + fn level_for(&self, when: u64) -> usize { level_for(self.elapsed, when) } @@ -254,8 +264,11 @@ fn level_for(elapsed: u64, when: u64) -> usize { // Mask in the trailing bits ignored by the level calculation in order to cap // the possible leading zeros - let masked = elapsed ^ when | SLOT_MASK; - + let mut masked = elapsed ^ when | SLOT_MASK; + if masked >= MAX_DURATION { + // Fudge the timer into the top level + masked = MAX_DURATION - 1; + } let leading_zeros = masked.leading_zeros() as usize; let significant = 63 - leading_zeros; significant / 6 diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs index c87adca..7d32f27 100644 --- a/src/time/wheel/stack.rs +++ b/src/time/wheel/stack.rs @@ -22,6 +22,9 @@ pub(crate) trait Stack: Default { /// Pop an item from the stack fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>; + /// Peek into the stack. + fn peek(&self) -> Option<Self::Owned>; + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; diff --git a/src/util/maybe_dangling.rs b/src/util/maybe_dangling.rs new file mode 100644 index 0000000..c29a089 --- /dev/null +++ b/src/util/maybe_dangling.rs @@ -0,0 +1,67 @@ +use core::future::Future; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// A wrapper type that tells the compiler that the contents might not be valid. +/// +/// This is necessary mainly when `T` contains a reference. In that case, the +/// compiler will sometimes assume that the reference is always valid; in some +/// cases it will assume this even after the destructor of `T` runs. For +/// example, when a reference is used as a function argument, then the compiler +/// will assume that the reference is valid until the function returns, even if +/// the reference is destroyed during the function. When the reference is used +/// as part of a self-referential struct, that assumption can be false. Wrapping +/// the reference in this type prevents the compiler from making that +/// assumption. +/// +/// # Invariants +/// +/// The `MaybeUninit` will always contain a valid value until the destructor runs. +// +// Reference +// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706> +// +// TODO: replace this with an official solution once RFC #3336 or similar is available. +// <https://github.com/rust-lang/rfcs/pull/3336> +#[repr(transparent)] +pub(crate) struct MaybeDangling<T>(MaybeUninit<T>); + +impl<T> Drop for MaybeDangling<T> { + fn drop(&mut self) { + // Safety: `0` is always initialized. + unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) }; + } +} + +impl<T> MaybeDangling<T> { + pub(crate) fn new(inner: T) -> Self { + Self(MaybeUninit::new(inner)) + } +} + +impl<F: Future> Future for MaybeDangling<F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // Safety: `0` is always initialized. + let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) }; + fut.poll(cx) + } +} + +#[test] +fn maybedangling_runs_drop() { + struct SetOnDrop<'a>(&'a mut bool); + + impl Drop for SetOnDrop<'_> { + fn drop(&mut self) { + *self.0 = true; + } + } + + let mut success = false; + + drop(MaybeDangling::new(SetOnDrop(&mut success))); + assert!(success); +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..a17f25a --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,8 @@ +mod maybe_dangling; +#[cfg(any(feature = "io", feature = "codec"))] +mod poll_buf; + +pub(crate) use maybe_dangling::MaybeDangling; +#[cfg(any(feature = "io", feature = "codec"))] +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub use poll_buf::{poll_read_buf, poll_write_buf}; diff --git a/src/util/poll_buf.rs b/src/util/poll_buf.rs new file mode 100644 index 0000000..82af1bb --- /dev/null +++ b/src/util/poll_buf.rs @@ -0,0 +1,145 @@ +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use bytes::{Buf, BufMut}; +use futures_core::ready; +use std::io::{self, IoSlice}; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. +/// +/// [`BufMut`]: bytes::Buf +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio_stream as stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, poll_read_buf}; +/// use futures::future::poll_fn; +/// use std::pin::Pin; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub fn poll_read_buf<T: AsyncRead, B: BufMut>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll<io::Result<usize>> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + + // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a + // transparent wrapper around `[MaybeUninit<u8>]`. + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} + +/// Try to write data from an implementer of the [`Buf`] trait to an +/// [`AsyncWrite`], advancing the buffer's internal cursor. +/// +/// This function will use [vectored writes] when the [`AsyncWrite`] supports +/// vectored writes. +/// +/// # Examples +/// +/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements +/// [`Buf`]: +/// +/// ```no_run +/// use tokio_util::io::poll_write_buf; +/// use tokio::io; +/// use tokio::fs::File; +/// +/// use bytes::Buf; +/// use std::io::Cursor; +/// use std::pin::Pin; +/// use futures::future::poll_fn; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let mut file = File::create("foo.txt").await?; +/// let mut buf = Cursor::new(b"data to write"); +/// +/// // Loop until the entire contents of the buffer are written to +/// // the file. +/// while buf.has_remaining() { +/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; +/// } +/// +/// Ok(()) +/// } +/// ``` +/// +/// [`Buf`]: bytes::Buf +/// [`AsyncWrite`]: tokio::io::AsyncWrite +/// [`File`]: tokio::fs::File +/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub fn poll_write_buf<T: AsyncWrite, B: Buf>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll<io::Result<usize>> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.chunks_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.chunk()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) +} diff --git a/tests/compat.rs b/tests/compat.rs new file mode 100644 index 0000000..278ebfc --- /dev/null +++ b/tests/compat.rs @@ -0,0 +1,43 @@ +#![cfg(all(feature = "compat"))] +#![cfg(not(target_os = "wasi"))] // WASI does not support all fs operations +#![warn(rust_2018_idioms)] + +use futures_io::SeekFrom; +use futures_util::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; +use tempfile::NamedTempFile; +use tokio::fs::OpenOptions; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +#[tokio::test] +async fn compat_file_seek() -> futures_util::io::Result<()> { + let temp_file = NamedTempFile::new()?; + let mut file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(temp_file) + .await? + .compat_write(); + + file.write_all(&[0, 1, 2, 3, 4, 5]).await?; + file.write_all(&[6, 7]).await?; + + assert_eq!(file.stream_position().await?, 8); + + // Modify elements at position 2. + assert_eq!(file.seek(SeekFrom::Start(2)).await?, 2); + file.write_all(&[8, 9]).await?; + + file.flush().await?; + + // Verify we still have 8 elements. + assert_eq!(file.seek(SeekFrom::End(0)).await?, 8); + // Seek back to the start of the file to read and verify contents. + file.seek(SeekFrom::Start(0)).await?; + + let mut buf = Vec::new(); + let num_bytes = file.read_to_end(&mut buf).await?; + assert_eq!(&buf[..num_bytes], &[0, 1, 8, 9, 4, 5, 6, 7]); + + Ok(()) +} diff --git a/tests/io_sync_bridge.rs b/tests/io_sync_bridge.rs index 76bbd0b..50d0e89 100644 --- a/tests/io_sync_bridge.rs +++ b/tests/io_sync_bridge.rs @@ -44,6 +44,18 @@ async fn test_async_write_to_sync() -> Result<(), Box<dyn Error>> { } #[tokio::test] +async fn test_into_inner() -> Result<(), Box<dyn Error>> { + let mut buf = Vec::new(); + SyncIoBridge::new(tokio::io::empty()) + .into_inner() + .read_to_end(&mut buf) + .await + .unwrap(); + assert_eq!(buf.len(), 0); + Ok(()) +} + +#[tokio::test] async fn test_shutdown() -> Result<(), Box<dyn Error>> { let (s1, mut s2) = tokio::io::duplex(1024); let (_rh, wh) = tokio::io::split(s1); diff --git a/tests/length_delimited.rs b/tests/length_delimited.rs index 126e41b..ed5590f 100644 --- a/tests/length_delimited.rs +++ b/tests/length_delimited.rs @@ -12,7 +12,6 @@ use futures::{pin_mut, Sink, Stream}; use std::collections::VecDeque; use std::io; use std::pin::Pin; -use std::task::Poll::*; use std::task::{Context, Poll}; macro_rules! mock { @@ -39,10 +38,10 @@ macro_rules! assert_next_eq { macro_rules! assert_next_pending { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { - Ready(Some(Ok(v))) => panic!("value = {:?}", v), - Ready(Some(Err(e))) => panic!("error = {:?}", e), - Ready(None) => panic!("done"), - Pending => {} + Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Poll::Ready(Some(Err(e))) => panic!("error = {:?}", e), + Poll::Ready(None) => panic!("done"), + Poll::Pending => {} }); }}; } @@ -50,10 +49,10 @@ macro_rules! assert_next_pending { macro_rules! assert_next_err { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { - Ready(Some(Ok(v))) => panic!("value = {:?}", v), - Ready(Some(Err(_))) => {} - Ready(None) => panic!("done"), - Pending => panic!("pending"), + Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Poll::Ready(Some(Err(_))) => {} + Poll::Ready(None) => panic!("done"), + Poll::Pending => panic!("pending"), }); }}; } @@ -186,11 +185,11 @@ fn read_single_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), - Pending, + Poll::Pending, data(b"\x00\x09abc"), - Pending, + Poll::Pending, data(b"defghi"), - Pending, + Poll::Pending, }, LengthDelimitedCodec::new(), ); @@ -208,15 +207,15 @@ fn read_multi_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), - Pending, + Poll::Pending, data(b"\x00\x09abc"), - Pending, + Poll::Pending, data(b"defghi"), - Pending, + Poll::Pending, data(b"\x00\x00\x00\x0312"), - Pending, + Poll::Pending, data(b"3\x00\x00\x00\x0bhello world"), - Pending, + Poll::Pending, }, LengthDelimitedCodec::new(), ); @@ -250,9 +249,9 @@ fn read_incomplete_head() { fn read_incomplete_head_multi() { let io = FramedRead::new( mock! { - Pending, + Poll::Pending, data(b"\x00"), - Pending, + Poll::Pending, }, LengthDelimitedCodec::new(), ); @@ -268,9 +267,9 @@ fn read_incomplete_payload() { let io = FramedRead::new( mock! { data(b"\x00\x00\x00\x09ab"), - Pending, + Poll::Pending, data(b"cd"), - Pending, + Poll::Pending, }, LengthDelimitedCodec::new(), ); @@ -310,7 +309,7 @@ fn read_update_max_frame_len_at_rest() { fn read_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_read(mock! { data(b"\x00\x00\x00\x09abcd"), - Pending, + Poll::Pending, data(b"efghi"), data(b"\x00\x00\x00\x09abcdefghi"), }); @@ -533,9 +532,9 @@ fn write_single_multi_frame_multi_packet() { fn write_single_frame_would_block() { let io = FramedWrite::new( mock! { - Pending, + Poll::Pending, data(b"\x00\x00"), - Pending, + Poll::Pending, data(b"\x00\x09"), data(b"abcdefghi"), flush(), @@ -640,7 +639,7 @@ fn write_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_write(mock! { data(b"\x00\x00\x00\x06"), data(b"ab"), - Pending, + Poll::Pending, data(b"cdef"), flush(), }); @@ -701,8 +700,6 @@ enum Op { Flush, } -use self::Op::*; - impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, @@ -710,15 +707,15 @@ impl AsyncRead for Mock { dst: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { match self.calls.pop_front() { - Some(Ready(Ok(Op::Data(data)))) => { + Some(Poll::Ready(Ok(Op::Data(data)))) => { debug_assert!(dst.remaining() >= data.len()); dst.put_slice(&data); - Ready(Ok(())) + Poll::Ready(Ok(())) } - Some(Ready(Ok(_))) => panic!(), - Some(Ready(Err(e))) => Ready(Err(e)), - Some(Pending) => Pending, - None => Ready(Ok(())), + Some(Poll::Ready(Ok(_))) => panic!(), + Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), + Some(Poll::Pending) => Poll::Pending, + None => Poll::Ready(Ok(())), } } } @@ -730,31 +727,31 @@ impl AsyncWrite for Mock { src: &[u8], ) -> Poll<Result<usize, io::Error>> { match self.calls.pop_front() { - Some(Ready(Ok(Op::Data(data)))) => { + Some(Poll::Ready(Ok(Op::Data(data)))) => { let len = data.len(); assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); assert_eq!(&data[..], &src[..len]); - Ready(Ok(len)) + Poll::Ready(Ok(len)) } - Some(Ready(Ok(_))) => panic!(), - Some(Ready(Err(e))) => Ready(Err(e)), - Some(Pending) => Pending, - None => Ready(Ok(0)), + Some(Poll::Ready(Ok(_))) => panic!(), + Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), + Some(Poll::Pending) => Poll::Pending, + None => Poll::Ready(Ok(0)), } } fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { match self.calls.pop_front() { - Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), - Some(Ready(Ok(_))) => panic!(), - Some(Ready(Err(e))) => Ready(Err(e)), - Some(Pending) => Pending, - None => Ready(Ok(())), + Some(Poll::Ready(Ok(Op::Flush))) => Poll::Ready(Ok(())), + Some(Poll::Ready(Ok(_))) => panic!(), + Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), + Some(Poll::Pending) => Poll::Pending, + None => Poll::Ready(Ok(())), } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { - Ready(Ok(())) + Poll::Ready(Ok(())) } } @@ -771,9 +768,9 @@ impl From<Vec<u8>> for Op { } fn data(bytes: &[u8]) -> Poll<io::Result<Op>> { - Ready(Ok(bytes.into())) + Poll::Ready(Ok(bytes.into())) } fn flush() -> Poll<io::Result<Op>> { - Ready(Ok(Flush)) + Poll::Ready(Ok(Op::Flush)) } diff --git a/tests/mpsc.rs b/tests/mpsc.rs index a3c164d..74b83c2 100644 --- a/tests/mpsc.rs +++ b/tests/mpsc.rs @@ -28,6 +28,29 @@ async fn simple() { } #[tokio::test] +async fn simple_ref() { + let v = vec![1, 2, 3i32]; + + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + + for vi in v.iter() { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(vi).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert_eq!(*recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + drop(recv); + send.send_item(&42).unwrap(); +} + +#[tokio::test] async fn repeated_poll_reserve() { let (send, mut recv) = channel::<i32>(1); let mut send = PollSender::new(send); diff --git a/tests/task_join_map.rs b/tests/task_join_map.rs index cef08b2..1ab5f9b 100644 --- a/tests/task_join_map.rs +++ b/tests/task_join_map.rs @@ -109,6 +109,30 @@ async fn alternating() { } } +#[tokio::test] +async fn test_keys() { + use std::collections::HashSet; + + let mut map = JoinMap::new(); + + assert_eq!(map.len(), 0); + map.spawn(1, async {}); + assert_eq!(map.len(), 1); + map.spawn(2, async {}); + assert_eq!(map.len(), 2); + + let keys = map.keys().collect::<HashSet<&u32>>(); + assert!(keys.contains(&1)); + assert!(keys.contains(&2)); + + let _ = map.join_next().await.unwrap(); + let _ = map.join_next().await.unwrap(); + + assert_eq!(map.len(), 0); + let keys = map.keys().collect::<HashSet<&u32>>(); + assert!(keys.is_empty()); +} + #[tokio::test(start_paused = true)] async fn abort_by_key() { let mut map = JoinMap::new(); diff --git a/tests/task_tracker.rs b/tests/task_tracker.rs new file mode 100644 index 0000000..f0eb244 --- /dev/null +++ b/tests/task_tracker.rs @@ -0,0 +1,178 @@ +#![warn(rust_2018_idioms)] + +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::task::TaskTracker; + +#[test] +fn open_close() { + let tracker = TaskTracker::new(); + assert!(!tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.reopen(); + assert!(!tracker.is_closed()); + tracker.reopen(); + assert!(!tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + tracker.close(); + assert!(tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); +} + +#[test] +fn token_len() { + let tracker = TaskTracker::new(); + + let mut tokens = Vec::new(); + for i in 0..10 { + assert_eq!(tracker.len(), i); + tokens.push(tracker.token()); + } + + assert!(!tracker.is_empty()); + assert_eq!(tracker.len(), 10); + + for (i, token) in tokens.into_iter().enumerate() { + drop(token); + assert_eq!(tracker.len(), 9 - i); + } +} + +#[test] +fn notify_immediately() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_immediately_on_reopen() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_close() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_close_reopen() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_last_task() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_last_task_respawn() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_ready!(wait.poll()); + drop(token2); +} + +#[test] +fn no_notify_on_respawn_if_open() { + let tracker = TaskTracker::new(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_pending!(wait.poll()); + drop(token2); +} + +#[test] +fn close_during_exit() { + const ITERS: usize = 5; + + for close_spot in 0..=ITERS { + let tracker = TaskTracker::new(); + let tokens: Vec<_> = (0..ITERS).map(|_| tracker.token()).collect(); + + let mut wait = task::spawn(tracker.wait()); + + for (i, token) in tokens.into_iter().enumerate() { + assert_pending!(wait.poll()); + if i == close_spot { + tracker.close(); + assert_pending!(wait.poll()); + } + drop(token); + } + + if close_spot == ITERS { + assert_pending!(wait.poll()); + tracker.close(); + } + + assert_ready!(wait.poll()); + } +} + +#[test] +fn notify_many() { + let tracker = TaskTracker::new(); + + let mut waits: Vec<_> = (0..10).map(|_| task::spawn(tracker.wait())).collect(); + + for wait in &mut waits { + assert_pending!(wait.poll()); + } + + tracker.close(); + + for wait in &mut waits { + assert_ready!(wait.poll()); + } +} diff --git a/tests/time_delay_queue.rs b/tests/time_delay_queue.rs index 9ceae34..9b7b6cc 100644 --- a/tests/time_delay_queue.rs +++ b/tests/time_delay_queue.rs @@ -2,6 +2,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use futures::StreamExt; use tokio::time::{self, sleep, sleep_until, Duration, Instant}; use tokio_test::{assert_pending, assert_ready, task}; use tokio_util::time::DelayQueue; @@ -257,6 +258,10 @@ async fn reset_twice() { #[tokio::test] async fn repeatedly_reset_entry_inserted_as_expired() { time::pause(); + + // Instants before the start of the test seem to break in wasm. + time::sleep(ms(1000)).await; + let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); @@ -556,6 +561,10 @@ async fn reset_later_after_slot_starts() { #[tokio::test] async fn reset_inserted_expired() { time::pause(); + + // Instants before the start of the test seem to break in wasm. + time::sleep(ms(1000)).await; + let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); @@ -778,6 +787,22 @@ async fn compact_change_deadline() { assert!(entry.is_none()); } +#[tokio::test(start_paused = true)] +async fn item_expiry_greater_than_wheel() { + // This function tests that a delay queue that has existed for at least 2^36 milliseconds won't panic when a new item is inserted. + let mut queue = DelayQueue::new(); + for _ in 0..2 { + tokio::time::advance(Duration::from_millis(1 << 35)).await; + queue.insert(0, Duration::from_millis(0)); + queue.next().await; + } + // This should not panic + let no_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.insert(1, Duration::from_millis(1)); + })); + assert!(no_panic.is_ok()); +} + #[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")] #[tokio::test(start_paused = true)] async fn remove_after_compact() { @@ -815,6 +840,44 @@ async fn remove_after_compact_poll() { assert!(panic.is_err()); } +#[tokio::test(start_paused = true)] +async fn peek() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key = queue.insert_at("foo", now + ms(5)); + let key2 = queue.insert_at("bar", now); + let key3 = queue.insert_at("baz", now + ms(10)); + + assert_eq!(queue.peek(), Some(key2)); + + sleep(ms(6)).await; + + assert_eq!(queue.peek(), Some(key2)); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(entry.get_ref(), &"bar"); + + assert_eq!(queue.peek(), Some(key)); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(entry.get_ref(), &"foo"); + + assert_eq!(queue.peek(), Some(key3)); + + assert_pending!(poll!(queue)); + + sleep(ms(5)).await; + + assert_eq!(queue.peek(), Some(key3)); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(entry.get_ref(), &"baz"); + + assert!(queue.peek().is_none()); +} + fn ms(n: u64) -> Duration { Duration::from_millis(n) } diff --git a/tests/udp.rs b/tests/udp.rs index 1b99806..db726a3 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -13,7 +13,10 @@ use futures::sink::SinkExt; use std::io; use std::sync::Arc; -#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] +#[cfg_attr( + any(target_os = "macos", target_os = "ios", target_os = "tvos"), + allow(unused_assignments) +)] #[tokio::test] async fn send_framed_byte_codec() -> std::io::Result<()> { let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; @@ -41,7 +44,7 @@ async fn send_framed_byte_codec() -> std::io::Result<()> { b_soc = b.into_inner(); } - #[cfg(not(any(target_os = "macos", target_os = "ios")))] + #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))] // test sending & receiving an empty message { let mut a = UdpFramed::new(a_soc, ByteCodec); |