@@ -97,12 +97,40 @@ async def test_common_methods_in_wrapped_call():
97
97
assert mock_call .wait_for_connection .call_count == 1
98
98
99
99
100
+ @pytest .mark .asyncio
101
+ @pytest .mark .parametrize (
102
+ "callable_type,expected_wrapper_type" ,
103
+ [
104
+ (grpc .aio .UnaryStreamMultiCallable , grpc_helpers_async ._WrappedUnaryStreamCall ),
105
+ (grpc .aio .StreamUnaryMultiCallable , grpc_helpers_async ._WrappedStreamUnaryCall ),
106
+ (
107
+ grpc .aio .StreamStreamMultiCallable ,
108
+ grpc_helpers_async ._WrappedStreamStreamCall ,
109
+ ),
110
+ ],
111
+ )
112
+ async def test_wrap_errors_w_stream_type (callable_type , expected_wrapper_type ):
113
+ class ConcreteMulticallable (callable_type ):
114
+ def __call__ (self , * args , ** kwargs ):
115
+ raise NotImplementedError ("Should not be called" )
116
+
117
+ with mock .patch .object (
118
+ grpc_helpers_async , "_wrap_stream_errors"
119
+ ) as wrap_stream_errors :
120
+ callable_ = ConcreteMulticallable ()
121
+ grpc_helpers_async .wrap_errors (callable_ )
122
+ assert wrap_stream_errors .call_count == 1
123
+ wrap_stream_errors .assert_called_once_with (callable_ , expected_wrapper_type )
124
+
125
+
100
126
@pytest .mark .asyncio
101
127
async def test_wrap_stream_errors_unary_stream ():
102
128
mock_call = mock .Mock (aio .UnaryStreamCall , autospec = True )
103
129
multicallable = mock .Mock (return_value = mock_call )
104
130
105
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
131
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
132
+ multicallable , grpc_helpers_async ._WrappedUnaryStreamCall
133
+ )
106
134
107
135
await wrapped_callable (1 , 2 , three = "four" )
108
136
multicallable .assert_called_once_with (1 , 2 , three = "four" )
@@ -114,7 +142,9 @@ async def test_wrap_stream_errors_stream_unary():
114
142
mock_call = mock .Mock (aio .StreamUnaryCall , autospec = True )
115
143
multicallable = mock .Mock (return_value = mock_call )
116
144
117
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
145
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
146
+ multicallable , grpc_helpers_async ._WrappedStreamUnaryCall
147
+ )
118
148
119
149
await wrapped_callable (1 , 2 , three = "four" )
120
150
multicallable .assert_called_once_with (1 , 2 , three = "four" )
@@ -126,22 +156,26 @@ async def test_wrap_stream_errors_stream_stream():
126
156
mock_call = mock .Mock (aio .StreamStreamCall , autospec = True )
127
157
multicallable = mock .Mock (return_value = mock_call )
128
158
129
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
159
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
160
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
161
+ )
130
162
131
163
await wrapped_callable (1 , 2 , three = "four" )
132
164
multicallable .assert_called_once_with (1 , 2 , three = "four" )
133
165
assert mock_call .wait_for_connection .call_count == 1
134
166
135
167
136
168
@pytest .mark .asyncio
137
- async def test_wrap_stream_errors_type_error ():
169
+ async def test_wrap_errors_type_error ():
170
+ """
171
+ If wrap_errors is called with an unexpected type, it should raise a TypeError.
172
+ """
138
173
mock_call = mock .Mock ()
139
174
multicallable = mock .Mock (return_value = mock_call )
140
175
141
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
142
-
143
- with pytest .raises (TypeError ):
144
- await wrapped_callable ()
176
+ with pytest .raises (TypeError ) as exc :
177
+ grpc_helpers_async .wrap_errors (multicallable )
178
+ assert "Unexpected type" in str (exc .value )
145
179
146
180
147
181
@pytest .mark .asyncio
@@ -151,7 +185,9 @@ async def test_wrap_stream_errors_raised():
151
185
mock_call .wait_for_connection = mock .AsyncMock (side_effect = [grpc_error ])
152
186
multicallable = mock .Mock (return_value = mock_call )
153
187
154
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
188
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
189
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
190
+ )
155
191
156
192
with pytest .raises (exceptions .InvalidArgument ):
157
193
await wrapped_callable ()
@@ -166,7 +202,9 @@ async def test_wrap_stream_errors_read():
166
202
mock_call .read = mock .AsyncMock (side_effect = grpc_error )
167
203
multicallable = mock .Mock (return_value = mock_call )
168
204
169
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
205
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
206
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
207
+ )
170
208
171
209
wrapped_call = await wrapped_callable (1 , 2 , three = "four" )
172
210
multicallable .assert_called_once_with (1 , 2 , three = "four" )
@@ -189,7 +227,9 @@ async def test_wrap_stream_errors_aiter():
189
227
mock_call .__aiter__ = mock .Mock (return_value = mocked_aiter )
190
228
multicallable = mock .Mock (return_value = mock_call )
191
229
192
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
230
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
231
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
232
+ )
193
233
wrapped_call = await wrapped_callable ()
194
234
195
235
with pytest .raises (exceptions .InvalidArgument ) as exc_info :
@@ -210,7 +250,9 @@ async def test_wrap_stream_errors_aiter_non_rpc_error():
210
250
mock_call .__aiter__ = mock .Mock (return_value = mocked_aiter )
211
251
multicallable = mock .Mock (return_value = mock_call )
212
252
213
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
253
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
254
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
255
+ )
214
256
wrapped_call = await wrapped_callable ()
215
257
216
258
with pytest .raises (TypeError ) as exc_info :
@@ -224,7 +266,9 @@ async def test_wrap_stream_errors_aiter_called_multiple_times():
224
266
mock_call = mock .Mock (aio .StreamStreamCall , autospec = True )
225
267
multicallable = mock .Mock (return_value = mock_call )
226
268
227
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
269
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
270
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
271
+ )
228
272
wrapped_call = await wrapped_callable ()
229
273
230
274
assert wrapped_call .__aiter__ () == wrapped_call .__aiter__ ()
@@ -239,7 +283,9 @@ async def test_wrap_stream_errors_write():
239
283
mock_call .done_writing = mock .AsyncMock (side_effect = [None , grpc_error ])
240
284
multicallable = mock .Mock (return_value = mock_call )
241
285
242
- wrapped_callable = grpc_helpers_async ._wrap_stream_errors (multicallable )
286
+ wrapped_callable = grpc_helpers_async ._wrap_stream_errors (
287
+ multicallable , grpc_helpers_async ._WrappedStreamStreamCall
288
+ )
243
289
244
290
wrapped_call = await wrapped_callable ()
245
291
@@ -295,7 +341,9 @@ def test_wrap_errors_streaming(wrap_stream_errors):
295
341
result = grpc_helpers_async .wrap_errors (callable_ )
296
342
297
343
assert result == wrap_stream_errors .return_value
298
- wrap_stream_errors .assert_called_once_with (callable_ )
344
+ wrap_stream_errors .assert_called_once_with (
345
+ callable_ , grpc_helpers_async ._WrappedUnaryStreamCall
346
+ )
299
347
300
348
301
349
@pytest .mark .parametrize (
0 commit comments