diff --git a/oxenmq/address.h b/oxenmq/address.h index 6be27cc..5ffb2d7 100644 --- a/oxenmq/address.h +++ b/oxenmq/address.h @@ -31,6 +31,7 @@ #include #include #include +#include namespace oxenmq { @@ -206,4 +207,12 @@ struct address { // Outputs address.full_address() when sent to an ostream. std::ostream& operator<<(std::ostream& o, const address& a); -} +} // namespace oxenmq + +namespace std { +template<> struct hash { + std::size_t operator()(const oxenmq::address& a) const noexcept { + return std::hash{}(a.full_address(oxenmq::address::encoding::hex)); + } +}; +} // namespace std diff --git a/tests/test_address.cpp b/tests/test_address.cpp index 4e9bf52..7a901f7 100644 --- a/tests/test_address.cpp +++ b/tests/test_address.cpp @@ -129,3 +129,34 @@ TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") { REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument); } +TEST_CASE("address hashing", "[address][hash]") { + address a{"tcp://public.loki.foundation:12345"}; + address b{"tcp+curve://public.loki.foundation:12345/" + pk_hex}; + address c{"ipc:///tmp/some.sock"}; + address d{"ipc:///tmp/some.other.sock"}; + + std::hash hasher{}; + REQUIRE( hasher(a) != hasher(b) ); + REQUIRE( hasher(a) != hasher(c) ); + REQUIRE( hasher(a) != hasher(d) ); + REQUIRE( hasher(b) != hasher(c) ); + REQUIRE( hasher(b) != hasher(d) ); + REQUIRE( hasher(c) != hasher(d) ); + + std::unordered_set set; + set.insert(a); + set.insert(b); + set.insert(c); + set.insert(d); + + CHECK( set.size() == 4 ); + std::unordered_map count; + for (const auto& addr : set) + count[addr]++; + + REQUIRE( count.size() == 4 ); + CHECK( count[a] == 1 ); + CHECK( count[b] == 1 ); + CHECK( count[c] == 1 ); + CHECK( count[d] == 1 ); +}