Skip to content

Commit 09b9fd3

Browse files
committed
fix GenericReference iterable query (i.e. __in)
This change adds the ``_ref`` or ``_ref.$id`` prefix to a query if all values in an iterable query (i.e. ``__in``) are ``ObjectId``s or ``DBRef``s and raises an error for a mixed query which will only work for documents. These could possibly be compiled into an ``{$or: ...}`` query, but the automatic expansion can be added as necessary.
1 parent d5867d6 commit 09b9fd3

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

mongoengine/queryset/transform.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ def query(_doc_cls=None, **kwargs):
129129

130130
singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"]
131131
singular_ops += STRING_OPERATORS
132+
is_iterable = False
132133
if op in singular_ops:
133134
value = field.prepare_query_value(op, value)
134135

135136
if isinstance(field, CachedReferenceField) and value:
136137
value = value["_id"]
137138

138139
elif op in ("in", "nin", "all", "near") and not isinstance(value, dict):
140+
is_iterable = True
139141
# Raise an error if the in/nin/all/near param is not iterable.
140142
value = _prepare_query_for_iterable(field, op, value)
141143

@@ -144,10 +146,24 @@ def query(_doc_cls=None, **kwargs):
144146
# * If the value is a DBRef, the key should be "field_name._ref".
145147
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
146148
if isinstance(field, GenericReferenceField):
147-
if isinstance(value, DBRef):
149+
if isinstance(value, DBRef) or (
150+
is_iterable and all(isinstance(v, DBRef) for v in value)
151+
):
148152
parts[-1] += "._ref"
149-
elif isinstance(value, ObjectId):
153+
elif isinstance(value, ObjectId) or (
154+
is_iterable and all(isinstance(v, ObjectId) for v in value)
155+
):
150156
parts[-1] += "._ref.$id"
157+
elif (
158+
is_iterable
159+
and any(isinstance(v, DBRef) for v in value)
160+
and any(isinstance(v, ObjectId) for v in value)
161+
):
162+
raise ValueError(
163+
"The `in`, `nin`, `all`, or `near`-operators cannot "
164+
"be applied to mixed queries of DBRef/ObjectId/%s"
165+
% _doc_cls.__name__
166+
)
151167

152168
# if op and op not in COMPARISON_OPERATORS:
153169
if op:

0 commit comments

Comments
 (0)