diff --git a/include/jwt-cpp/base.h b/include/jwt-cpp/base.h index 4177ed68..cef493d1 100644 --- a/include/jwt-cpp/base.h +++ b/include/jwt-cpp/base.h @@ -1,9 +1,11 @@ #ifndef JWT_CPP_BASE_H #define JWT_CPP_BASE_H +#include #include #include #include +#include #ifdef __has_cpp_attribute #if __has_cpp_attribute(fallthrough) @@ -21,7 +23,10 @@ namespace jwt { */ namespace alphabet { /** - * \brief valid list of characted when working with [Base64](https://tools.ietf.org/html/rfc3548) + * \brief valid list of character when working with [Base64](https://datatracker.ietf.org/doc/html/rfc4648#section-4) + * + * As directed in [X.509 Parameter](https://datatracker.ietf.org/doc/html/rfc7517#section-4.7) certificate chains are + * base64-encoded as per [Section 4 of RFC4648](https://datatracker.ietf.org/doc/html/rfc4648#section-4) */ struct base64 { static const std::array& data() { @@ -38,7 +43,13 @@ namespace jwt { } }; /** - * \brief valid list of characted when working with [Base64URL](https://tools.ietf.org/html/rfc4648) + * \brief valid list of character when working with [Base64URL](https://tools.ietf.org/html/rfc4648#section-5) + * + * As directed by [RFC 7519 Terminology](https://datatracker.ietf.org/doc/html/rfc7519#section-2) set the definition of Base64URL + * encoding as that in [RFC 7515](https://datatracker.ietf.org/doc/html/rfc7515#section-2) that states: + * + * > Base64 encoding using the URL- and filename-safe character set defined in + * > [Section 5 of RFC 4648 RFC4648](https://tools.ietf.org/html/rfc4648#section-5), with all trailing '=' characters omitted */ struct base64url { static const std::array& data() { @@ -54,155 +65,205 @@ namespace jwt { return fill; } }; + namespace helper { + /** + * @brief A General purpose base64url alphabet respecting the + * [URI Case Normalization](https://datatracker.ietf.org/doc/html/rfc3986#section-6.2.2.1) + * + * This is useful in situations outside of JWT encoding/decoding and is provided as a helper + */ + struct base64url_percent_encoding { + static const std::array& data() { + static constexpr std::array data{ + {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}}; + return data; + } + static const std::initializer_list& fill() { + static std::initializer_list fill{"%3D", "%3d"}; + return fill; + } + }; + } // namespace helper + + inline uint32_t index(const std::array& alphabet, char symbol) { + auto itr = std::find_if(alphabet.cbegin(), alphabet.cend(), [symbol](char c) { return c == symbol; }); + if (itr == alphabet.cend()) { throw std::runtime_error("Invalid input: not within alphabet"); } + + return std::distance(alphabet.cbegin(), itr); + } } // namespace alphabet /** - * \brief Alphabet generic methods for working with encoding/decoding the base64 family + * \brief A collection of fellable functions for working with base64 and base64url */ - class base { - public: - template - static std::string encode(const std::string& bin) { - return encode(bin, T::data(), T::fill()); - } - template - static std::string decode(const std::string& base) { - return decode(base, T::data(), T::fill()); - } - template - static std::string pad(const std::string& base) { - return pad(base, T::fill()); - } - template - static std::string trim(const std::string& base) { - return trim(base, T::fill()); - } + namespace base { - private: - static std::string encode(const std::string& bin, const std::array& alphabet, - const std::string& fill) { - size_t size = bin.size(); - std::string res; + namespace details { + struct padding { + size_t count = 0; + size_t length = 0; - // clear incomplete bytes - size_t fast_size = size - size % 3; - for (size_t i = 0; i < fast_size;) { - uint32_t octet_a = static_cast(bin[i++]); - uint32_t octet_b = static_cast(bin[i++]); - uint32_t octet_c = static_cast(bin[i++]); + padding() = default; + padding(size_t count, size_t length) : count(count), length(length) {} - uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; + padding operator+(const padding& p) { return padding(count + p.count, length + p.length); } - res += alphabet[(triple >> 3 * 6) & 0x3F]; - res += alphabet[(triple >> 2 * 6) & 0x3F]; - res += alphabet[(triple >> 1 * 6) & 0x3F]; - res += alphabet[(triple >> 0 * 6) & 0x3F]; - } + friend bool operator==(const padding& lhs, const padding& rhs) { + return lhs.count == rhs.count && lhs.length == rhs.length; + } + }; + + inline padding count_padding(const std::string& base, const std::vector& fills) { + for (const auto& fill : fills) { + if (base.size() < fill.size()) continue; + // Does the end of the input exactly match the fill pattern? + if (base.substr(base.size() - fill.size()) == fill) { + return padding{1, fill.length()} + + count_padding(base.substr(0, base.size() - fill.size()), fills); + } + } - if (fast_size == size) return res; - - size_t mod = size % 3; - - uint32_t octet_a = fast_size < size ? static_cast(bin[fast_size++]) : 0; - uint32_t octet_b = fast_size < size ? static_cast(bin[fast_size++]) : 0; - uint32_t octet_c = fast_size < size ? static_cast(bin[fast_size++]) : 0; - - uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; - - switch (mod) { - case 1: - res += alphabet[(triple >> 3 * 6) & 0x3F]; - res += alphabet[(triple >> 2 * 6) & 0x3F]; - res += fill; - res += fill; - break; - case 2: - res += alphabet[(triple >> 3 * 6) & 0x3F]; - res += alphabet[(triple >> 2 * 6) & 0x3F]; - res += alphabet[(triple >> 1 * 6) & 0x3F]; - res += fill; - break; - default: break; + return {}; } - return res; - } + inline std::string encode(const std::string& bin, const std::array& alphabet, + const std::string& fill) { + size_t size = bin.size(); + std::string res; + + // clear incomplete bytes + size_t fast_size = size - size % 3; + for (size_t i = 0; i < fast_size;) { + uint32_t octet_a = static_cast(bin[i++]); + uint32_t octet_b = static_cast(bin[i++]); + uint32_t octet_c = static_cast(bin[i++]); + + uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; + + res += alphabet[(triple >> 3 * 6) & 0x3F]; + res += alphabet[(triple >> 2 * 6) & 0x3F]; + res += alphabet[(triple >> 1 * 6) & 0x3F]; + res += alphabet[(triple >> 0 * 6) & 0x3F]; + } + + if (fast_size == size) return res; - static std::string decode(const std::string& base, const std::array& alphabet, - const std::string& fill) { - size_t size = base.size(); - - size_t fill_cnt = 0; - while (size > fill.size()) { - if (base.substr(size - fill.size(), fill.size()) == fill) { - fill_cnt++; - size -= fill.size(); - if (fill_cnt > 2) throw std::runtime_error("Invalid input: too much fill"); - } else + size_t mod = size % 3; + + uint32_t octet_a = fast_size < size ? static_cast(bin[fast_size++]) : 0; + uint32_t octet_b = fast_size < size ? static_cast(bin[fast_size++]) : 0; + uint32_t octet_c = fast_size < size ? static_cast(bin[fast_size++]) : 0; + + uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; + + switch (mod) { + case 1: + res += alphabet[(triple >> 3 * 6) & 0x3F]; + res += alphabet[(triple >> 2 * 6) & 0x3F]; + res += fill; + res += fill; break; + case 2: + res += alphabet[(triple >> 3 * 6) & 0x3F]; + res += alphabet[(triple >> 2 * 6) & 0x3F]; + res += alphabet[(triple >> 1 * 6) & 0x3F]; + res += fill; + break; + default: break; + } + + return res; } - if ((size + fill_cnt) % 4 != 0) throw std::runtime_error("Invalid input: incorrect total size"); + inline std::string decode(const std::string& base, const std::array& alphabet, + const std::vector& fill) { + const auto pad = count_padding(base, fill); + if (pad.count > 2) throw std::runtime_error("Invalid input: too much fill"); - size_t out_size = size / 4 * 3; - std::string res; - res.reserve(out_size); + const size_t size = base.size() - pad.length; + if ((size + pad.count) % 4 != 0) throw std::runtime_error("Invalid input: incorrect total size"); - auto get_sextet = [&](size_t offset) { - for (size_t i = 0; i < alphabet.size(); i++) { - if (alphabet[i] == base[offset]) return static_cast(i); + size_t out_size = size / 4 * 3; + std::string res; + res.reserve(out_size); + + auto get_sextet = [&](size_t offset) { return alphabet::index(alphabet, base[offset]); }; + + size_t fast_size = size - size % 4; + for (size_t i = 0; i < fast_size;) { + uint32_t sextet_a = get_sextet(i++); + uint32_t sextet_b = get_sextet(i++); + uint32_t sextet_c = get_sextet(i++); + uint32_t sextet_d = get_sextet(i++); + + uint32_t triple = + (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); + + res += static_cast((triple >> 2 * 8) & 0xFFU); + res += static_cast((triple >> 1 * 8) & 0xFFU); + res += static_cast((triple >> 0 * 8) & 0xFFU); } - throw std::runtime_error("Invalid input: not within alphabet"); - }; - size_t fast_size = size - size % 4; - for (size_t i = 0; i < fast_size;) { - uint32_t sextet_a = get_sextet(i++); - uint32_t sextet_b = get_sextet(i++); - uint32_t sextet_c = get_sextet(i++); - uint32_t sextet_d = get_sextet(i++); + if (pad.count == 0) return res; + + uint32_t triple = (get_sextet(fast_size) << 3 * 6) + (get_sextet(fast_size + 1) << 2 * 6); - uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); + switch (pad.count) { + case 1: + triple |= (get_sextet(fast_size + 2) << 1 * 6); + res += static_cast((triple >> 2 * 8) & 0xFFU); + res += static_cast((triple >> 1 * 8) & 0xFFU); + break; + case 2: res += static_cast((triple >> 2 * 8) & 0xFFU); break; + default: break; + } - res += static_cast((triple >> 2 * 8) & 0xFFU); - res += static_cast((triple >> 1 * 8) & 0xFFU); - res += static_cast((triple >> 0 * 8) & 0xFFU); + return res; } - if (fill_cnt == 0) return res; + inline std::string decode(const std::string& base, const std::array& alphabet, + const std::string& fill) { + return decode(base, alphabet, std::vector{fill}); + } - uint32_t triple = (get_sextet(fast_size) << 3 * 6) + (get_sextet(fast_size + 1) << 2 * 6); + inline std::string pad(const std::string& base, const std::string& fill) { + std::string padding; + switch (base.size() % 4) { + case 1: padding += fill; JWT_FALLTHROUGH; + case 2: padding += fill; JWT_FALLTHROUGH; + case 3: padding += fill; JWT_FALLTHROUGH; + default: break; + } - switch (fill_cnt) { - case 1: - triple |= (get_sextet(fast_size + 2) << 1 * 6); - res += static_cast((triple >> 2 * 8) & 0xFFU); - res += static_cast((triple >> 1 * 8) & 0xFFU); - break; - case 2: res += static_cast((triple >> 2 * 8) & 0xFFU); break; - default: break; + return base + padding; } - return res; - } - - static std::string pad(const std::string& base, const std::string& fill) { - std::string padding; - switch (base.size() % 4) { - case 1: padding += fill; JWT_FALLTHROUGH; - case 2: padding += fill; JWT_FALLTHROUGH; - case 3: padding += fill; JWT_FALLTHROUGH; - default: break; + inline std::string trim(const std::string& base, const std::string& fill) { + auto pos = base.find(fill); + return base.substr(0, pos); } + } // namespace details - return base + padding; + template + std::string encode(const std::string& bin) { + return details::encode(bin, T::data(), T::fill()); } - - static std::string trim(const std::string& base, const std::string& fill) { - auto pos = base.find(fill); - return base.substr(0, pos); + template + std::string decode(const std::string& base) { + return details::decode(base, T::data(), T::fill()); + } + template + std::string pad(const std::string& base) { + return details::pad(base, T::fill()); + } + template + std::string trim(const std::string& base) { + return details::trim(base, T::fill()); } - }; + } // namespace base } // namespace jwt #endif diff --git a/tests/BaseTest.cpp b/tests/BaseTest.cpp index 7c302248..210798af 100644 --- a/tests/BaseTest.cpp +++ b/tests/BaseTest.cpp @@ -1,6 +1,48 @@ #include "jwt-cpp/base.h" #include +TEST(BaseTest, Base64Index) { + ASSERT_EQ(0, jwt::alphabet::index(jwt::alphabet::base64::data(), 'A')); + ASSERT_EQ(32, jwt::alphabet::index(jwt::alphabet::base64::data(), 'g')); + ASSERT_EQ(62, jwt::alphabet::index(jwt::alphabet::base64::data(), '+')); +} + +TEST(BaseTest, Base64URLIndex) { + ASSERT_EQ(0, jwt::alphabet::index(jwt::alphabet::base64url::data(), 'A')); + ASSERT_EQ(32, jwt::alphabet::index(jwt::alphabet::base64url::data(), 'g')); + ASSERT_EQ(62, jwt::alphabet::index(jwt::alphabet::base64url::data(), '-')); +} + +TEST(BaseTest, BaseDetailsCountPadding) { + using jwt::base::details::padding; + ASSERT_EQ(padding{}, jwt::base::details::count_padding("ABC", {"~"})); + ASSERT_EQ((padding{3, 3}), jwt::base::details::count_padding("ABC~~~", {"~"})); + ASSERT_EQ((padding{5, 5}), jwt::base::details::count_padding("ABC~~~~~", {"~"})); + + ASSERT_EQ(padding{}, jwt::base::details::count_padding("ABC", {"~", "!"})); + ASSERT_EQ((padding{1, 1}), jwt::base::details::count_padding("ABC!", {"~", "!"})); + ASSERT_EQ((padding{1, 1}), jwt::base::details::count_padding("ABC~", {"~", "!"})); + ASSERT_EQ((padding{3, 3}), jwt::base::details::count_padding("ABC~~!", {"~", "!"})); + ASSERT_EQ((padding{3, 3}), jwt::base::details::count_padding("ABC!~~", {"~", "!"})); + ASSERT_EQ((padding{5, 5}), jwt::base::details::count_padding("ABC~~!~~", {"~", "!"})); + + ASSERT_EQ((padding{2, 6}), jwt::base::details::count_padding("MTIzNA%3d%3d", {"%3d", "%3D"})); + ASSERT_EQ((padding{2, 6}), jwt::base::details::count_padding("MTIzNA%3d%3D", {"%3d", "%3D"})); + ASSERT_EQ((padding{2, 6}), jwt::base::details::count_padding("MTIzNA%3D%3d", {"%3d", "%3D"})); + ASSERT_EQ((padding{2, 6}), jwt::base::details::count_padding("MTIzNA%3D%3D", {"%3d", "%3D"})); + + // Some fake scenarios + + ASSERT_EQ(padding{}, jwt::base::details::count_padding("", {"~"})); + ASSERT_EQ(padding{}, jwt::base::details::count_padding("ABC", {"~", "~~!"})); + ASSERT_EQ(padding{}, jwt::base::details::count_padding("ABC!", {"~", "~~!"})); + ASSERT_EQ((padding{1, 1}), jwt::base::details::count_padding("ABC~", {"~", "~~!"})); + ASSERT_EQ((padding{1, 3}), jwt::base::details::count_padding("ABC~~!", {"~", "~~!"})); + ASSERT_EQ((padding{2, 2}), jwt::base::details::count_padding("ABC!~~", {"~", "~~!"})); + ASSERT_EQ((padding{3, 5}), jwt::base::details::count_padding("ABC~~!~~", {"~", "~~!"})); + ASSERT_EQ(padding{}, jwt::base::details::count_padding("ABC~~!~~", {})); +} + TEST(BaseTest, Base64Decode) { ASSERT_EQ("1", jwt::base::decode("MQ==")); ASSERT_EQ("12", jwt::base::decode("MTI=")); @@ -15,6 +57,16 @@ TEST(BaseTest, Base64DecodeURL) { ASSERT_EQ("1234", jwt::base::decode("MTIzNA%3d%3d")); } +TEST(BaseTest, Base64DecodeURLCaseInsensitive) { + ASSERT_EQ("1", jwt::base::decode("MQ%3d%3d")); + ASSERT_EQ("1", jwt::base::decode("MQ%3D%3d")); + ASSERT_EQ("1", jwt::base::decode("MQ%3d%3D")); + ASSERT_EQ("12", jwt::base::decode("MTI%3d")); + ASSERT_EQ("123", jwt::base::decode("MTIz")); + ASSERT_EQ("1234", jwt::base::decode("MTIzNA%3d%3d")); + ASSERT_EQ("1234", jwt::base::decode("MTIzNA%3D%3D")); +} + TEST(BaseTest, Base64Encode) { ASSERT_EQ("MQ==", jwt::base::encode("1")); ASSERT_EQ("MTI=", jwt::base::encode("12"));