Skip to content

Commit ec96168

Browse files
committed
Actually cancel reserve slot tasks
1 parent 18264f2 commit ec96168

File tree

3 files changed

+79
-12
lines changed

3 files changed

+79
-12
lines changed

temporalio/bridge/src/worker.rs

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pyo3::types::{PyBytes, PyTuple};
77
use std::collections::HashMap;
88
use std::collections::HashSet;
99
use std::marker::PhantomData;
10-
use std::sync::Arc;
10+
use std::sync::{Arc, OnceLock};
1111
use std::time::Duration;
1212
use temporal_sdk_core::api::errors::{PollActivityError, PollWfError};
1313
use temporal_sdk_core::replay::{HistoryForReplay, ReplayWorkerInput};
@@ -64,6 +64,18 @@ pub struct TunerHolder {
6464
local_activity_slot_supplier: SlotSupplier,
6565
}
6666

67+
// pub fn set_task_locals_on_tuner<'a>(py: Python<'a>, tuner: &TunerHolder) -> PyResult<()> {
68+
// // TODO: All suppliers
69+
// if let SlotSupplier::Custom(ref cs) = tuner.workflow_slot_supplier {
70+
// Python::with_gil(|py| {
71+
// let py_obj = cs.inner.as_ref(py);
72+
// py_obj.call_method0("set_task_locals")?;
73+
// Ok(())
74+
// })?;
75+
// };
76+
// Ok(())
77+
// }
78+
6779
#[derive(FromPyObject)]
6880
pub enum SlotSupplier {
6981
FixedSize(FixedSizeSlotSupplier),
@@ -190,17 +202,60 @@ impl CustomSlotSupplier {
190202
}
191203
}
192204

