Skip to content

Support custom algorithms with jwks #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: support-loading-jwk-from-json
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 130 additions & 90 deletions include/jwt-cpp/jwt.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <functional>
#include <iterator>
#include <locale>
#include <map>
#include <memory>
#include <set>
#include <system_error>
Expand Down Expand Up @@ -3418,18 +3419,16 @@ namespace jwt {

bool empty() const noexcept { return jwk_claims.empty(); }

helper::evp_pkey_handle get_pkey() const { return k.get_asymmetric_key(); }

std::string get_oct_key() const { return k.get_symmetric_key(); }
key get_key() const { return k; }

private:
template<typename Decode>
static helper::evp_pkey_handle build_rsa_key(const details::map_of_claims<json_traits>& claims,
Decode&& decode) {
EVP_PKEY* evp_key = nullptr;
auto n = jwt::helper::raw2bn(decode(claims.get_claim("n").as_string()));
auto e = jwt::helper::raw2bn(decode(claims.get_claim("e").as_string()));
#ifdef JWT_OPENSSL_3_0
EVP_PKEY* evp_key = nullptr;
// https://www.openssl.org/docs/manmaster/man7/EVP_PKEY-RSA.html
// see https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html
// and https://stackoverflow.com/questions/68465716/how-to-properly-create-an-rsa-key-from-raw-data-in-openssl-3-0-in-c-language
Expand All @@ -3439,25 +3438,39 @@ namespace jwt {

std::unique_ptr<OSSL_PARAM_BLD, decltype(&OSSL_PARAM_BLD_free)> params_build(OSSL_PARAM_BLD_new(),
OSSL_PARAM_BLD_free);
OSSL_PARAM_BLD_push_BN(params_build.get(), "n", n.get());
OSSL_PARAM_BLD_push_BN(params_build.get(), "e", e.get());
if (!params_build) { throw std::runtime_error("OSSL_PARAM_BLD_new failed"); }
if (OSSL_PARAM_BLD_push_BN(params_build.get(), "n", n.get()) != 1) {
throw std::runtime_error("OSSL_PARAM_BLD_push_BN failed");
}
if (OSSL_PARAM_BLD_push_BN(params_build.get(), "e", e.get()) != 1) {
throw std::runtime_error("OSSL_PARAM_BLD_push_BN failed");
}

std::unique_ptr<OSSL_PARAM, decltype(&OSSL_PARAM_free)> params(OSSL_PARAM_BLD_to_param(params_build.get()),
OSSL_PARAM_free);
EVP_PKEY_fromdata_init(ctx.get());
EVP_PKEY_fromdata(ctx.get(), &evp_key, EVP_PKEY_PUBLIC_KEY, params.get());
if (!params) { throw std::runtime_error("OSSL_PARAM_BLD_to_param failed"); }
if (EVP_PKEY_fromdata_init(ctx.get()) != 1) { throw std::runtime_error("EVP_PKEY_fromdata_init failed"); }
if (EVP_PKEY_fromdata(ctx.get(), &evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) != 1) {
throw std::runtime_error("EVP_PKEY_fromdata failed");
}
return helper::evp_pkey_handle(evp_key);
#else
RSA* rsa = RSA_new();
evp_key = EVP_PKEY_new();
std::unique_ptr<RSA, decltype(&RSA_free)> rsa(RSA_new(), RSA_free);
if (!rsa) { throw std::runtime_error("RSA_new failed"); }
#if defined(JWT_OPENSSL_1_0_0) && !defined(LIBWOLFSSL_VERSION_HEX)
rsa->e = e.release();
rsa->n = n.release();
#else
RSA_set0_key(rsa, n.release(), e.release(), nullptr);
if (RSA_set0_key(rsa.get(), n.release(), e.release(), nullptr) != 1) {
throw std::runtimeruntime_error("RSA_set0_key failed");
}
#endif
EVP_PKEY_assign_RSA(evp_key, rsa);
return helper::evp_pkey_handle(evp_key);
std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(EVP_PKEY_new(), EVP_PKEY_free);
if (EVP_PKEY_assign_RSA(evp_key.get(), rsa.get()) != 1) {
throw std::runtime_error("EVP_PKEY_assign_RSA failed");
}
rsa.release();
return helper::evp_pkey_handle(evp_key.release());
#endif
}

Expand Down Expand Up @@ -3489,6 +3502,97 @@ namespace jwt {
key k;
};

struct algo_base {
virtual ~algo_base() = default;
virtual void verify(const std::string& data, const std::string& sig, std::error_code& ec) = 0;
};
template<typename T>
struct algo : public algo_base {
T alg;
explicit algo(T a) : alg(a) {}
void verify(const std::string& data, const std::string& sig, std::error_code& ec) override {
alg.verify(data, sig, ec);
}
};

struct algorithm_db {
using builder_fn = std::function<std::unique_ptr<algo_base>(const key&)>;
using algname_to_builder_fn = std::map<std::string, builder_fn>;
enum type { empty, basic };

algorithm_db() : algorithm_db(empty) {}
algorithm_db(type t) {
if (t == empty) { supported_algorithms.clear(); }
}

builder_fn create_algorithm(const std::string& name) const {
const auto algorithm = supported_algorithms.find(name);
if (algorithm != supported_algorithms.end()) { return algorithm->second; }
return nullptr;
}

void register_algorithm(const std::string& alg_name, builder_fn build_fn) {
supported_algorithms.insert_or_assign(alg_name, build_fn);
}

private:
algname_to_builder_fn supported_algorithms = {
{"RS256",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256(key.get_asymmetric_key()));
}},
{"RS384",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384(key.get_asymmetric_key()));
}},
{"RS512",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512(key.get_asymmetric_key()));
}},
{"PS256",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256(key.get_asymmetric_key()));
}},
{"PS384",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384(key.get_asymmetric_key()));
}},
{"PS512",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512(key.get_asymmetric_key()));
}},
{"ES256",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256(key.get_asymmetric_key()));
}},
{"ES384",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384(key.get_asymmetric_key()));
}},
{"ES512",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512(key.get_asymmetric_key()));
}},
{"ES256K",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::es256k>>(
jwt::algorithm::es256k(key.get_asymmetric_key()));
}},
{"HS256",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256(key.get_symmetric_key()));
}},
{"HS384",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384(key.get_symmetric_key()));
}},
{"HS512",
[](const key& key) {
return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512(key.get_symmetric_key()));
}},
};
};

