Skip to content

Commit 584aa99

Browse files
committed
Cancel remaining type resolvers on exceptions
1 parent 182e2d6 commit 584aa99

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

src/graphql/execution/execute.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,25 +2105,39 @@ def default_type_resolver(
21052105
# Otherwise, test each possible type.
21062106
possible_types = info.schema.get_possible_types(abstract_type)
21072107
is_awaitable = info.is_awaitable
2108-
awaitable_is_type_of_results: list[Awaitable] = []
2109-
append_awaitable_results = awaitable_is_type_of_results.append
2108+
awaitable_is_type_of_results: list[Awaitable[bool]] = []
2109+
append_awaitable_result = awaitable_is_type_of_results.append
21102110
awaitable_types: list[GraphQLObjectType] = []
2111-
append_awaitable_types = awaitable_types.append
2111+
append_awaitable_type = awaitable_types.append
21122112

21132113
for type_ in possible_types:
21142114
if type_.is_type_of:
21152115
is_type_of_result = type_.is_type_of(value, info)
21162116

21172117
if is_awaitable(is_type_of_result):
2118-
append_awaitable_results(cast("Awaitable", is_type_of_result))
2119-
append_awaitable_types(type_)
2118+
append_awaitable_result(cast("Awaitable[bool]", is_type_of_result))
2119+
append_awaitable_type(type_)
21202120
elif is_type_of_result:
21212121
return type_.name
21222122

21232123
if awaitable_is_type_of_results:
21242124
# noinspection PyShadowingNames
21252125
async def get_type() -> str | None:
2126-
is_type_of_results = await gather(*awaitable_is_type_of_results)
2126+
tasks = [
2127+
create_task(result) # type: ignore[arg-type]
2128+
for result in awaitable_is_type_of_results
2129+
]
2130+
2131+
try:
2132+
is_type_of_results = await gather(*tasks)
2133+
except Exception:
2134+
# Cancel unfinished tasks before raising the exception
2135+
for task in tasks:
2136+
if not task.done():
2137+
task.cancel()
2138+
await gather(*tasks, return_exceptions=True)
2139+
raise
2140+
21272141
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
21282142
if is_type_of_result:
21292143
return type_.name

tests/execution/test_parallel.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,84 @@ async def resolve_iterator(*args):
336336
await barrier.wait()
337337
await asyncio.sleep(0)
338338
assert not completed
339+
340+
@pytest.mark.asyncio
341+
async def cancel_type_resolver():
342+
FooType = GraphQLInterfaceType("Foo", {"foo": GraphQLField(GraphQLString)})
343+
344+
barrier = Barrier(3)
345+
completed = False
346+
347+
async def is_type_of_bar(*_args):
348+
raise RuntimeError("Oops")
349+
350+
BarType = GraphQLObjectType(
351+
"Bar",
352+
{
353+
"foo": GraphQLField(GraphQLString),
354+
},
355+
interfaces=[FooType],
356+
is_type_of=is_type_of_bar,
357+
)
358+
359+
async def is_type_of_baz(*_args):
360+
nonlocal completed
361+
await barrier.wait()
362+
completed = True # pragma: no cover
363+
364+
BazType = GraphQLObjectType(
365+
"Baz",
366+
{
367+
"foo": GraphQLField(GraphQLString),
368+
},
369+
interfaces=[FooType],
370+
is_type_of=is_type_of_baz,
371+
)
372+
373+
schema = GraphQLSchema(
374+
GraphQLObjectType(
375+
"Query",
376+
{
377+
"foo": GraphQLField(
378+
GraphQLList(FooType),
379+
resolve=lambda *_args: [
380+
{"foo": "bar"},
381+
{"foo": "baz"},
382+
],
383+
)
384+
},
385+
),
386+
types=[BarType, BazType],
387+
)
388+
389+
ast = parse(
390+
"""
391+
{
392+
foo {
393+
foo
394+
... on Bar { foobar }
395+
... on Baz { foobaz }
396+
}
397+
}
398+
"""
399+
)
400+
401+
# raises TimeoutError if not parallel
402+
awaitable_result = execute(schema, ast)
403+
assert isinstance(awaitable_result, Awaitable)
404+
result = await asyncio.wait_for(awaitable_result, 1)
405+
406+
assert result == (
407+
{"foo": [None, None]},
408+
[
409+
{"message": "Oops", "locations": [(3, 17)], "path": ["foo", 0]},
410+
{"message": "Oops", "locations": [(3, 17)], "path": ["foo", 1]},
411+
],
412+
)
413+
414+
assert not completed
415+
416+
# Unblock succeed() and check that it does not complete
417+
await barrier.wait()
418+
await asyncio.sleep(0)
419+
assert not completed

0 commit comments

Comments
 (0)