16
16
"""
17
17
18
18
import json
19
- from base64 import b64encode , b64decode , urlsafe_b64decode
20
- from datetime import datetime
19
+ import jwt
21
20
import random
22
- import string
23
- from jose import jwk
24
21
import requests
25
- from requests .sessions import Session
26
22
import six
27
- from requests_oauthlib import OAuth1
28
-
23
+ import string
24
+ from base64 import b64encode , b64decode , urlsafe_b64decode
25
+ from datetime import datetime
26
+ from requests .sessions import Session
29
27
28
+ from intuitlib .config import DISCOVERY_URL , ACCEPT_HEADER
30
29
from intuitlib .enums import Scopes
31
30
from intuitlib .exceptions import AuthClientError
32
- from intuitlib . config import DISCOVERY_URL , ACCEPT_HEADER
31
+
33
32
34
33
def get_discovery_doc (environment , session = None ):
35
34
"""Gets discovery doc based on environment specified.
@@ -153,7 +152,6 @@ def validate_id_token(id_token, client_id, intuit_issuer, jwk_uri):
153
152
154
153
id_token_header = json .loads (b64decode (_correct_padding (id_token_parts [0 ])).decode ('ascii' ))
155
154
id_token_payload = json .loads (b64decode (_correct_padding (id_token_parts [1 ])).decode ('ascii' ))
156
- id_token_signature = urlsafe_b64decode (((_correct_padding (id_token_parts [2 ])).encode ('ascii' )))
157
155
158
156
if id_token_payload ['iss' ] != intuit_issuer :
159
157
return False
@@ -164,12 +162,12 @@ def validate_id_token(id_token, client_id, intuit_issuer, jwk_uri):
164
162
if id_token_payload ['exp' ] < current_time :
165
163
return False
166
164
167
- message = id_token_parts [ 0 ] + '.' + id_token_parts [ 1 ]
168
- keys_dict = get_jwk ( id_token_header [ 'kid' ], jwk_uri )
169
-
170
- public_key = jwk . construct ( keys_dict )
171
- is_signature_valid = public_key . verify ( message . encode ( 'utf-8' ), id_token_signature )
172
- return is_signature_valid
165
+ public_key = get_jwk ( id_token_header [ 'kid' ], jwk_uri ). key
166
+ try :
167
+ jwt . decode ( id_token , public_key , audience = client_id , algorithms = [ 'RS256' ])
168
+ return True
169
+ except jwt . PyJWTError :
170
+ return False
173
171
174
172
def get_jwk (kid , jwk_uri ):
175
173
"""Get JWK for public key information
@@ -178,15 +176,14 @@ def get_jwk(kid, jwk_uri):
178
176
:param jwk_uri: JWK URI
179
177
180
178
:raises HTTPError: if response status != 200
181
- :return: dict containing keys
179
+ :return: Algorithm with the key loaded.
182
180
"""
183
181
184
182
response = requests .get (jwk_uri )
185
183
if response .status_code != 200 :
186
184
raise AuthClientError (response )
187
185
data = response .json ()
188
- keys = next (key for key in data ["keys" ] if key ['kid' ] == kid )
189
- return keys
186
+ return jwt .PyJWKSet .from_dict (data )[kid ]
190
187
191
188
def _correct_padding (val ):
192
189
"""Correct padding for JWT
0 commit comments