@@ -7,7 +7,7 @@ use pyo3::types::{PyBytes, PyTuple};
7
7
use std:: collections:: HashMap ;
8
8
use std:: collections:: HashSet ;
9
9
use std:: marker:: PhantomData ;
10
- use std:: sync:: Arc ;
10
+ use std:: sync:: { Arc , OnceLock } ;
11
11
use std:: time:: Duration ;
12
12
use temporal_sdk_core:: api:: errors:: { PollActivityError , PollWfError } ;
13
13
use temporal_sdk_core:: replay:: { HistoryForReplay , ReplayWorkerInput } ;
@@ -64,6 +64,18 @@ pub struct TunerHolder {
64
64
local_activity_slot_supplier : SlotSupplier ,
65
65
}
66
66
67
+ // pub fn set_task_locals_on_tuner<'a>(py: Python<'a>, tuner: &TunerHolder) -> PyResult<()> {
68
+ // // TODO: All suppliers
69
+ // if let SlotSupplier::Custom(ref cs) = tuner.workflow_slot_supplier {
70
+ // Python::with_gil(|py| {
71
+ // let py_obj = cs.inner.as_ref(py);
72
+ // py_obj.call_method0("set_task_locals")?;
73
+ // Ok(())
74
+ // })?;
75
+ // };
76
+ // Ok(())
77
+ // }
78
+
67
79
#[ derive( FromPyObject ) ]
68
80
pub enum SlotSupplier {
69
81
FixedSize ( FixedSizeSlotSupplier ) ,
@@ -190,17 +202,60 @@ impl CustomSlotSupplier {
190
202
}
191
203
}
192
204
205
+ // Shouldn't really need this callback nonsense, it should be possible to do this from the pyo3
206
+ // asyncio library, but we'd have to vendor the whole thing to make the right improvements. When
207
+ // pyo3 is upgraded and we are using
208
+
209
+ #[ pyclass]
210
+ struct CreatedTaskForSlotCallback {
211
+ stored_task : Arc < OnceLock < PyObject > > ,
212
+ }
213
+
214
+ #[ pymethods]
215
+ impl CreatedTaskForSlotCallback {
216
+ fn __call__ ( & self , task : PyObject ) -> PyResult < ( ) > {
217
+ self . stored_task . set ( task) . expect ( "must only be set once" ) ;
218
+ Ok ( ( ) )
219
+ }
220
+ }
221
+
222
+ struct TaskCanceller {
223
+ stored_task : Arc < OnceLock < PyObject > > ,
224
+ }
225
+
226
+ impl TaskCanceller {
227
+ fn new ( stored_task : Arc < OnceLock < PyObject > > ) -> Self {
228
+ TaskCanceller { stored_task }
229
+ }
230
+ }
231
+
232
+ impl Drop for TaskCanceller {
233
+ fn drop ( & mut self ) {
234
+ if let Some ( task) = self . stored_task . get ( ) {
235
+ Python :: with_gil ( |py| {
236
+ task. call_method0 ( py, "cancel" )
237
+ . expect ( "Failed to cancel task" ) ;
238
+ } ) ;
239
+ }
240
+ }
241
+ }
242
+
193
243
#[ async_trait:: async_trait]
194
244
impl < SK : SlotKind + Send + Sync > SlotSupplierTrait for CustomSlotSupplierOfType < SK > {
195
245
type SlotKind = SK ;
196
246
197
247
async fn reserve_slot ( & self , ctx : & dyn SlotReservationContext ) -> SlotSupplierPermit {
198
248
loop {
249
+ let stored_task = Arc :: new ( OnceLock :: new ( ) ) ;
250
+ let _task_canceller = TaskCanceller :: new ( stored_task. clone ( ) ) ;
199
251
let pypermit = match Python :: with_gil ( |py| {
200
252
let py_obj = self . inner . as_ref ( py) ;
201
253
let called = py_obj. call_method1 (
202
254
"reserve_slot" ,
203
- ( SlotReserveCtx :: from_ctx ( Self :: SlotKind :: kind ( ) , ctx) , ) ,
255
+ (
256
+ SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) ,
257
+ CreatedTaskForSlotCallback { stored_task } ,
258
+ ) ,
204
259
) ?;
205
260
runtime:: THREAD_TASK_LOCAL
206
261
. with ( |tl| pyo3_asyncio:: into_future_with_locals ( tl. get ( ) . unwrap ( ) , called) )
@@ -232,7 +287,7 @@ impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<
232
287
let py_obj = self . inner . as_ref ( py) ;
233
288
let pa = py_obj. call_method1 (
234
289
"try_reserve_slot" ,
235
- ( SlotReserveCtx :: from_ctx ( Self :: SlotKind :: kind ( ) , ctx) , ) ,
290
+ ( SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) , ) ,
236
291
) ?;
237
292
238
293
if pa. is_none ( ) {
@@ -362,6 +417,8 @@ pub fn new_replay_worker<'a>(
362
417
impl WorkerRef {
363
418
fn validate < ' p > ( & self , py : Python < ' p > ) -> PyResult < & ' p PyAny > {
364
419
let worker = self . worker . as_ref ( ) . unwrap ( ) . clone ( ) ;
420
+ // Set custom slot supplier task locals so they can run futures
421
+ // match worker.get_config().tuner {}
365
422
self . runtime . future_into_py ( py, async move {
366
423
worker
367
424
. validate ( )
0 commit comments