diff --git a/stdlib/public/Concurrency/Task.cpp b/stdlib/public/Concurrency/Task.cpp index 9e03b1cd85172..1b0de154e2f49 100644 --- a/stdlib/public/Concurrency/Task.cpp +++ b/stdlib/public/Concurrency/Task.cpp @@ -1758,18 +1758,27 @@ swift_task_addCancellationHandlerImpl( CancellationNotificationStatusRecord(unsigned_handler, context); bool fireHandlerNow = false; - addStatusRecordToSelf(record, [&](ActiveTaskStatus oldStatus, ActiveTaskStatus& newStatus) { if (oldStatus.isCancelled()) { - fireHandlerNow = true; // We don't fire the cancellation handler here since this function needs // to be idempotent + fireHandlerNow = true; + + // don't add the record, because that would risk triggering it from + // task_cancel, concurrently with the record->run() we're about to do below. + return false; } - return true; + return true; // add the record }); if (fireHandlerNow) { record->run(); + + // we have not added the record to the task because it has fired immediately, + // and therefore we can clean it up immediately rather than wait until removeCancellationHandler + // which would be triggered at the end of the withTaskCancellationHandler block. + swift_task_dealloc(record); + return nullptr; // indicate to the remove... method, that there was no task added } return record; } @@ -1777,8 +1786,17 @@ swift_task_addCancellationHandlerImpl( SWIFT_CC(swift) static void swift_task_removeCancellationHandlerImpl( CancellationNotificationStatusRecord *record) { - removeStatusRecordFromSelf(record); - swift_task_dealloc(record); + if (!record) { + // seems we never added the record but have run it immediately, + // so we make no attempts to remove it. + return; + } + + if (auto poppedRecord = + popStatusRecordOfType(swift_task_getCurrent())) { + assert(record == poppedRecord && "The removed record did not match the expected record!"); + swift_task_dealloc(record); + } } SWIFT_CC(swift) diff --git a/stdlib/public/Concurrency/TaskPrivate.h b/stdlib/public/Concurrency/TaskPrivate.h index 49b5baa17ef7c..2235e7a687b23 100644 --- a/stdlib/public/Concurrency/TaskPrivate.h +++ b/stdlib/public/Concurrency/TaskPrivate.h @@ -244,6 +244,16 @@ void removeStatusRecordWhere( llvm::function_ref condition, llvm::function_refupdateStatus = nullptr); +/// Remove and return a status record of the given type. This function removes a +/// singlw record, and leaves subsequent records as-is if there are any. +/// Returns `nullptr` if there are no matching records. +/// +/// NOTE: When using this function with new record types, make sure to provide +/// an explicit instantiation in TaskStatus.cpp. +template +SWIFT_CC(swift) +TaskStatusRecordT* popStatusRecordOfType(AsyncTask *task); + /// Remove a status record from the current task. This must be called /// synchronously with the task. SWIFT_CC(swift) diff --git a/stdlib/public/Concurrency/TaskStatus.cpp b/stdlib/public/Concurrency/TaskStatus.cpp index 802d2c87a193f..38cc28b45535e 100644 --- a/stdlib/public/Concurrency/TaskStatus.cpp +++ b/stdlib/public/Concurrency/TaskStatus.cpp @@ -350,15 +350,18 @@ void swift::removeStatusRecordWhere( }); } -// Remove and return a status record of the given type. There must be at most -// one matching record. Returns nullptr if there are none. template -static TaskStatusRecordT *popStatusRecordOfType(AsyncTask *task) { +SWIFT_CC(swift) +TaskStatusRecordT* swift::popStatusRecordOfType(AsyncTask *task) { TaskStatusRecordT *record = nullptr; + bool alreadyRemovedRecord = false; removeStatusRecordWhere(task, [&](ActiveTaskStatus s, TaskStatusRecord *r) { + if (alreadyRemovedRecord) + return false; + if (auto *match = dyn_cast(r)) { - assert(!record && "two matching records found"); record = match; + alreadyRemovedRecord = true; return true; // Remove this record. } @@ -562,6 +565,10 @@ static void swift_task_popTaskExecutorPreferenceImpl( swift_task_dealloc(record); } +// Since the header would have incomplete declarations, we instead instantiate a concrete version of the function here +template SWIFT_CC(swift) +CancellationNotificationStatusRecord* swift::popStatusRecordOfType(AsyncTask *); + void AsyncTask::pushInitialTaskExecutorPreference( TaskExecutorRef preferredExecutor, bool owned) { void *allocation = _swift_task_alloc_specific( @@ -879,7 +886,7 @@ static void swift_task_cancelImpl(AsyncTask *task) { } newStatus.traceStatusChanged(task, false); - if (newStatus.getInnermostRecord() == NULL) { + if (newStatus.getInnermostRecord() == nullptr) { // No records, nothing to propagate return; } diff --git a/test/Concurrency/Runtime/cancellation_handler_only_once.swift b/test/Concurrency/Runtime/cancellation_handler_only_once.swift new file mode 100644 index 0000000000000..b188b1f0c02b0 --- /dev/null +++ b/test/Concurrency/Runtime/cancellation_handler_only_once.swift @@ -0,0 +1,62 @@ +// RUN: %target-run-simple-swift( -Xfrontend -disable-availability-checking -target %target-swift-5.1-abi-triple %import-libdispatch) | %FileCheck %s +// REQUIRES: concurrency +// REQUIRES: executable_test + +// rdar://76038845 +// REQUIRES: concurrency_runtime +// UNSUPPORTED: back_deployment_runtime +// UNSUPPORTED: freestanding + +import Synchronization + +struct State { + var cancelled = 0 + var continuation: CheckedContinuation? +} + +func testFunc(_ iteration: Int) async -> Task { + let state = Mutex(State()) + + let task = Task { + await withTaskCancellationHandler { + await withCheckedContinuation { continuation in + let cancelled = state.withLock { + if $0.cancelled > 0 { + return true + } else { + $0.continuation = continuation + return false + } + } + if cancelled { + continuation.resume() + } + } + } onCancel: { + let continuation = state.withLock { + $0.cancelled += 1 + return $0.continuation.take() + } + continuation?.resume() + } + } + + // This task cancel is racing with installing the cancellation handler, + // and we may either hit the cancellation handler: + // - after this cancel was issued, and therefore the handler runs immediately + task.cancel() + _ = await task.value + + let cancelled = state.withLock { $0.cancelled } + precondition(cancelled == 1, "cancelled more than once, iteration: \(iteration)") + + return task +} + +var ts: [Task] = [] +for iteration in 0..<1_000 { + let t = await testFunc(iteration) + ts.append(t) +} + +print("done") // CHECK: done