diff --git a/mailserver/cleaner_test.go b/mailserver/cleaner_test.go index ca84dbb25..8926dd526 100644 --- a/mailserver/cleaner_test.go +++ b/mailserver/cleaner_test.go @@ -87,7 +87,7 @@ func testPrune(t *testing.T, u time.Time, expected int, c *Cleaner, s *WMailServ require.NoError(t, err) count := countMessages(t, s.db) - require.Equal(t, expected, count, fmt.Sprintf("expected %d message, got: %d", expected, count)) + require.Equal(t, expected, count) } func testMessagesCount(t *testing.T, expected int, s *WMailServer) { diff --git a/mailserver/limiter_test.go b/mailserver/limiter_test.go index a4a8662db..0e3905785 100644 --- a/mailserver/limiter_test.go +++ b/mailserver/limiter_test.go @@ -14,7 +14,6 @@ func TestIsAllowed(t *testing.T) { t time.Duration shouldBeAllowed bool db func() map[string]time.Time - errMsg string info string }{ { @@ -23,8 +22,7 @@ func TestIsAllowed(t *testing.T) { db: func() map[string]time.Time { return make(map[string]time.Time) }, - errMsg: "Expected limiter not to allow with empty db", - info: "Expecting limiter.isAllowed to not allow with an empty db", + info: "Expecting limiter.isAllowed to allow with an empty db", }, { t: 5 * time.Millisecond, @@ -34,8 +32,7 @@ func TestIsAllowed(t *testing.T) { db[peerID] = time.Now().Add(time.Duration(-10) * time.Millisecond) return db }, - errMsg: "Expected limiter to allow with peer on its db", - info: "Expecting limiter.isAllowed to allow with an expired peer on its db", + info: "Expecting limiter.isAllowed to allow with an expired peer on its db", }, { t: 5 * time.Millisecond, @@ -45,8 +42,7 @@ func TestIsAllowed(t *testing.T) { db[peerID] = time.Now().Add(time.Duration(-1) * time.Millisecond) return db }, - errMsg: "Expected limiter to not allow with peer on its db", - info: "Expecting limiter.isAllowed to not allow with a non expired peer on its db", + info: "Expecting limiter.isAllowed to not allow with a non expired peer on its db", }, } @@ -54,7 +50,7 @@ func TestIsAllowed(t *testing.T) { t.Run(tc.info, func(*testing.T) { l := newLimiter(tc.t) l.db = tc.db() - assert.Equal(t, tc.shouldBeAllowed, l.isAllowed(peerID), tc.errMsg) + assert.Equal(t, tc.shouldBeAllowed, l.isAllowed(peerID)) }) } } diff --git a/mailserver/mailserver.go b/mailserver/mailserver.go index cf8034fe9..a354b7424 100644 --- a/mailserver/mailserver.go +++ b/mailserver/mailserver.go @@ -101,8 +101,7 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e // setupLimiter in case limit is bigger than 0 it will setup an automated // limit db cleanup. -func (s *WMailServer) setupLimiter(rateLimit time.Duration) { - limit := rateLimit * time.Second +func (s *WMailServer) setupLimiter(limit time.Duration) { if limit > 0 { s.limit = newLimiter(limit) s.setupMailServerCleanup(limit) @@ -165,24 +164,27 @@ func (s *WMailServer) DeliverMail(peer *whisper.Peer, request *whisper.Envelope) log.Error("Whisper peer is nil") return } - s.managePeerLimits(peer.ID()) + if s.exceedsPeerRequests(peer.ID()) { + return + } if ok, lower, upper, bloom := s.validateRequest(peer.ID(), request); ok { s.processRequest(peer, lower, upper, bloom) } } -// managePeerLimits in case limit its been setup on the current server and limit +// exceedsPeerRequests in case limit its been setup on the current server and limit // allows the query, it will store/update new query time for the current peer. -func (s *WMailServer) managePeerLimits(peer []byte) { +func (s *WMailServer) exceedsPeerRequests(peer []byte) bool { if s.limit != nil { peerID := string(peer) if !s.limit.isAllowed(peerID) { log.Info("peerID exceeded the number of requests per second") - return + return true } s.limit.add(peerID) } + return false } // processRequest processes the current request and re-sends all stored messages diff --git a/mailserver/mailserver_test.go b/mailserver/mailserver_test.go index bef2bf38b..e4401ec9f 100644 --- a/mailserver/mailserver_test.go +++ b/mailserver/mailserver_test.go @@ -17,7 +17,6 @@ package mailserver import ( - "bytes" "crypto/ecdsa" "encoding/binary" "errors" @@ -144,12 +143,12 @@ func (s *MailserverSuite) TestArchive() { func (s *MailserverSuite) TestManageLimits() { s.server.limit = newLimiter(time.Duration(5) * time.Millisecond) - s.server.managePeerLimits([]byte("peerID")) + s.False(s.server.exceedsPeerRequests([]byte("peerID"))) s.Equal(1, len(s.server.limit.db)) firstSaved := s.server.limit.db["peerID"] // second call when limit is not accomplished does not store a new limit - s.server.managePeerLimits([]byte("peerID")) + s.True(s.server.exceedsPeerRequests([]byte("peerID"))) s.Equal(1, len(s.server.limit.db)) s.Equal(firstSaved, s.server.limit.db["peerID"]) } @@ -164,128 +163,104 @@ func (s *MailserverSuite) TestDBKey() { } func (s *MailserverSuite) TestMailServer() { - var server WMailServer - - s.setupServer(&server) - defer server.Close() + s.setupServer(s.server) + defer s.server.Close() env, err := generateEnvelope(time.Now()) s.NoError(err) - server.Archive(env) + s.server.Archive(env) testCases := []struct { - params *ServerTestParams - emptyLow bool - lowModifier int32 - uppModifier int32 - topic byte - expect bool - shouldFail bool - info string + params *ServerTestParams + expect bool + isOK bool + info string }{ { - params: s.defaultServerParams(env), - lowModifier: 0, - uppModifier: 0, - expect: true, - shouldFail: false, - info: "Processing a request where from and to are equals to an existing register, should provide results", + params: s.defaultServerParams(env), + expect: true, + isOK: true, + info: "Processing a request where from and to are equal to an existing register, should provide results", }, { - params: s.defaultServerParams(env), - lowModifier: 1, - uppModifier: 1, - expect: false, - shouldFail: false, - info: "Processing a request where from and to are great than any existing register, should not provide results", + params: func() *ServerTestParams { + params := s.defaultServerParams(env) + params.low = params.birth + 1 + params.upp = params.birth + 1 + + return params + }(), + expect: false, + isOK: true, + info: "Processing a request where from and to are greater than any existing register, should not provide results", }, { - params: s.defaultServerParams(env), - lowModifier: 0, - uppModifier: 1, - topic: 0xFF, - expect: false, - shouldFail: false, - info: "Processing a request where to is grat than any existing register and with a specific topic, should not provide results", + params: func() *ServerTestParams { + params := s.defaultServerParams(env) + params.upp = params.birth + 1 + params.topic[0] = 0xFF + + return params + }(), + expect: false, + isOK: true, + info: "Processing a request where to is greater than any existing register and with a specific topic, should not provide results", }, { - params: s.defaultServerParams(env), - emptyLow: true, - lowModifier: 4, - uppModifier: -1, - shouldFail: true, - info: "Processing a request where to is lower than from should fail", + params: func() *ServerTestParams { + params := s.defaultServerParams(env) + params.low = 0 + params.upp = params.birth - 1 + + return params + }(), + isOK: false, + info: "Processing a request where to is lower than from should fail", }, { - params: s.defaultServerParams(env), - emptyLow: true, - lowModifier: 0, - uppModifier: 24, - shouldFail: true, - info: "Processing a request where difference between from and to is > 24 should fail", + params: func() *ServerTestParams { + params := s.defaultServerParams(env) + params.low = 0 + params.upp = params.birth + 24 + + return params + }(), + isOK: false, + info: "Processing a request where difference between from and to is > 24 should fail", }, } for _, tc := range testCases { s.T().Run(tc.info, func(*testing.T) { - if tc.lowModifier != 0 { - tc.params.low = tc.params.birth + uint32(tc.lowModifier) - } - if tc.uppModifier != 0 { - tc.params.upp = tc.params.birth + uint32(tc.uppModifier) - } - if tc.emptyLow { - tc.params.low = 0 - } - if tc.topic == 0xFF { - tc.params.topic[0] = tc.topic - } - request := s.createRequest(tc.params) src := crypto.FromECDSAPub(&tc.params.key.PublicKey) - ok, lower, upper, bloom := server.validateRequest(src, request) - if tc.shouldFail { - if ok { - s.T().Fatal(err) - } - return - } - if !ok { - s.T().Fatalf("request validation failed, seed: %d.", seed) - } - if lower != tc.params.low { - s.T().Fatalf("request validation failed (lower bound), seed: %d.", seed) - } - if upper != tc.params.upp { - s.T().Fatalf("request validation failed (upper bound), seed: %d.", seed) - } - expectedBloom := whisper.TopicToBloom(tc.params.topic) - if !bytes.Equal(bloom, expectedBloom) { - s.T().Fatalf("request validation failed (topic), seed: %d.", seed) - } + ok, lower, upper, bloom := s.server.validateRequest(src, request) + s.Equal(tc.isOK, ok) + if ok { + s.Equal(tc.params.low, lower) + s.Equal(tc.params.upp, upper) + s.Equal(whisper.TopicToBloom(tc.params.topic), bloom) + s.Equal(tc.expect, s.messageExists(env, tc.params.low, tc.params.upp, bloom)) - var exist bool - mail := server.processRequest(nil, tc.params.low, tc.params.upp, bloom) - for _, msg := range mail { - if msg.Hash() == env.Hash() { - exist = true - break - } - } - - if exist != tc.expect { - s.T().Fatalf("error: exist = %v, seed: %d.", exist, seed) - } - - src[0]++ - ok, lower, upper, _ = server.validateRequest(src, request) - if !ok { - // request should be valid regardless of signature - s.T().Fatalf("request validation false negative, seed: %d (lower: %d, upper: %d).", seed, lower, upper) + src[0]++ + ok, _, _, _ = s.server.validateRequest(src, request) + s.True(ok) } }) } } +func (s *MailserverSuite) messageExists(envelope *whisper.Envelope, low, upp uint32, bloom []byte) bool { + var exist bool + mail := s.server.processRequest(nil, low, upp, bloom) + for _, msg := range mail { + if msg.Hash() == envelope.Hash() { + exist = true + break + } + } + return exist +} + func (s *MailserverSuite) TestBloomFromReceivedMessage() { testCases := []struct { msg whisper.ReceivedMessage