1
0
Fork 0
mirror of https://github.com/denoland/deno.git synced 2024-11-21 15:04:11 -05:00

refactor(npm): improve locking around updating npm resolution (#24104)

Introduces a `SyncReadAsyncWriteLock` to make it harder to write to the
npm resolution without first waiting async in a queue. For the npm
resolution, reading synchronously is fine, but when updating, someone
should wait async, clone the data, then write the data at the end back.
This commit is contained in:
David Sherret 2024-06-05 15:17:35 -04:00 committed by GitHub
parent 7ed90a20d0
commit 1b355d8a87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 417 additions and 225 deletions

View file

@ -28,7 +28,6 @@ use crate::util::fs::hard_link_dir_recursive;
mod registry_info;
mod tarball;
mod tarball_extract;
mod value_creator;
pub use registry_info::RegistryInfoDownloader;
pub use tarball::TarballCache;

View file

@ -13,7 +13,6 @@ use deno_core::futures::FutureExt;
use deno_core::parking_lot::Mutex;
use deno_core::serde_json;
use deno_core::url::Url;
use deno_npm::npm_rc::RegistryConfig;
use deno_npm::npm_rc::ResolvedNpmRc;
use deno_npm::registry::NpmPackageInfo;
@ -21,24 +20,14 @@ use crate::args::CacheSetting;
use crate::http_util::HttpClientProvider;
use crate::npm::common::maybe_auth_header_for_npm_registry;
use crate::util::progress_bar::ProgressBar;
use crate::util::sync::MultiRuntimeAsyncValueCreator;
use super::value_creator::MultiRuntimeAsyncValueCreator;
use super::NpmCache;
// todo(dsherret): create seams and unit test this
#[derive(Debug, Clone)]
enum MemoryCacheItem {
/// The cache item hasn't loaded yet.
Pending(Arc<MultiRuntimeAsyncValueCreator<FutureResult>>),
/// The item has loaded in the past and was stored in the file system cache.
/// There is no reason to request this package from the npm registry again
/// for the duration of execution.
FsCached,
/// An item is memory cached when it fails saving to the file system cache
/// or the package does not exist.
MemoryCached(Result<Option<Arc<NpmPackageInfo>>, Arc<AnyError>>),
}
type LoadResult = Result<FutureResult, Arc<AnyError>>;
type LoadFuture = LocalBoxFuture<'static, LoadResult>;
#[derive(Debug, Clone)]
enum FutureResult {
@ -47,8 +36,18 @@ enum FutureResult {
ErroredFsCache(Arc<NpmPackageInfo>),
}
type PendingRegistryLoadFuture =
LocalBoxFuture<'static, Result<FutureResult, AnyError>>;
#[derive(Debug, Clone)]
enum MemoryCacheItem {
/// The cache item hasn't loaded yet.
Pending(Arc<MultiRuntimeAsyncValueCreator<LoadResult>>),
/// The item has loaded in the past and was stored in the file system cache.
/// There is no reason to request this package from the npm registry again
/// for the duration of execution.
FsCached,
/// An item is memory cached when it fails saving to the file system cache
/// or the package does not exist.
MemoryCached(Result<Option<Arc<NpmPackageInfo>>, Arc<AnyError>>),
}
/// Downloads packuments from the npm registry.
///
@ -82,26 +81,18 @@ impl RegistryInfoDownloader {
self: &Arc<Self>,
name: &str,
) -> Result<Option<Arc<NpmPackageInfo>>, AnyError> {
let registry_url = self.npmrc.get_registry_url(name);
let registry_config = self.npmrc.get_registry_config(name);
self
.load_package_info_inner(name, registry_url, registry_config)
.await
.with_context(|| {
format!(
"Error getting response at {} for package \"{}\"",
self.get_package_url(name, registry_url),
name
)
})
self.load_package_info_inner(name).await.with_context(|| {
format!(
"Error getting response at {} for package \"{}\"",
self.get_package_url(name),
name
)
})
}
async fn load_package_info_inner(
self: &Arc<Self>,
name: &str,
registry_url: &Url,
registry_config: &RegistryConfig,
) -> Result<Option<Arc<NpmPackageInfo>>, AnyError> {
if *self.cache.cache_setting() == CacheSetting::Only {
return Err(custom_error(
@ -117,9 +108,11 @@ impl RegistryInfoDownloader {
if let Some(cache_item) = mem_cache.get(name) {
cache_item.clone()
} else {
let future =
self.create_load_future(name, registry_url, registry_config);
let value_creator = MultiRuntimeAsyncValueCreator::new(future);
let value_creator = MultiRuntimeAsyncValueCreator::new({
let downloader = self.clone();
let name = name.to_string();
Box::new(move || downloader.create_load_future(&name))
});
let cache_item = MemoryCacheItem::Pending(Arc::new(value_creator));
mem_cache.insert(name.to_string(), cache_item.clone());
cache_item
@ -138,11 +131,7 @@ impl RegistryInfoDownloader {
maybe_info.clone().map_err(|e| anyhow!("{}", e))
}
MemoryCacheItem::Pending(value_creator) => {
let downloader = self.clone();
let future = value_creator.get(move || {
downloader.create_load_future(name, registry_url, registry_config)
});
match future.await {
match value_creator.get().await {
Ok(FutureResult::SavedFsCache(info)) => {
// return back the future and mark this package as having
// been saved in the cache for next time it's requested
@ -199,14 +188,10 @@ impl RegistryInfoDownloader {
}
}
fn create_load_future(
self: &Arc<Self>,
name: &str,
registry_url: &Url,
registry_config: &RegistryConfig,
) -> PendingRegistryLoadFuture {
fn create_load_future(self: &Arc<Self>, name: &str) -> LoadFuture {
let downloader = self.clone();
let package_url = self.get_package_url(name, registry_url);
let package_url = self.get_package_url(name);
let registry_config = self.npmrc.get_registry_config(name);
let maybe_auth_header = maybe_auth_header_for_npm_registry(registry_config);
let guard = self.progress_bar.update(package_url.as_str());
let name = name.to_string();
@ -242,10 +227,12 @@ impl RegistryInfoDownloader {
None => Ok(FutureResult::PackageNotExists),
}
}
.map(|r| r.map_err(Arc::new))
.boxed_local()
}
fn get_package_url(&self, name: &str, registry_url: &Url) -> Url {
fn get_package_url(&self, name: &str) -> Url {
let registry_url = self.npmrc.get_registry_url(name);
// list of all characters used in npm packages:
// !, ', (, ), *, -, ., /, [0-9], @, [A-Za-z], _, ~
const ASCII_SET: percent_encoding::AsciiSet =

View file

@ -20,18 +20,21 @@ use crate::args::CacheSetting;
use crate::http_util::HttpClientProvider;
use crate::npm::common::maybe_auth_header_for_npm_registry;
use crate::util::progress_bar::ProgressBar;
use crate::util::sync::MultiRuntimeAsyncValueCreator;
use super::tarball_extract::verify_and_extract_tarball;
use super::tarball_extract::TarballExtractionMode;
use super::value_creator::MultiRuntimeAsyncValueCreator;
use super::NpmCache;
// todo(dsherret): create seams and unit test this
type LoadResult = Result<(), Arc<AnyError>>;
type LoadFuture = LocalBoxFuture<'static, LoadResult>;
#[derive(Debug, Clone)]
enum MemoryCacheItem {
/// The cache item hasn't finished yet.
Pending(Arc<MultiRuntimeAsyncValueCreator<()>>),
Pending(Arc<MultiRuntimeAsyncValueCreator<LoadResult>>),
/// The result errored.
Errored(Arc<AnyError>),
/// This package has already been cached.
@ -91,8 +94,14 @@ impl TarballCache {
if let Some(cache_item) = mem_cache.get(package_nv) {
cache_item.clone()
} else {
let future = self.create_setup_future(package_nv.clone(), dist.clone());
let value_creator = MultiRuntimeAsyncValueCreator::new(future);
let value_creator = MultiRuntimeAsyncValueCreator::new({
let tarball_cache = self.clone();
let package_nv = package_nv.clone();
let dist = dist.clone();
Box::new(move || {
tarball_cache.create_setup_future(package_nv.clone(), dist.clone())
})
});
let cache_item = MemoryCacheItem::Pending(Arc::new(value_creator));
mem_cache.insert(package_nv.clone(), cache_item.clone());
cache_item
@ -103,12 +112,7 @@ impl TarballCache {
MemoryCacheItem::Cached => Ok(()),
MemoryCacheItem::Errored(err) => Err(anyhow!("{}", err)),
MemoryCacheItem::Pending(creator) => {
let tarball_cache = self.clone();
let result = creator
.get(move || {
tarball_cache.create_setup_future(package_nv.clone(), dist.clone())
})
.await;
let result = creator.get().await;
match result {
Ok(_) => {
*self.memory_cache.lock().get_mut(package_nv).unwrap() =
@ -130,7 +134,7 @@ impl TarballCache {
self: &Arc<Self>,
package_nv: PackageNv,
dist: NpmPackageVersionDistInfo,
) -> LocalBoxFuture<'static, Result<(), AnyError>> {
) -> LoadFuture {
let tarball_cache = self.clone();
async move {
let registry_url = tarball_cache.npmrc.get_registry_url(&package_nv.name);
@ -197,6 +201,8 @@ impl TarballCache {
bail!("Could not find npm package tarball at: {}", dist.tarball);
}
}
}.boxed_local()
}
.map(|r| r.map_err(Arc::new))
.boxed_local()
}
}

View file

@ -1,101 +0,0 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use std::sync::Arc;
use deno_core::error::AnyError;
use deno_core::futures::future::BoxFuture;
use deno_core::futures::future::LocalBoxFuture;
use deno_core::futures::future::Shared;
use deno_core::futures::FutureExt;
use deno_core::parking_lot::Mutex;
use tokio::task::JoinError;
// todo(dsherret): unit test this
type FutureResult<TResult> = Result<TResult, Arc<AnyError>>;
type JoinResult<TResult> = Result<FutureResult<TResult>, Arc<JoinError>>;
#[derive(Debug)]
struct State<TResult> {
retry_index: usize,
future: Shared<BoxFuture<'static, JoinResult<TResult>>>,
}
/// Attempts to create a shared value asynchronously on one tokio runtime while
/// many runtimes are requesting the value.
///
/// This is only useful when the value needs to get created once across
/// many runtimes.
///
/// This handles the case where one tokio runtime goes down while another
/// one is still running.
#[derive(Debug)]
pub struct MultiRuntimeAsyncValueCreator<TResult: Send + Clone + 'static> {
state: Mutex<State<TResult>>,
}
impl<TResult: Send + Clone + 'static> MultiRuntimeAsyncValueCreator<TResult> {
pub fn new(
future: LocalBoxFuture<'static, Result<TResult, AnyError>>,
) -> Self {
Self {
state: Mutex::new(State {
retry_index: 0,
future: Self::create_shared_future(future),
}),
}
}
pub async fn get(
&self,
recreate_future: impl Fn() -> LocalBoxFuture<'static, Result<TResult, AnyError>>,
) -> Result<TResult, Arc<AnyError>> {
let (mut future, mut retry_index) = {
let state = self.state.lock();
(state.future.clone(), state.retry_index)
};
loop {
let result = future.await;
match result {
Ok(result) => return result,
Err(join_error) => {
if join_error.is_cancelled() {
let mut state = self.state.lock();
if state.retry_index == retry_index {
// we were the first one to retry, so create a new future
// that we'll run from the current runtime
state.retry_index += 1;
state.future = Self::create_shared_future(recreate_future());
}
retry_index = state.retry_index;
future = state.future.clone();
// just in case we're stuck in a loop
if retry_index > 1000 {
panic!("Something went wrong.") // should never happen
}
} else {
panic!("{}", join_error);
}
}
}
}
}
fn create_shared_future(
future: LocalBoxFuture<'static, Result<TResult, AnyError>>,
) -> Shared<BoxFuture<'static, JoinResult<TResult>>> {
deno_core::unsync::spawn(future)
.map(|result| match result {
Ok(Ok(value)) => Ok(Ok(value)),
Ok(Err(err)) => Ok(Err(Arc::new(err))),
Err(err) => Err(Arc::new(err)),
})
.boxed()
.shared()
}
}

View file

@ -6,7 +6,6 @@ use std::sync::Arc;
use deno_core::error::AnyError;
use deno_core::parking_lot::Mutex;
use deno_core::parking_lot::RwLock;
use deno_lockfile::NpmPackageDependencyLockfileInfo;
use deno_lockfile::NpmPackageLockfileInfo;
use deno_npm::registry::NpmPackageInfo;
@ -31,7 +30,7 @@ use deno_semver::package::PackageReq;
use deno_semver::VersionReq;
use crate::args::Lockfile;
use crate::util::sync::TaskQueue;
use crate::util::sync::SyncReadAsyncWriteLock;
use super::CliNpmRegistryApi;
@ -42,8 +41,7 @@ use super::CliNpmRegistryApi;
/// This does not interact with the file system.
pub struct NpmResolution {
api: Arc<CliNpmRegistryApi>,
snapshot: RwLock<NpmResolutionSnapshot>,
update_queue: TaskQueue,
snapshot: SyncReadAsyncWriteLock<NpmResolutionSnapshot>,
maybe_lockfile: Option<Arc<Mutex<Lockfile>>>,
}
@ -74,8 +72,7 @@ impl NpmResolution {
) -> Self {
Self {
api,
snapshot: RwLock::new(initial_snapshot),
update_queue: Default::default(),
snapshot: SyncReadAsyncWriteLock::new(initial_snapshot),
maybe_lockfile,
}
}
@ -85,16 +82,16 @@ impl NpmResolution {
package_reqs: &[PackageReq],
) -> Result<(), AnyError> {
// only allow one thread in here at a time
let _permit = self.update_queue.acquire().await;
let snapshot_lock = self.snapshot.acquire().await;
let snapshot = add_package_reqs_to_snapshot(
&self.api,
package_reqs,
self.maybe_lockfile.clone(),
|| self.snapshot.read().clone(),
|| snapshot_lock.read().clone(),
)
.await?;
*self.snapshot.write() = snapshot;
*snapshot_lock.write() = snapshot;
Ok(())
}
@ -103,7 +100,7 @@ impl NpmResolution {
package_reqs: &[PackageReq],
) -> Result<(), AnyError> {
// only allow one thread in here at a time
let _permit = self.update_queue.acquire().await;
let snapshot_lock = self.snapshot.acquire().await;
let reqs_set = package_reqs.iter().collect::<HashSet<_>>();
let snapshot = add_package_reqs_to_snapshot(
@ -111,7 +108,7 @@ impl NpmResolution {
package_reqs,
self.maybe_lockfile.clone(),
|| {
let snapshot = self.snapshot.read().clone();
let snapshot = snapshot_lock.read().clone();
let has_removed_package = !snapshot
.package_reqs()
.keys()
@ -126,24 +123,24 @@ impl NpmResolution {
)
.await?;
*self.snapshot.write() = snapshot;
*snapshot_lock.write() = snapshot;
Ok(())
}
pub async fn resolve_pending(&self) -> Result<(), AnyError> {
// only allow one thread in here at a time
let _permit = self.update_queue.acquire().await;
let snapshot_lock = self.snapshot.acquire().await;
let snapshot = add_package_reqs_to_snapshot(
&self.api,
&Vec::new(),
self.maybe_lockfile.clone(),
|| self.snapshot.read().clone(),
|| snapshot_lock.read().clone(),
)
.await?;
*self.snapshot.write() = snapshot;
*snapshot_lock.write() = snapshot;
Ok(())
}
@ -229,8 +226,10 @@ impl NpmResolution {
pkg_info: &NpmPackageInfo,
) -> Result<PackageNv, NpmPackageVersionResolutionError> {
debug_assert_eq!(pkg_req.name, pkg_info.name);
let _permit = self.update_queue.acquire().await;
let mut snapshot = self.snapshot.write();
// only allow one thread in here at a time
let snapshot_lock = self.snapshot.acquire().await;
let mut snapshot = snapshot_lock.write();
let pending_resolver = get_npm_pending_resolver(&self.api);
let nv = pending_resolver.resolve_package_req_as_pending(
&mut snapshot,
@ -244,8 +243,10 @@ impl NpmResolution {
&self,
reqs_with_pkg_infos: &[(&PackageReq, Arc<NpmPackageInfo>)],
) -> Vec<Result<PackageNv, NpmPackageVersionResolutionError>> {
let _permit = self.update_queue.acquire().await;
let mut snapshot = self.snapshot.write();
// only allow one thread in here at a time
let snapshot_lock = self.snapshot.acquire().await;
let mut snapshot = snapshot_lock.write();
let pending_resolver = get_npm_pending_resolver(&self.api);
let mut results = Vec::with_capacity(reqs_with_pkg_infos.len());
for (pkg_req, pkg_info) in reqs_with_pkg_infos {

View file

@ -0,0 +1,20 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use tokio_util::sync::CancellationToken;
#[derive(Debug, Default, Clone)]
pub struct AsyncFlag(CancellationToken);
impl AsyncFlag {
pub fn raise(&self) {
self.0.cancel();
}
pub fn is_raised(&self) -> bool {
self.0.is_cancelled()
}
pub fn wait_raised(&self) -> impl std::future::Future<Output = ()> + '_ {
self.0.cancelled()
}
}

View file

@ -0,0 +1,35 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
/// Simplifies the use of an atomic boolean as a flag.
#[derive(Debug, Default)]
pub struct AtomicFlag(AtomicBool);
impl AtomicFlag {
/// Raises the flag returning if the raise was successful.
pub fn raise(&self) -> bool {
!self.0.swap(true, Ordering::SeqCst)
}
/// Gets if the flag is raised.
pub fn is_raised(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn atomic_flag_raises() {
let flag = AtomicFlag::default();
assert!(!flag.is_raised()); // false by default
assert!(flag.raise());
assert!(flag.is_raised());
assert!(!flag.raise());
assert!(flag.is_raised());
}
}

14
cli/util/sync/mod.rs Normal file
View file

@ -0,0 +1,14 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
mod async_flag;
mod atomic_flag;
mod sync_read_async_write_lock;
mod task_queue;
mod value_creator;
pub use async_flag::AsyncFlag;
pub use atomic_flag::AtomicFlag;
pub use sync_read_async_write_lock::SyncReadAsyncWriteLock;
pub use task_queue::TaskQueue;
pub use task_queue::TaskQueuePermit;
pub use value_creator::MultiRuntimeAsyncValueCreator;

View file

@ -0,0 +1,62 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use deno_core::parking_lot::RwLock;
use deno_core::parking_lot::RwLockReadGuard;
use deno_core::parking_lot::RwLockWriteGuard;
use super::TaskQueue;
use super::TaskQueuePermit;
/// A lock that can be read synchronously at any time (including when
/// being written to), but must write asynchronously.
pub struct SyncReadAsyncWriteLockWriteGuard<'a, T: Send + Sync> {
_update_permit: TaskQueuePermit<'a>,
data: &'a RwLock<T>,
}
impl<'a, T: Send + Sync> SyncReadAsyncWriteLockWriteGuard<'a, T> {
pub fn read(&self) -> RwLockReadGuard<'_, T> {
self.data.read()
}
/// Warning: Only `write()` with data you created within this
/// write this `SyncReadAsyncWriteLockWriteGuard`.
///
/// ```rs
/// let mut data = lock.write().await;
///
/// let mut data = data.read().clone();
/// data.value = 2;
/// *data.write() = data;
/// ```
pub fn write(&self) -> RwLockWriteGuard<'_, T> {
self.data.write()
}
}
/// A lock that can only be
pub struct SyncReadAsyncWriteLock<T: Send + Sync> {
data: RwLock<T>,
update_queue: TaskQueue,
}
impl<T: Send + Sync> SyncReadAsyncWriteLock<T> {
pub fn new(data: T) -> Self {
Self {
data: RwLock::new(data),
update_queue: TaskQueue::default(),
}
}
pub fn read(&self) -> RwLockReadGuard<'_, T> {
self.data.read()
}
pub async fn acquire(&self) -> SyncReadAsyncWriteLockWriteGuard<'_, T> {
let update_permit = self.update_queue.acquire().await;
SyncReadAsyncWriteLockWriteGuard {
_update_permit: update_permit,
data: &self.data,
}
}
}

View file

@ -1,30 +1,13 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use std::collections::LinkedList;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use deno_core::futures::task::AtomicWaker;
use deno_core::futures::Future;
use deno_core::parking_lot::Mutex;
use tokio_util::sync::CancellationToken;
/// Simplifies the use of an atomic boolean as a flag.
#[derive(Debug, Default)]
pub struct AtomicFlag(AtomicBool);
impl AtomicFlag {
/// Raises the flag returning if the raise was successful.
pub fn raise(&self) -> bool {
!self.0.swap(true, Ordering::SeqCst)
}
/// Gets if the flag is raised.
pub fn is_raised(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
use super::AtomicFlag;
#[derive(Debug, Default)]
struct TaskQueueTaskItem {
@ -161,23 +144,6 @@ impl<'a> Future for TaskQueuePermitAcquireFuture<'a> {
}
}
#[derive(Debug, Default, Clone)]
pub struct AsyncFlag(CancellationToken);
impl AsyncFlag {
pub fn raise(&self) {
self.0.cancel();
}
pub fn is_raised(&self) -> bool {
self.0.is_cancelled()
}
pub fn wait_raised(&self) -> impl std::future::Future<Output = ()> + '_ {
self.0.cancelled()
}
}
#[cfg(test)]
mod test {
use deno_core::futures;
@ -186,16 +152,6 @@ mod test {
use super::*;
#[test]
fn atomic_flag_raises() {
let flag = AtomicFlag::default();
assert!(!flag.is_raised()); // false by default
assert!(flag.raise());
assert!(flag.is_raised());
assert!(!flag.raise());
assert!(flag.is_raised());
}
#[tokio::test]
async fn task_queue_runs_one_after_other() {
let task_queue = TaskQueue::default();

View file

@ -0,0 +1,213 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use std::sync::Arc;
use deno_core::futures::future::BoxFuture;
use deno_core::futures::future::LocalBoxFuture;
use deno_core::futures::future::Shared;
use deno_core::futures::FutureExt;
use deno_core::parking_lot::Mutex;
use tokio::task::JoinError;
type JoinResult<TResult> = Result<TResult, Arc<JoinError>>;
type CreateFutureFn<TResult> =
Box<dyn Fn() -> LocalBoxFuture<'static, TResult> + Send + Sync>;
#[derive(Debug)]
struct State<TResult> {
retry_index: usize,
future: Option<Shared<BoxFuture<'static, JoinResult<TResult>>>>,
}
/// Attempts to create a shared value asynchronously on one tokio runtime while
/// many runtimes are requesting the value.
///
/// This is only useful when the value needs to get created once across
/// many runtimes.
///
/// This handles the case where the tokio runtime creating the value goes down
/// while another one is waiting on the value.
pub struct MultiRuntimeAsyncValueCreator<TResult: Send + Clone + 'static> {
create_future: CreateFutureFn<TResult>,
state: Mutex<State<TResult>>,
}
impl<TResult: Send + Clone + 'static> std::fmt::Debug
for MultiRuntimeAsyncValueCreator<TResult>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiRuntimeAsyncValueCreator").finish()
}
}
impl<TResult: Send + Clone + 'static> MultiRuntimeAsyncValueCreator<TResult> {
pub fn new(create_future: CreateFutureFn<TResult>) -> Self {
Self {
state: Mutex::new(State {
retry_index: 0,
future: None,
}),
create_future,
}
}
pub async fn get(&self) -> TResult {
let (mut future, mut retry_index) = {
let mut state = self.state.lock();
let future = match &state.future {
Some(future) => future.clone(),
None => {
let future = self.create_shared_future();
state.future = Some(future.clone());
future
}
};
(future, state.retry_index)
};
loop {
let result = future.await;
match result {
Ok(result) => return result,
Err(join_error) => {
if join_error.is_cancelled() {
let mut state = self.state.lock();
if state.retry_index == retry_index {
// we were the first one to retry, so create a new future
// that we'll run from the current runtime
state.retry_index += 1;
state.future = Some(self.create_shared_future());
}
retry_index = state.retry_index;
future = state.future.as_ref().unwrap().clone();
// just in case we're stuck in a loop
if retry_index > 1000 {
panic!("Something went wrong.") // should never happen
}
} else {
panic!("{}", join_error);
}
}
}
}
}
fn create_shared_future(
&self,
) -> Shared<BoxFuture<'static, JoinResult<TResult>>> {
let future = (self.create_future)();
deno_core::unsync::spawn(future)
.map(|result| result.map_err(Arc::new))
.boxed()
.shared()
}
}
#[cfg(test)]
mod test {
use deno_core::unsync::spawn;
use super::*;
#[tokio::test]
async fn single_runtime() {
let value_creator = MultiRuntimeAsyncValueCreator::new(Box::new(|| {
async { 1 }.boxed_local()
}));
let value = value_creator.get().await;
assert_eq!(value, 1);
}
#[test]
fn multi_runtimes() {
let value_creator =
Arc::new(MultiRuntimeAsyncValueCreator::new(Box::new(|| {
async {
tokio::task::yield_now().await;
1
}
.boxed_local()
})));
let handles = (0..3)
.map(|_| {
let value_creator = value_creator.clone();
std::thread::spawn(|| {
create_runtime().block_on(async move { value_creator.get().await })
})
})
.collect::<Vec<_>>();
for handle in handles {
assert_eq!(handle.join().unwrap(), 1);
}
}
#[test]
fn multi_runtimes_first_never_finishes() {
let is_first_run = Arc::new(Mutex::new(true));
let (tx, rx) = std::sync::mpsc::channel::<()>();
let value_creator = Arc::new(MultiRuntimeAsyncValueCreator::new({
let is_first_run = is_first_run.clone();
Box::new(move || {
let is_first_run = is_first_run.clone();
let tx = tx.clone();
async move {
let is_first_run = {
let mut is_first_run = is_first_run.lock();
let initial_value = *is_first_run;
*is_first_run = false;
tx.send(()).unwrap();
initial_value
};
if is_first_run {
tokio::time::sleep(std::time::Duration::from_millis(30_000)).await;
panic!("TIMED OUT"); // should not happen
} else {
tokio::task::yield_now().await;
}
1
}
.boxed_local()
})
}));
std::thread::spawn({
let value_creator = value_creator.clone();
let is_first_run = is_first_run.clone();
move || {
create_runtime().block_on(async {
let value_creator = value_creator.clone();
// spawn a task that will never complete
spawn(async move { value_creator.get().await });
// wait for the task to set is_first_run to false
while *is_first_run.lock() {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
// now exit the runtime while the value_creator is still pending
})
}
});
let handle = {
let value_creator = value_creator.clone();
std::thread::spawn(|| {
create_runtime().block_on(async move {
let value_creator = value_creator.clone();
rx.recv().unwrap();
// even though the other runtime shutdown, this get() should
// recover and still get the value
value_creator.get().await
})
})
};
assert_eq!(handle.join().unwrap(), 1);
}
fn create_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
}
}