Make from_{hex,base32z,base64} compatible with std::byte

Make the char handling a bit more generic so that std::byte (or other
size-1 types) will work.
This commit is contained in:
Jason Rhinelander 2020-05-12 19:38:05 -03:00
parent de395af872
commit 1f60abf50e
4 changed files with 51 additions and 22 deletions

View File

@ -75,7 +75,7 @@ static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 &&
/// Converts bytes into a base32z encoded character sequence. /// Converts bytes into a base32z encoded character sequence.
template <typename InputIt, typename OutputIt> template <typename InputIt, typename OutputIt>
void to_base32z(InputIt begin, InputIt end, OutputIt out) { void to_base32z(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(*begin) == 1, "to_base32z requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes");
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4] int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4]
std::uint_fast16_t r = 0; std::uint_fast16_t r = 0;
while (begin != end) { while (begin != end) {
@ -113,9 +113,9 @@ inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); }
/// Returns true if all elements in the range are base32z characters /// Returns true if all elements in the range are base32z characters
template <typename It> template <typename It>
constexpr bool is_base32z(It begin, It end) { constexpr bool is_base32z(It begin, It end) {
static_assert(sizeof(*begin) == 1, "is_base32z requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes");
for (; begin != end; ++begin) { for (; begin != end; ++begin) {
auto c = *begin; auto c = static_cast<unsigned char>(*begin);
if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y')) if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y'))
return false; return false;
} }
@ -140,15 +140,14 @@ constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); }
/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0". /// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0".
template <typename InputIt, typename OutputIt> template <typename InputIt, typename OutputIt>
void from_base32z(InputIt begin, InputIt end, OutputIt out) { void from_base32z(InputIt begin, InputIt end, OutputIt out) {
using Char = decltype(*begin); static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes");
static_assert(sizeof(Char) == 1, "from_base32z requires chars/bytes");
uint_fast16_t curr = 0; uint_fast16_t curr = 0;
int bits = 0; // number of bits we've loaded into val; we always keep this < 8. int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
while (begin != end) { while (begin != end) {
curr = curr << 5 | detail::b32z_lut.from_b32z(*begin++); curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast<unsigned char>(*begin++));
if (bits >= 3) { if (bits >= 3) {
bits -= 3; // Added 5, removing 8 bits -= 3; // Added 5, removing 8
*out++ = static_cast<Char>(curr >> bits); *out++ = static_cast<uint8_t>(curr >> bits);
curr &= (1 << bits) - 1; curr &= (1 << bits) - 1;
} else { } else {
bits += 5; bits += 5;

View File

@ -77,7 +77,7 @@ static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_
/// Converts bytes into a base64 encoded character sequence. /// Converts bytes into a base64 encoded character sequence.
template <typename InputIt, typename OutputIt> template <typename InputIt, typename OutputIt>
void to_base64(InputIt begin, InputIt end, OutputIt out) { void to_base64(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(*begin) == 1, "to_base64 requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes");
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4} int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4}
std::uint_fast16_t r = 0; std::uint_fast16_t r = 0;
while (begin != end) { while (begin != end) {
@ -130,21 +130,21 @@ inline std::string to_base64(std::string_view s) { return to_base64<>(s); }
/// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4. /// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4.
template <typename It> template <typename It>
constexpr bool is_base64(It begin, It end) { constexpr bool is_base64(It begin, It end) {
static_assert(sizeof(*begin) == 1, "is_base64 requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes");
using std::distance; using std::distance;
using std::prev; using std::prev;
// Allow 1 or 2 padding chars *if* they pad it to a multiple of 4. // Allow 1 or 2 padding chars *if* they pad it to a multiple of 4.
if (begin != end && distance(begin, end) % 4 == 0) { if (begin != end && distance(begin, end) % 4 == 0) {
auto last = prev(end); auto last = prev(end);
if (*last == '=') if (static_cast<unsigned char>(*last) == '=')
end = last--; end = last--;
if (*last == '=') if (static_cast<unsigned char>(*last) == '=')
end = last; end = last;
} }
for (; begin != end; ++begin) { for (; begin != end; ++begin) {
auto c = *begin; auto c = static_cast<unsigned char>(*begin);
if (detail::b64_lut.from_b64(c) == 0 && c != 'A') if (detail::b64_lut.from_b64(c) == 0 && c != 'A')
return false; return false;
} }
@ -169,12 +169,11 @@ constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end
/// the last 4 bits of the last character are essentially considered padding. /// the last 4 bits of the last character are essentially considered padding.
template <typename InputIt, typename OutputIt> template <typename InputIt, typename OutputIt>
void from_base64(InputIt begin, InputIt end, OutputIt out) { void from_base64(InputIt begin, InputIt end, OutputIt out) {
using Char = decltype(*begin); static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes");
static_assert(sizeof(Char) == 1, "from_base64 requires chars/bytes");
uint_fast16_t curr = 0; uint_fast16_t curr = 0;
int bits = 0; // number of bits we've loaded into val; we always keep this < 8. int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
while (begin != end) { while (begin != end) {
Char c = *begin++; auto c = static_cast<unsigned char>(*begin++);
// padding; don't bother checking if we're at the end because is_base64 is a precondition // padding; don't bother checking if we're at the end because is_base64 is a precondition
// and we're allowed UB if it isn't satisfied. // and we're allowed UB if it isn't satisfied.
@ -185,7 +184,7 @@ void from_base64(InputIt begin, InputIt end, OutputIt out) {
bits = 6; bits = 6;
else { else {
bits -= 2; // Added 6, removing 8 bits -= 2; // Added 6, removing 8
*out++ = static_cast<Char>(curr >> bits); *out++ = static_cast<uint8_t>(curr >> bits);
curr &= (1 << bits) - 1; curr &= (1 << bits) - 1;
} }
} }

View File

@ -63,10 +63,10 @@ static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_
/// Creates hex digits from a character sequence. /// Creates hex digits from a character sequence.
template <typename InputIt, typename OutputIt> template <typename InputIt, typename OutputIt>
void to_hex(InputIt begin, InputIt end, OutputIt out) { void to_hex(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(*begin) == 1, "to_hex requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes");
for (; begin != end; ++begin) { for (; begin != end; ++begin) {
auto c = *begin; uint8_t c = static_cast<uint8_t>(*begin);
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4); *out++ = detail::hex_lut.to_hex(c >> 4);
*out++ = detail::hex_lut.to_hex(c & 0x0f); *out++ = detail::hex_lut.to_hex(c & 0x0f);
} }
} }
@ -84,9 +84,9 @@ inline std::string to_hex(std::string_view s) { return to_hex<>(s); }
/// Returns true if all elements in the range are hex characters /// Returns true if all elements in the range are hex characters
template <typename It> template <typename It>
constexpr bool is_hex(It begin, It end) { constexpr bool is_hex(It begin, It end) {
static_assert(sizeof(*begin) == 1, "is_hex requires chars/bytes"); static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes");
for (; begin != end; ++begin) { for (; begin != end; ++begin) {
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0') if (detail::hex_lut.from_hex(static_cast<unsigned char>(*begin)) == 0 && static_cast<unsigned char>(*begin) != '0')
return false; return false;
} }
return true; return true;
@ -115,7 +115,7 @@ void from_hex(InputIt begin, InputIt end, OutputIt out) {
while (begin != end) { while (begin != end) {
auto a = *begin++; auto a = *begin++;
auto b = *begin++; auto b = *begin++;
*out++ = from_hex_pair(a, b); *out++ = from_hex_pair(static_cast<unsigned char>(a), static_cast<unsigned char>(b));
} }
} }

View File

@ -27,6 +27,17 @@ TEST_CASE("hex encoding/decoding", "[encoding][decoding][hex]") {
REQUIRE( lokimq::from_hex(pk_hex) == pk ); REQUIRE( lokimq::from_hex(pk_hex) == pk );
REQUIRE( lokimq::to_hex(pk) == pk_hex ); REQUIRE( lokimq::to_hex(pk) == pk_hex );
std::vector<std::byte> bytes{{std::byte{0xff}, std::byte{0x42}, std::byte{0x12}, std::byte{0x34}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_hex(b) == "ff421234"s );
bytes.resize(8);
bytes[0] = std::byte{'f'}; bytes[1] = std::byte{'f'}; bytes[2] = std::byte{'4'}; bytes[3] = std::byte{'2'};
bytes[4] = std::byte{'1'}; bytes[5] = std::byte{'2'}; bytes[6] = std::byte{'3'}; bytes[7] = std::byte{'4'};
std::basic_string_view<std::byte> hex_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_hex(hex_bytes) );
REQUIRE( lokimq::from_hex(hex_bytes) == "\xff\x42\x12\x34" );
} }
TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") { TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") {
@ -67,6 +78,16 @@ TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") {
REQUIRE( lokimq::to_base32z(pk) == pk_b32z ); REQUIRE( lokimq::to_base32z(pk) == pk_b32z );
REQUIRE( lokimq::from_base32z(pk_b32z) == pk ); REQUIRE( lokimq::from_base32z(pk_b32z) == pk );
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_base32z(b) == "yd9o" );
bytes.resize(4);
bytes[0] = std::byte{'y'}; bytes[1] = std::byte{'d'}; bytes[2] = std::byte{'9'}; bytes[3] = std::byte{'o'};
std::basic_string_view<std::byte> b32_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_base32z(b32_bytes) );
REQUIRE( lokimq::from_base32z(b32_bytes) == "\x00\xff"sv );
} }
TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") {
@ -139,4 +160,14 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") {
REQUIRE( lokimq::to_base64(pk) == pk_b64 ); REQUIRE( lokimq::to_base64(pk) == pk_b64 );
REQUIRE( lokimq::from_base64(pk_b64) == pk ); REQUIRE( lokimq::from_base64(pk_b64) == pk );
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_base64(b) == "AP8=" );
bytes.resize(4);
bytes[0] = std::byte{'/'}; bytes[1] = std::byte{'w'}; bytes[2] = std::byte{'A'}; bytes[3] = std::byte{'='};
std::basic_string_view<std::byte> b64_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_base64(b64_bytes) );
REQUIRE( lokimq::from_base64(b64_bytes) == "\xff\x00"sv );
} }