205+
// Shouldn't really need this callback nonsense, it should be possible to do this from the pyo3
206+
// asyncio library, but we'd have to vendor the whole thing to make the right improvements. When
207+
// pyo3 is upgraded and we are using
208+
209+
#[pyclass]
210+
struct CreatedTaskForSlotCallback {
211+
stored_task: Arc<OnceLock<PyObject>>,
212+
}
213+
214+
#[pymethods]
215+
impl CreatedTaskForSlotCallback {
216+
fn __call__(&self, task: PyObject) -> PyResult<()> {
217+
self.stored_task.set(task).expect("must only be set once");
218+
Ok(())
219+
}
220+
}
221+
222+
struct TaskCanceller {
223+
stored_task: Arc<OnceLock<PyObject>>,
224+
}
225+
226+
impl TaskCanceller {
227+
fn new(stored_task: Arc<OnceLock<PyObject>>) -> Self {
228+
TaskCanceller { stored_task }
229+
}
230+
}
231+
232+
impl Drop for TaskCanceller {
233+
fn drop(&mut self) {
234+
if let Some(task) = self.stored_task.get() {
235+
Python::with_gil(|py| {
236+
task.call_method0(py, "cancel")
237+
.expect("Failed to cancel task");
238+
});
239+
}
240+
}
241+
}
242+
193243
#[async_trait::async_trait]
194244
impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<SK> {
195245
type SlotKind = SK;
196246

197247
async fn reserve_slot(&self, ctx: &dyn SlotReservationContext) -> SlotSupplierPermit {
198248
loop {
249+
let stored_task = Arc::new(OnceLock::new());
250+
let _task_canceller = TaskCanceller::new(stored_task.clone());
199251
let pypermit = match Python::with_gil(|py| {
200252
let py_obj = self.inner.as_ref(py);
201253
let called = py_obj.call_method1(
202254
"reserve_slot",
203-
(SlotReserveCtx::from_ctx(Self::SlotKind::kind(), ctx),),
255+
(
256+
SlotReserveCtx::from_ctx(SK::kind(), ctx),
257+
CreatedTaskForSlotCallback { stored_task },
258+
),
204259
)?;
205260
runtime::THREAD_TASK_LOCAL
206261
.with(|tl| pyo3_asyncio::into_future_with_locals(tl.get().unwrap(), called))
@@ -232,7 +287,7 @@ impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<
232287
let py_obj = self.inner.as_ref(py);
233288
let pa = py_obj.call_method1(
234289
"try_reserve_slot",
235-
(SlotReserveCtx::from_ctx(Self::SlotKind::kind(), ctx),),
290+
(SlotReserveCtx::from_ctx(SK::kind(), ctx),),
236291
)?;
237292

238293
if pa.is_none() {
@@ -362,6 +417,8 @@ pub fn new_replay_worker<'a>(
362417
impl WorkerRef {
363418
fn validate<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
364419
let worker = self.worker.as_ref().unwrap().clone();
420+
// Set custom slot supplier task locals so they can run futures
421+
// match worker.get_config().tuner {}
365422
self.runtime.future_into_py(py, async move {
366423
worker
367424
.validate()

temporalio/worker/_tuning.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from datetime import timedelta
6-
from typing import Literal, Optional, Union
6+
from typing import Any, Callable, Literal, Optional, Union
77

88
from typing_extensions import TypeAlias
99

@@ -90,13 +90,17 @@ class ResourceBasedSlotSupplier:
9090
]
9191

9292

93-
class _ErrorLoggingSlotSupplier(CustomSlotSupplier):
93+
class _BridgeSlotSupplierWrapper:
9494
def __init__(self, supplier: CustomSlotSupplier):
9595
self._supplier = supplier
9696

97-
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
97+
async def reserve_slot(
98+
self, ctx: SlotReserveContext, reserve_cb: Callable[[Any], None]
99+
) -> SlotPermit:
98100
try:
99-
return await self._supplier.reserve_slot(ctx)
101+
reserve_fut = asyncio.create_task(self._supplier.reserve_slot(ctx))
102+
reserve_cb(reserve_fut)
103+
return await reserve_fut
100104
except asyncio.CancelledError:
101105
raise
102106
except Exception:
@@ -160,7 +164,7 @@ def _to_bridge_slot_supplier(
160164
)
161165
elif isinstance(slot_supplier, CustomSlotSupplier):
162166
return temporalio.bridge.worker.BridgeCustomSlotSupplier(
163-
_ErrorLoggingSlotSupplier(slot_supplier)
167+
_BridgeSlotSupplierWrapper(slot_supplier)
164168
)
165169
else:
166170
raise TypeError(f"Unknown slot supplier type: {slot_supplier}")

tests/worker/test_worker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import concurrent.futures
5+
import threading
56
import uuid
67
from datetime import timedelta
78
from typing import Any, Awaitable, Callable, Optional
@@ -353,6 +354,7 @@ def __init__(self, pnum: int):
353354
class MySlotSupplier(CustomSlotSupplier):
354355
reserves = 0
355356
releases = 0
357+
highest_seen_reserve_on_release = 0
356358
used = 0
357359
seen_sticky_kinds = set()
358360
seen_slot_kinds = set()
@@ -388,6 +390,9 @@ def release_slot(self, ctx: SlotReleaseContext) -> None:
388390
assert ctx.permit is not None
389391
assert isinstance(ctx.permit, MyPermit)
390392
assert ctx.permit.pnum is not None
393+
self.highest_seen_reserve_on_release = max(
394+
ctx.permit.pnum, self.highest_seen_reserve_on_release
395+
)
391396
# Info may be empty, and we should see both empty and not
392397
if ctx.slot_info is None:
393398
self.seen_release_info_empty = True
@@ -422,10 +427,11 @@ def reserve_asserts(self, ctx):
422427
await wf1.signal(WaitOnSignalWorkflow.my_signal, "finish")
423428
await wf1.result()
424429

425-
async def releases() -> int:
426-
return ss.releases
427-
428-
assert ss.reserves == ss.releases
430+
# We can't use reserve number directly because there is a technically possible race
431+
# where the python reserve function appears to complete, but Rust doesn't see that.
432+
# This isn't solvable without redoing a chunk of pyo3-asyncio. So we only check
433+
# that the permits passed to release line up.
434+
assert ss.highest_seen_reserve_on_release == ss.releases
429435
# Two workflow tasks, one activity
430436
assert ss.used == 3
431437
assert ss.seen_sticky_kinds == {True, False}

0 commit comments

Comments
 (0)