Skip to content

Commit dbaefde

Browse files
authored
Fix async session tests. (#433)
1 parent 20378ae commit dbaefde

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

tests/conftest.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import httpx
1616
import pytest
17+
import asyncio
18+
from functools import wraps
1719

1820
from tests.utils.client_configuration import ClientConfiguration
1921
from tests.utils.list_resource import list_data_to_dicts, list_response_of
@@ -26,7 +28,6 @@
2628

2729
from jwt import PyJWKClient
2830
from unittest.mock import Mock, patch
29-
from functools import wraps
3031

3132

3233
def _get_test_client_setup(
@@ -310,7 +311,19 @@ def inner(
310311

311312
def with_jwks_mock(func):
312313
@wraps(func)
313-
def wrapper(*args, **kwargs):
314+
async def async_wrapper(*args, **kwargs):
315+
# Create mock JWKS client
316+
mock_jwks = Mock(spec=PyJWKClient)
317+
mock_signing_key = Mock()
318+
mock_signing_key.key = kwargs["session_constants"]["PUBLIC_KEY"]
319+
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
320+
321+
# Apply the mock
322+
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
323+
return await func(*args, **kwargs)
324+
325+
@wraps(func)
326+
def sync_wrapper(*args, **kwargs):
314327
# Create mock JWKS client
315328
mock_jwks = Mock(spec=PyJWKClient)
316329
mock_signing_key = Mock()
@@ -321,4 +334,7 @@ def wrapper(*args, **kwargs):
321334
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
322335
return func(*args, **kwargs)
323336

324-
return wrapper
337+
# Return appropriate wrapper based on whether the function is async or not
338+
if asyncio.iscoroutinefunction(func):
339+
return async_wrapper
340+
return sync_wrapper

tests/test_session.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from unittest.mock import Mock, patch
2+
from unittest.mock import AsyncMock, Mock, patch
33
import jwt
44
from datetime import datetime, timezone
55

@@ -396,6 +396,7 @@ def test_refresh_success_with_aud_claim(
396396

397397

398398
class TestAsyncSession(SessionFixtures):
399+
@pytest.mark.asyncio
399400
@with_jwks_mock
400401
async def test_refresh_success(self, session_constants, mock_user_management):
401402
session_data = AsyncSession.seal_data(
@@ -413,8 +414,8 @@ async def test_refresh_success(self, session_constants, mock_user_management):
413414
"user": session_constants["TEST_USER"],
414415
}
415416

416-
mock_user_management.authenticate_with_refresh_token.return_value = (
417-
RefreshTokenAuthenticationResponse(**mock_response)
417+
mock_user_management.authenticate_with_refresh_token = AsyncMock(
418+
return_value=(RefreshTokenAuthenticationResponse(**mock_response))
418419
)
419420

420421
session = AsyncSession(
@@ -451,6 +452,7 @@ async def test_refresh_success(self, session_constants, mock_user_management):
451452
},
452453
)
453454

455+
@pytest.mark.asyncio
454456
@with_jwks_mock
455457
async def test_refresh_success_with_aud_claim(
456458
self, session_constants, mock_user_management
@@ -479,8 +481,8 @@ async def test_refresh_success_with_aud_claim(
479481
"user": session_constants["TEST_USER"],
480482
}
481483

482-
mock_user_management.authenticate_with_refresh_token.return_value = (
483-
RefreshTokenAuthenticationResponse(**mock_response)
484+
mock_user_management.authenticate_with_refresh_token = AsyncMock(
485+
return_value=(RefreshTokenAuthenticationResponse(**mock_response))
484486
)
485487

486488
session = AsyncSession(

0 commit comments

Comments
 (0)