Skip to content

Commit fe3a2bb

Browse files
committed
make algorithm builders non-global
- algorithm_db is kind of a map of algorithm name to algorithm builder function. It allows users to register custom algorithms.
1 parent e0b5b0d commit fe3a2bb

File tree

2 files changed

+102
-100
lines changed

2 files changed

+102
-100
lines changed

include/jwt-cpp/jwt.h

Lines changed: 100 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3427,73 +3427,94 @@ namespace jwt {
34273427
}
34283428
};
34293429

3430-
using algorithm_builder = std::map<std::string, std::function<std::unique_ptr<algo_base>(const key&)>>;
3431-
static const algorithm_builder default_verification_algorithms = {
3432-
{"RS256",
3433-
[](const key& key) {
3434-
return std::unique_ptr<algo<jwt::algorithm::rs256>>(
3435-
new algo<jwt::algorithm::rs256>(jwt::algorithm::rs256(key.get_asymmetric_key())));
3436-
}},
3437-
{"RS384",
3438-
[](const key& key) {
3439-
return std::unique_ptr<algo<jwt::algorithm::rs384>>(
3440-
new algo<jwt::algorithm::rs384>(jwt::algorithm::rs384(key.get_asymmetric_key())));
3441-
}},
3442-
{"RS512",
3443-
[](const key& key) {
3444-
return std::unique_ptr<algo<jwt::algorithm::rs512>>(
3445-
new algo<jwt::algorithm::rs512>(jwt::algorithm::rs512(key.get_asymmetric_key())));
3446-
}},
3447-
{"PS256",
3448-
[](const key& key) {
3449-
return std::unique_ptr<algo<jwt::algorithm::ps256>>(
3450-
new algo<jwt::algorithm::ps256>(jwt::algorithm::ps256(key.get_asymmetric_key())));
3451-
}},
3452-
{"PS384",
3453-
[](const key& key) {
3454-
return std::unique_ptr<algo<jwt::algorithm::ps384>>(
3455-
new algo<jwt::algorithm::ps384>(jwt::algorithm::ps384(key.get_asymmetric_key())));
3456-
}},
3457-
{"PS512",
3458-
[](const key& key) {
3459-
return std::unique_ptr<algo<jwt::algorithm::ps512>>(
3460-
new algo<jwt::algorithm::ps512>(jwt::algorithm::ps512(key.get_asymmetric_key())));
3461-
}},
3462-
{"ES256",
3463-
[](const key& key) {
3464-
return std::unique_ptr<algo<jwt::algorithm::es256>>(
3465-
new algo<jwt::algorithm::es256>(jwt::algorithm::es256(key.get_asymmetric_key())));
3466-
}},
3467-
{"ES384",
3468-
[](const key& key) {
3469-
return std::unique_ptr<algo<jwt::algorithm::es384>>(
3470-
new algo<jwt::algorithm::es384>(jwt::algorithm::es384(key.get_asymmetric_key())));
3471-
}},
3472-
{"ES512",
3473-
[](const key& key) {
3474-
return std::unique_ptr<algo<jwt::algorithm::es512>>(
3475-
new algo<jwt::algorithm::es512>(jwt::algorithm::es512(key.get_asymmetric_key())));
3476-
}},
3477-
{"ES256K",
3478-
[](const key& key) {
3479-
return std::unique_ptr<algo<jwt::algorithm::es256k>>(
3480-
new algo<jwt::algorithm::es256k>(jwt::algorithm::es256k(key.get_asymmetric_key())));
3481-
}},
3482-
{"HS256",
3483-
[](const key& key) {
3484-
return std::unique_ptr<algo<jwt::algorithm::hs256>>(
3485-
new algo<jwt::algorithm::hs256>(jwt::algorithm::hs256(key.get_symmetric_key())));
3486-
}},
3487-
{"HS384",
3488-
[](const key& key) {
3489-
return std::unique_ptr<algo<jwt::algorithm::hs384>>(
3490-
new algo<jwt::algorithm::hs384>(jwt::algorithm::hs384(key.get_symmetric_key())));
3491-
}},
3492-
{"HS512",
3493-
[](const key& key) {
3494-
return std::unique_ptr<algo<jwt::algorithm::hs512>>(
3495-
new algo<jwt::algorithm::hs512>(jwt::algorithm::hs512(key.get_symmetric_key())));
3496-
}},
3430+
struct algorithm_db {
3431+
using builder_fn = std::function<std::unique_ptr<algo_base>(const key&)>;
3432+
using algname_to_builder_fn = std::map<std::string, builder_fn>;
3433+
enum type { empty, basic };
3434+
3435+
algorithm_db() : algorithm_db(empty) {}
3436+
algorithm_db(type t) {
3437+
if (t == empty) { supported_algorithms = {}; }
3438+
}
3439+
3440+
builder_fn create_algorithm(const std::string& name) const {
3441+
const auto algorithm = supported_algorithms.find(name);
3442+
if (algorithm != supported_algorithms.end()) { return algorithm->second; }
3443+
return nullptr;
3444+
}
3445+
3446+
void register_algorithm(const std::string& alg_name, builder_fn build_fn) {
3447+
supported_algorithms.insert_or_assign(alg_name, build_fn);
3448+
}
3449+
3450+
private:
3451+
algname_to_builder_fn supported_algorithms = {
3452+
{"RS256",
3453+
[](const key& key) {
3454+
return std::unique_ptr<algo<jwt::algorithm::rs256>>(
3455+
new algo<jwt::algorithm::rs256>(jwt::algorithm::rs256(key.get_asymmetric_key())));
3456+
}},
3457+
{"RS384",
3458+
[](const key& key) {
3459+
return std::unique_ptr<algo<jwt::algorithm::rs384>>(
3460+
new algo<jwt::algorithm::rs384>(jwt::algorithm::rs384(key.get_asymmetric_key())));
3461+
}},
3462+
{"RS512",
3463+
[](const key& key) {
3464+
return std::unique_ptr<algo<jwt::algorithm::rs512>>(
3465+
new algo<jwt::algorithm::rs512>(jwt::algorithm::rs512(key.get_asymmetric_key())));
3466+
}},
3467+
{"PS256",
3468+
[](const key& key) {
3469+
return std::unique_ptr<algo<jwt::algorithm::ps256>>(
3470+
new algo<jwt::algorithm::ps256>(jwt::algorithm::ps256(key.get_asymmetric_key())));
3471+
}},
3472+
{"PS384",
3473+
[](const key& key) {
3474+
return std::unique_ptr<algo<jwt::algorithm::ps384>>(
3475+
new algo<jwt::algorithm::ps384>(jwt::algorithm::ps384(key.get_asymmetric_key())));
3476+
}},
3477+
{"PS512",
3478+
[](const key& key) {
3479+
return std::unique_ptr<algo<jwt::algorithm::ps512>>(
3480+
new algo<jwt::algorithm::ps512>(jwt::algorithm::ps512(key.get_asymmetric_key())));
3481+
}},
3482+
{"ES256",
3483+
[](const key& key) {
3484+
return std::unique_ptr<algo<jwt::algorithm::es256>>(
3485+
new algo<jwt::algorithm::es256>(jwt::algorithm::es256(key.get_asymmetric_key())));
3486+
}},
3487+
{"ES384",
3488+
[](const key& key) {
3489+
return std::unique_ptr<algo<jwt::algorithm::es384>>(
3490+
new algo<jwt::algorithm::es384>(jwt::algorithm::es384(key.get_asymmetric_key())));
3491+
}},
3492+
{"ES512",
3493+
[](const key& key) {
3494+
return std::unique_ptr<algo<jwt::algorithm::es512>>(
3495+
new algo<jwt::algorithm::es512>(jwt::algorithm::es512(key.get_asymmetric_key())));
3496+
}},
3497+
{"ES256K",
3498+
[](const key& key) {
3499+
return std::unique_ptr<algo<jwt::algorithm::es256k>>(
3500+
new algo<jwt::algorithm::es256k>(jwt::algorithm::es256k(key.get_asymmetric_key())));
3501+
}},
3502+
{"HS256",
3503+
[](const key& key) {
3504+
return std::unique_ptr<algo<jwt::algorithm::hs256>>(
3505+
new algo<jwt::algorithm::hs256>(jwt::algorithm::hs256(key.get_symmetric_key())));
3506+
}},
3507+
{"HS384",
3508+
[](const key& key) {
3509+
return std::unique_ptr<algo<jwt::algorithm::hs384>>(
3510+
new algo<jwt::algorithm::hs384>(jwt::algorithm::hs384(key.get_symmetric_key())));
3511+
}},
3512+
{"HS512",
3513+
[](const key& key) {
3514+
return std::unique_ptr<algo<jwt::algorithm::hs512>>(
3515+
new algo<jwt::algorithm::hs512>(jwt::algorithm::hs512(key.get_symmetric_key())));
3516+
}},
3517+
};
34973518
};
34983519

34993520
/**
@@ -3523,15 +3544,9 @@ namespace jwt {
35233544
size_t default_leeway = 0;
35243545
/// Instance of clock type
35253546
Clock clock;
3526-
algorithm_builder supported_algorithms;
3547+
algorithm_db supported_algorithms;
35273548
/// Supported algorithms
35283549
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
3529-
using alg_name = std::string;
3530-
using alg_list = std::vector<alg_name>;
3531-
using algorithms = std::unordered_map<std::string, alg_list>;
3532-
algorithms supported_alg = {{"RSA", {"RS256", "RS384", "RS512", "PS256", "PS384", "PS512"}},
3533-
{"EC", {"ES256", "ES384", "ES512", "ES256K"}},
3534-
{"oct", {"HS256", "HS384", "HS512"}}};
35353550

35363551
typedef std::vector<jwt::jwk<json_traits>> key_list;
35373552
/// https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
@@ -3547,44 +3562,31 @@ namespace jwt {
35473562
}
35483563
}
35493564

3550-
bool is_valid_combination(const jwt::jwk<json_traits>& key, const std::string& alg_name) const {
3551-
const alg_list& x = supported_alg.find(key.get_key_type())->second;
3552-
return std::find(x.cbegin(), x.cend(), alg_name) != x.cend();
3565+
bool is_valid_combination(const std::string& key_type, const std::string& alg_name) const {
3566+
// TODO:mk check whether key type can be used with the algorithm
3567+
return true;
35533568
}
35543569

3555-
inline std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key,
3556-
const std::string& alg_name, std::error_code& ec) const {
3570+
std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key, const std::string& alg_name,
3571+
std::error_code& ec) const {
35573572
ec.clear();
3558-
algorithms::const_iterator it = supported_alg.find(key.get_key_type());
3559-
if (it == supported_alg.end()) {
3560-
ec = error::token_verification_error::wrong_algorithm;
3561-
return nullptr;
3562-
}
3563-
3564-
const alg_list& supported_jwt_algorithms = it->second;
3565-
if (std::find(supported_jwt_algorithms.begin(), supported_jwt_algorithms.end(), alg_name) ==
3566-
supported_jwt_algorithms.end()) {
3567-
ec = error::token_verification_error::wrong_algorithm;
3568-
return nullptr;
3569-
}
3570-
3571-
algorithm_builder::const_iterator alg = supported_algorithms.find(alg_name);
3572-
if (alg == supported_algorithms.end()) {
3573+
auto create = supported_algorithms.create_algorithm(alg_name);
3574+
if (create == nullptr) {
35733575
ec = error::token_verification_error::wrong_algorithm;
35743576
return nullptr;
35753577
}
35763578

3577-
return alg->second.operator()(key.get_key());
3579+
return create(key.get_key());
35783580
}
35793581

35803582
public:
35813583
/**
35823584
* Constructor for building a new verifier instance
35833585
* \param c Clock instance
35843586
*/
3585-
explicit verifier(Clock c) : verifier(c, default_verification_algorithms) {}
3587+
explicit verifier(Clock c) : verifier(c, algorithm_db(algorithm_db::basic)) {}
35863588

3587-
verifier(Clock c, algorithm_builder algorithms) : clock(c), supported_algorithms(algorithms) {
3589+
verifier(Clock c, algorithm_db algorithms) : clock(c), supported_algorithms(algorithms) {
35883590
claims["exp"] = [](const verify_ops::verify_context<json_traits>& ctx, std::error_code& ec) {
35893591
if (!ctx.jwt.has_expires_at()) return;
35903592
auto exp = ctx.jwt.get_expires_at();
@@ -3778,7 +3780,7 @@ namespace jwt {
37783780
if (key_set_it != keys.end()) {
37793781
const key_list& keys = key_set_it->second;
37803782
for (const auto& key : keys) {
3781-
if (is_valid_combination(key, algo)) {
3783+
if (is_valid_combination(key.get_key_type(), algo)) {
37823784
key_found = true;
37833785
auto alg = from_key_and_alg(key, algo, ec);
37843786
alg->verify(data, sig, ec);

tests/JwkTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ TEST(JwkTest, CustomAlgorithm) {
8989
void verify(const std::string& data, const std::string& sig, std::error_code& ec) {}
9090
};
9191

92-
auto my_verification_algorithms = jwt::default_verification_algorithms;
93-
my_verification_algorithms.insert_or_assign(std::string("my-custom-alg"), [](const jwt::key&) {
92+
jwt::algorithm_db my_verification_algorithms;
93+
my_verification_algorithms.register_algorithm("my-custom-alg", [](const jwt::key&) {
9494
return std::unique_ptr<jwt::algo<custom_verification_algorithm>>(
9595
new jwt::algo<custom_verification_algorithm>(custom_verification_algorithm()));
9696
});

0 commit comments

Comments
 (0)