/**
* Verifier class used to check if a decoded token contains all claims required by your application and has a valid
* signature.
Expand All @@ -3510,32 +3614,15 @@ namespace jwt {
std::function<void(const verify_ops::verify_context<json_traits>&, std::error_code& ec)>;

private:
struct algo_base {
virtual ~algo_base() = default;
virtual void verify(const std::string& data, const std::string& sig, std::error_code& ec) = 0;
};
template<typename T>
struct algo : public algo_base {
T alg;
explicit algo(T a) : alg(a) {}
void verify(const std::string& data, const std::string& sig, std::error_code& ec) override {
alg.verify(data, sig, ec);
}
};
/// Required claims
std::unordered_map<typename json_traits::string_type, verify_check_fn_t> claims;
/// Leeway time for exp, nbf and iat
size_t default_leeway = 0;
/// Instance of clock type
Clock clock;
algorithm_db supported_algorithms;
/// Supported algorithms
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
using alg_name = std::string;
using alg_list = std::vector<alg_name>;
using algorithms = std::unordered_map<std::string, alg_list>;
algorithms supported_alg = {{"RSA", {"RS256", "RS384", "RS512", "PS256", "PS384", "PS512"}},
{"EC", {"ES256", "ES384", "ES512", "ES256K"}},
{"oct", {"HS256", "HS384", "HS512"}}};

typedef std::vector<jwt::jwk<json_traits>> key_list;
/// https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
Expand All @@ -3551,78 +3638,31 @@ namespace jwt {
}
}

bool is_valid_combination(const jwt::jwk<json_traits>& key, const std::string& alg_name) const {
const alg_list& x = supported_alg.find(key.get_key_type())->second;
return std::find(x.cbegin(), x.cend(), alg_name) != x.cend();
bool is_valid_combination(const std::string& key_type, const std::string& alg_name) const {
// TODO:mk check whether key type can be used with the algorithm
return true;
}

inline std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key,
const std::string& alg_name, std::error_code& ec) const {
std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key, const std::string& alg_name,
std::error_code& ec) const {
ec.clear();
algorithms::const_iterator it = supported_alg.find(key.get_key_type());
if (it == supported_alg.end()) {
ec = error::token_verification_error::wrong_algorithm;
return nullptr;
}

const alg_list& supported_jwt_algorithms = it->second;
if (std::find(supported_jwt_algorithms.begin(), supported_jwt_algorithms.end(), alg_name) ==
supported_jwt_algorithms.end()) {
auto create = supported_algorithms.create_algorithm(alg_name);
if (create == nullptr) {
ec = error::token_verification_error::wrong_algorithm;
return nullptr;
}

if (alg_name == "RS256") {
return std::unique_ptr<algo<jwt::algorithm::rs256>>(
new algo<jwt::algorithm::rs256>(jwt::algorithm::rs256(key.get_pkey())));
} else if (alg_name == "RS384") {
return std::unique_ptr<algo<jwt::algorithm::rs384>>(
new algo<jwt::algorithm::rs384>(jwt::algorithm::rs384(key.get_pkey())));
} else if (alg_name == "RS512") {
return std::unique_ptr<algo<jwt::algorithm::rs512>>(
new algo<jwt::algorithm::rs512>(jwt::algorithm::rs512(key.get_pkey())));
} else if (alg_name == "PS256") {
return std::unique_ptr<algo<jwt::algorithm::ps256>>(
new algo<jwt::algorithm::ps256>(jwt::algorithm::ps256(key.get_pkey())));
} else if (alg_name == "PS384") {
return std::unique_ptr<algo<jwt::algorithm::ps384>>(
new algo<jwt::algorithm::ps384>(jwt::algorithm::ps384(key.get_pkey())));
} else if (alg_name == "PS512") {
return std::unique_ptr<algo<jwt::algorithm::ps512>>(
new algo<jwt::algorithm::ps512>(jwt::algorithm::ps512(key.get_pkey())));
} else if (alg_name == "ES256") {
return std::unique_ptr<algo<jwt::algorithm::es256>>(
new algo<jwt::algorithm::es256>(jwt::algorithm::es256(key.get_pkey())));
} else if (alg_name == "ES384") {
return std::unique_ptr<algo<jwt::algorithm::es384>>(
new algo<jwt::algorithm::es384>(jwt::algorithm::es384(key.get_pkey())));
} else if (alg_name == "ES512") {
return std::unique_ptr<algo<jwt::algorithm::es512>>(
new algo<jwt::algorithm::es512>(jwt::algorithm::es512(key.get_pkey())));
} else if (alg_name == "ES256K") {
return std::unique_ptr<algo<jwt::algorithm::es256k>>(
new algo<jwt::algorithm::es256k>(jwt::algorithm::es256k(key.get_pkey())));
} else if (alg_name == "HS256") {
return std::unique_ptr<algo<jwt::algorithm::hs256>>(
new algo<jwt::algorithm::hs256>(jwt::algorithm::hs256(key.get_oct_key())));
} else if (alg_name == "HS384") {
return std::unique_ptr<algo<jwt::algorithm::hs384>>(
new algo<jwt::algorithm::hs384>(jwt::algorithm::hs384(key.get_oct_key())));
} else if (alg_name == "HS512") {
return std::unique_ptr<algo<jwt::algorithm::hs512>>(
new algo<jwt::algorithm::hs512>(jwt::algorithm::hs512(key.get_oct_key())));
}

ec = error::token_verification_error::wrong_algorithm;
return nullptr;
return create(key.get_key());
}

public:
/**
* Constructor for building a new verifier instance
* \param c Clock instance
*/
explicit verifier(Clock c) : clock(c) {
explicit verifier(Clock c) : verifier(c, algorithm_db(algorithm_db::basic)) {}

verifier(Clock c, algorithm_db algorithms) : clock(c), supported_algorithms(algorithms) {
claims["exp"] = [](const verify_ops::verify_context<json_traits>& ctx, std::error_code& ec) {
if (!ctx.jwt.has_expires_at()) return;
auto exp = ctx.jwt.get_expires_at();
Expand Down Expand Up @@ -3821,7 +3861,7 @@ namespace jwt {
if (key_set_it != keys.end()) {
const key_list& keys = key_set_it->second;
for (const auto& key : keys) {
if (is_valid_combination(key, algo)) {
if (is_valid_combination(key.get_key_type(), algo)) {
key_found = true;
auto alg = from_key_and_alg(key, algo, ec);
alg->verify(data, sig, ec);
Expand Down
25 changes: 25 additions & 0 deletions tests/JwkTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,28 @@ TEST(JwkTest, HmacKey) {
auto decoded_token = jwt::decode(token);
ASSERT_NO_THROW(verifier.verify(decoded_token));
}

TEST(JwkTest, CustomAlgorithm) {
// {"alg":"my-custom-alg","typ":"JWS"}.{"iss":"auth0"}.valid_signature
std::string token = "eyJhbGciOiJteS1jdXN0b20tYWxnIiwidHlwIjoiSldTIn0.eyJpc3MiOiJhdXRoMCJ9.dmFsaWRfc2lnbmF0dXJl";
std::string secret_key = R"({
"kty": "oct",
"k": "c2VjcmV0"
})";

struct custom_verification_algorithm {
void verify(const std::string& data, const std::string& sig, std::error_code& ec) {}
};

jwt::algorithm_db my_verification_algorithms;
my_verification_algorithms.register_algorithm("my-custom-alg", [](const jwt::key&) {
return std::make_unique<jwt::algo<custom_verification_algorithm>>(custom_verification_algorithm());
});
auto verifier = jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson>(jwt::default_clock(),
my_verification_algorithms);

auto jwk = jwt::parse_jwk(secret_key);
verifier.allow_key(jwk);
auto decoded_token = jwt::decode(token);
ASSERT_NO_THROW(verifier.verify(decoded_token));
}