Skip to content

Commit b2e75b4

Browse files
authored
Simplify useCustomClaims to useTokenClaims (#258)
1 parent 3a67d08 commit b2e75b4

File tree

7 files changed

+130
-135
lines changed

7 files changed

+130
-135
lines changed

__tests__/session.spec.ts

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { NextRequest, NextResponse } from 'next/server.js';
22
import { cookies, headers } from 'next/headers.js';
33
import { redirect } from 'next/navigation.js';
44
import { generateTestToken } from './test-helpers.js';
5-
import { withAuth, updateSession, refreshSession, updateSessionMiddleware, getCustomClaims } from '../src/session.js';
5+
import { withAuth, updateSession, refreshSession, updateSessionMiddleware, getTokenClaims } from '../src/session.js';
66
import { getWorkOS } from '../src/workos.js';
77
import * as envVariables from '../src/env-variables.js';
88

@@ -843,103 +843,94 @@ describe('session.ts', () => {
843843
});
844844
});
845845

846-
describe('getCustomClaims', () => {
846+
describe('getTokenClaims', () => {
847847
beforeEach(async () => {
848848
const nextCookies = await cookies();
849849
// @ts-expect-error - _reset is part of the mock
850850
nextCookies._reset();
851851
jest.clearAllMocks();
852852
});
853853

854-
it('should return custom claims when accessToken is provided', async () => {
855-
const customClaims = { department: 'engineering', level: 5, metadata: { theme: 'dark' } };
856-
const token = await generateTestToken({
854+
it('should return all token claims when accessToken is provided', async () => {
855+
const tokenPayload = {
857856
sub: 'user_123',
858857
org_id: 'org_123',
859858
role: 'admin',
860859
permissions: ['read', 'write'],
861860
entitlements: ['feature_a'],
862-
...customClaims,
863-
});
861+
department: 'engineering',
862+
level: 5,
863+
metadata: { theme: 'dark' },
864+
};
865+
const token = await generateTestToken(tokenPayload);
864866

865-
const result = await getCustomClaims(token);
867+
const result = await getTokenClaims(token);
866868

867-
expect(result).toEqual(customClaims);
869+
expect(result).toMatchObject(tokenPayload);
868870
});
869871

870-
it('should return null when no accessToken is provided and no session exists', async () => {
871-
const result = await getCustomClaims();
872+
it('should return empty object when no accessToken is provided and no session exists', async () => {
873+
const result = await getTokenClaims();
872874

873-
expect(result).toBeNull();
875+
expect(result).toEqual({});
874876
});
875877

876-
it('should return empty object when token has no custom claims', async () => {
877-
const token = await generateTestToken({
878+
it('should return all standard claims when token has only standard claims', async () => {
879+
const tokenPayload = {
878880
sub: 'user_123',
879881
org_id: 'org_123',
880882
role: 'admin',
881883
permissions: ['read', 'write'],
882884
entitlements: ['feature_a'],
883-
});
885+
};
886+
const token = await generateTestToken(tokenPayload);
884887

885-
const result = await getCustomClaims(token);
888+
const result = await getTokenClaims(token);
886889

887-
expect(result).toEqual({});
890+
expect(result).toMatchObject(tokenPayload);
888891
});
889892

890-
it('should filter out all standard JWT claims', async () => {
891-
const customClaims = { customField: 'value', anotherCustom: 42 };
892-
const token = await generateTestToken({
893+
it('should return all claims including standard JWT claims', async () => {
894+
const customClaims = {
895+
customField: 'value',
896+
anotherCustom: 42,
897+
};
898+
const standardClaims = {
893899
aud: 'audience',
894-
exp: Math.floor(Date.now() / 1000) + 3600,
895-
iat: Math.floor(Date.now() / 1000),
896-
iss: 'issuer',
897900
sub: 'user_123',
898901
sid: 'session_123',
899902
org_id: 'org_123',
900903
role: 'admin',
901904
permissions: ['read', 'write'],
902905
entitlements: ['feature_a'],
903906
jti: 'jwt_123',
904-
nbf: Math.floor(Date.now() / 1000),
905-
...customClaims,
906-
});
907+
};
908+
const token = await generateTestToken({ ...standardClaims, ...customClaims });
907909

908-
const result = await getCustomClaims(token);
910+
const result = await getTokenClaims(token);
909911

910-
expect(result).toEqual(customClaims);
911-
expect(result).not.toHaveProperty('aud');
912-
expect(result).not.toHaveProperty('exp');
913-
expect(result).not.toHaveProperty('iat');
914-
expect(result).not.toHaveProperty('iss');
915-
expect(result).not.toHaveProperty('sub');
916-
expect(result).not.toHaveProperty('sid');
917-
expect(result).not.toHaveProperty('org_id');
918-
expect(result).not.toHaveProperty('role');
919-
expect(result).not.toHaveProperty('permissions');
920-
expect(result).not.toHaveProperty('entitlements');
921-
expect(result).not.toHaveProperty('jti');
922-
expect(result).not.toHaveProperty('nbf');
912+
expect(result).toMatchObject({ ...standardClaims, ...customClaims });
913+
expect(result).toHaveProperty('exp');
914+
expect(result).toHaveProperty('iat');
915+
expect(result).toHaveProperty('iss');
923916
});
924917

925-
it('should handle complex nested custom claims', async () => {
926-
const customClaims = {
918+
it('should handle complex nested claims', async () => {
919+
const tokenPayload = {
920+
sub: 'user_123',
921+
org_id: 'org_123',
927922
metadata: {
928923
preferences: { theme: 'dark', language: 'en' },
929924
settings: ['setting1', 'setting2'],
930925
},
931926
tags: ['tag1', 'tag2'],
932927
permissions_custom: { read: true, write: false },
933928
};
934-
const token = await generateTestToken({
935-
sub: 'user_123',
936-
org_id: 'org_123',
937-
...customClaims,
938-
});
929+
const token = await generateTestToken(tokenPayload);
939930

940-
const result = await getCustomClaims(token);
931+
const result = await getTokenClaims(token);
941932

942-
expect(result).toEqual(customClaims);
933+
expect(result).toMatchObject(tokenPayload);
943934
});
944935
});
945936
});
Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import '@testing-library/jest-dom';
22
import { render, waitFor } from '@testing-library/react';
33
import React from 'react';
4-
import { getAccessTokenAction } from '../src/actions.js';
54
import { useAuth } from '../src/components/authkit-provider.js';
6-
import { useCustomClaims } from '../src/components/useCustomClaims.js';
75

86
jest.mock('../src/actions.js', () => ({
97
getAccessTokenAction: jest.fn(),
@@ -18,20 +16,31 @@ jest.mock('../src/components/authkit-provider.js', () => {
1816
};
1917
});
2018

19+
jest.mock('../src/components/useAccessToken.js', () => ({
20+
useAccessToken: jest.fn(() => ({ accessToken: undefined })),
21+
}));
22+
2123
jest.mock('jose', () => ({
2224
decodeJwt: jest.fn((token: string) => {
25+
if (token === 'malformed-token' || token === 'throw-error-token') {
26+
throw new Error('Invalid JWT');
27+
}
2328
try {
2429
const parts = token.split('.');
25-
if (parts.length !== 3) return null;
30+
if (parts.length !== 3) throw new Error('Invalid JWT');
2631
const payload = JSON.parse(atob(parts[1]));
2732
return payload;
2833
} catch {
29-
return null;
34+
throw new Error('Invalid JWT');
3035
}
3136
}),
3237
}));
3338

34-
describe('useCustomClaims', () => {
39+
// Import after mocks are set up
40+
import { useAccessToken } from '../src/components/useAccessToken.js';
41+
import { useTokenClaims } from '../src/components/useTokenClaims.js';
42+
43+
describe('useTokenClaims', () => {
3544
beforeEach(() => {
3645
jest.clearAllMocks();
3746
jest.useFakeTimers();
@@ -41,32 +50,35 @@ describe('useCustomClaims', () => {
4150
sessionId: 'session_123',
4251
refreshAuth: jest.fn().mockResolvedValue({}),
4352
}));
53+
54+
// Reset useAccessToken mock to default
55+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: undefined });
4456
});
4557

4658
afterEach(() => {
4759
jest.useRealTimers();
4860
});
4961

50-
const CustomClaimsTestComponent = () => {
51-
const customClaims = useCustomClaims();
62+
const TokenClaimsTestComponent = () => {
63+
const tokenClaims = useTokenClaims();
5264
return (
5365
<div>
54-
<div data-testid="claims">{JSON.stringify(customClaims)}</div>
66+
<div data-testid="claims">{JSON.stringify(tokenClaims)}</div>
5567
</div>
5668
);
5769
};
5870

59-
it('should return null when no access token is available', async () => {
60-
(getAccessTokenAction as jest.Mock).mockResolvedValue(undefined);
71+
it('should return empty object when no access token is available', async () => {
72+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: undefined });
6173

62-
const { getByTestId } = render(<CustomClaimsTestComponent />);
74+
const { getByTestId } = render(<TokenClaimsTestComponent />);
6375

6476
await waitFor(() => {
65-
expect(getByTestId('claims')).toHaveTextContent('null');
77+
expect(getByTestId('claims')).toHaveTextContent('{}');
6678
});
6779
});
6880

69-
it('should return custom claims when access token is available', async () => {
81+
it('should return all token claims when access token is available', async () => {
7082
const payload = {
7183
aud: 'audience',
7284
exp: 9999999999,
@@ -87,21 +99,16 @@ describe('useCustomClaims', () => {
8799
};
88100
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;
89101

90-
(getAccessTokenAction as jest.Mock).mockResolvedValue(token);
102+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: token });
91103

92-
const { getByTestId } = render(<CustomClaimsTestComponent />);
104+
const { getByTestId } = render(<TokenClaimsTestComponent />);
93105

94106
await waitFor(() => {
95-
const expectedCustomClaims = {
96-
customField1: 'value1',
97-
customField2: 42,
98-
customObject: { nested: 'data' },
99-
};
100-
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
107+
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(payload));
101108
});
102109
});
103110

104-
it('should return empty object when token has no custom claims', async () => {
111+
it('should return all standard claims when token has only standard claims', async () => {
105112
const payload = {
106113
aud: 'audience',
107114
exp: 9999999999,
@@ -118,16 +125,16 @@ describe('useCustomClaims', () => {
118125
};
119126
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;
120127

121-
(getAccessTokenAction as jest.Mock).mockResolvedValue(token);
128+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: token });
122129

123-
const { getByTestId } = render(<CustomClaimsTestComponent />);
130+
const { getByTestId } = render(<TokenClaimsTestComponent />);
124131

125132
await waitFor(() => {
126-
expect(getByTestId('claims')).toHaveTextContent('{}');
133+
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(payload));
127134
});
128135
});
129136

130-
it('should handle partial standard claims', async () => {
137+
it('should handle partial claims', async () => {
131138
const payload = {
132139
sub: 'user_123',
133140
exp: 9999999999,
@@ -136,20 +143,16 @@ describe('useCustomClaims', () => {
136143
};
137144
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;
138145

139-
(getAccessTokenAction as jest.Mock).mockResolvedValue(token);
146+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: token });
140147

141-
const { getByTestId } = render(<CustomClaimsTestComponent />);
148+
const { getByTestId } = render(<TokenClaimsTestComponent />);
142149

143150
await waitFor(() => {
144-
const expectedCustomClaims = {
145-
customField: 'value',
146-
anotherCustom: true,
147-
};
148-
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
151+
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(payload));
149152
});
150153
});
151154

152-
it('should handle complex nested custom claims', async () => {
155+
it('should handle complex nested claims', async () => {
153156
const payload = {
154157
sub: 'user_123',
155158
exp: 9999999999,
@@ -168,26 +171,22 @@ describe('useCustomClaims', () => {
168171
};
169172
const token = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.${btoa(JSON.stringify(payload))}.mock-signature`;
170173

171-
(getAccessTokenAction as jest.Mock).mockResolvedValue(token);
174+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: token });
172175

173-
const { getByTestId } = render(<CustomClaimsTestComponent />);
176+
const { getByTestId } = render(<TokenClaimsTestComponent />);
174177

175178
await waitFor(() => {
176-
const expectedCustomClaims = {
177-
metadata: {
178-
preferences: {
179-
theme: 'dark',
180-
language: 'en',
181-
},
182-
settings: ['setting1', 'setting2'],
183-
},
184-
tags: ['tag1', 'tag2'],
185-
permissions_custom: {
186-
read: true,
187-
write: false,
188-
},
189-
};
190-
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(expectedCustomClaims));
179+
expect(getByTestId('claims')).toHaveTextContent(JSON.stringify(payload));
180+
});
181+
});
182+
183+
it('should return empty object when decodeJwt throws an error', async () => {
184+
(useAccessToken as jest.Mock).mockReturnValue({ accessToken: 'malformed-token' });
185+
186+
const { getByTestId } = render(<TokenClaimsTestComponent />);
187+
188+
await waitFor(() => {
189+
expect(getByTestId('claims')).toHaveTextContent('{}');
191190
});
192191
});
193192
});

src/components/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Impersonation } from './impersonation.js';
22
import { AuthKitProvider, useAuth } from './authkit-provider.js';
33
import { useAccessToken } from './useAccessToken.js';
4-
import { useCustomClaims } from './useCustomClaims.js';
4+
import { useTokenClaims } from './useTokenClaims.js';
55

6-
export { Impersonation, AuthKitProvider, useAuth, useAccessToken, useCustomClaims };
6+
export { Impersonation, AuthKitProvider, useAuth, useAccessToken, useTokenClaims };

src/components/useCustomClaims.ts

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)