mailserver: refactor mailserver's rate limiter (#1341)

This commit is contained in:
Adam Babik 2019-01-10 17:07:16 +01:00 committed by GitHub
parent a84dee4934
commit 8f2e347e4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 147 deletions

65
mailserver/db_key.go Normal file
View file

@ -0,0 +1,65 @@
package mailserver
import (
"encoding/binary"
"errors"
"github.com/ethereum/go-ethereum/common"
)
const (
// DBKeyLength is a size of the envelope key.
DBKeyLength = common.HashLength + timestampLength
)
var (
// ErrInvalidByteSize is returned when DBKey can't be created
// from a byte slice because it has invalid length.
ErrInvalidByteSize = errors.New("byte slice has invalid length")
)
// DBKey key to be stored in a db.
type DBKey struct {
timestamp uint32
hash common.Hash
raw []byte
}
// Bytes returns a bytes representation of the DBKey.
func (k *DBKey) Bytes() []byte {
return k.raw
}
// NewDBKey creates a new DBKey with the given values.
func NewDBKey(timestamp uint32, h common.Hash) *DBKey {
var k DBKey
k.timestamp = timestamp
k.hash = h
k.raw = make([]byte, DBKeyLength)
binary.BigEndian.PutUint32(k.raw, k.timestamp)
copy(k.raw[4:], k.hash[:])
return &k
}
// NewDBKeyFromBytes creates a DBKey from a byte slice.
func NewDBKeyFromBytes(b []byte) (*DBKey, error) {
if len(b) != DBKeyLength {
return nil, ErrInvalidByteSize
}
return &DBKey{
raw: b,
timestamp: binary.BigEndian.Uint32(b),
hash: common.BytesToHash(b[4:]),
}, nil
}
// mustNewDBKeyFromBytes panics if creating a key from a byte slice fails.
// Check if a byte slice has DBKeyLength length before using it.
func mustNewDBKeyFromBytes(b []byte) *DBKey {
k, err := NewDBKeyFromBytes(b)
if err != nil {
panic(err)
}
return k
}

View file

@ -5,45 +5,83 @@ import (
"time"
)
type limiter struct {
mu sync.RWMutex
type rateLimiter struct {
sync.RWMutex
timeout time.Duration
db map[string]time.Time
lifespan time.Duration // duration of the limit
db map[string]time.Time
period time.Duration
cancel chan struct{}
}
func newLimiter(timeout time.Duration) *limiter {
return &limiter{
timeout: timeout,
db: make(map[string]time.Time),
func newRateLimiter(duration time.Duration) *rateLimiter {
return &rateLimiter{
lifespan: duration,
db: make(map[string]time.Time),
period: time.Second,
}
}
func (l *limiter) add(id string) {
l.mu.Lock()
defer l.mu.Unlock()
func (l *rateLimiter) Start() {
cancel := make(chan struct{})
l.db[id] = time.Now()
l.Lock()
l.cancel = cancel
l.Unlock()
go l.cleanUp(l.period, cancel)
}
func (l *limiter) isAllowed(id string) bool {
l.mu.RLock()
defer l.mu.RUnlock()
func (l *rateLimiter) Stop() {
l.Lock()
defer l.Unlock()
if l.cancel == nil {
return
}
close(l.cancel)
l.cancel = nil
}
func (l *rateLimiter) Add(id string) {
l.Lock()
l.db[id] = time.Now()
l.Unlock()
}
func (l *rateLimiter) IsAllowed(id string) bool {
l.RLock()
defer l.RUnlock()
if lastRequestTime, ok := l.db[id]; ok {
return lastRequestTime.Add(l.timeout).Before(time.Now())
return lastRequestTime.Add(l.lifespan).Before(time.Now())
}
return true
}
func (l *limiter) deleteExpired() {
l.mu.Lock()
defer l.mu.Unlock()
func (l *rateLimiter) cleanUp(period time.Duration, cancel <-chan struct{}) {
t := time.NewTicker(period)
defer t.Stop()
for {
select {
case <-t.C:
l.deleteExpired()
case <-cancel:
return
}
}
}
func (l *rateLimiter) deleteExpired() {
l.Lock()
defer l.Unlock()
now := time.Now()
for id, lastRequestTime := range l.db {
if lastRequestTime.Add(l.timeout).Before(now) {
if lastRequestTime.Add(l.lifespan).Before(now) {
delete(l.db, id)
}
}

View file

@ -48,16 +48,16 @@ func TestIsAllowed(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.info, func(*testing.T) {
l := newLimiter(tc.t)
l := newRateLimiter(tc.t)
l.db = tc.db()
assert.Equal(t, tc.shouldBeAllowed, l.isAllowed(peerID))
assert.Equal(t, tc.shouldBeAllowed, l.IsAllowed(peerID))
})
}
}
func TestRemoveExpiredRateLimits(t *testing.T) {
peer := "peer"
l := newLimiter(time.Duration(5) * time.Second)
l := newRateLimiter(time.Duration(5) * time.Second)
for i := 0; i < 10; i++ {
peerID := fmt.Sprintf("%s%d", peer, i)
l.db[peerID] = time.Now().Add(time.Duration(i*(-2)) * time.Second)
@ -78,11 +78,31 @@ func TestRemoveExpiredRateLimits(t *testing.T) {
}
}
func TestCleaningUpExpiredRateLimits(t *testing.T) {
l := newRateLimiter(5 * time.Second)
l.period = time.Millisecond * 10
l.Start()
defer l.Stop()
l.db["peer01"] = time.Now().Add(-1 * time.Second)
l.db["peer02"] = time.Now().Add(-2 * time.Second)
l.db["peer03"] = time.Now().Add(-10 * time.Second)
time.Sleep(time.Millisecond * 20)
_, ok := l.db["peer01"]
assert.True(t, ok)
_, ok = l.db["peer02"]
assert.True(t, ok)
_, ok = l.db["peer03"]
assert.False(t, ok)
}
func TestAddingLimts(t *testing.T) {
peerID := "peerAdding"
l := newLimiter(time.Duration(5) * time.Second)
l := newRateLimiter(time.Duration(5) * time.Second)
pre := time.Now()
l.add(peerID)
l.Add(peerID)
post := time.Now()
assert.True(t, l.db[peerID].After(pre))
assert.True(t, l.db[peerID].Before(post))

View file

@ -27,7 +27,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp"
"github.com/status-im/status-go/db"
"github.com/status-im/status-go/params"
@ -46,28 +45,9 @@ const (
var (
errDirectoryNotProvided = errors.New("data directory not provided")
errDecryptionMethodNotProvided = errors.New("decryption method is not provided")
// By default go-ethereum/metrics creates dummy metrics that don't register anything.
// Real metrics are collected only if -metrics flag is set
requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil)
requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil)
requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil)
requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil)
requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil)
sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil)
sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil)
archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil)
archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil)
archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil)
requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil)
processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil)
historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil)
syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil)
)
const (
// DBKeyLength is a size of the envelope key.
DBKeyLength = common.HashLength + timestampLength
timestampLength = 4
requestLimitLength = 4
requestTimeRangeLength = timestampLength * 2
@ -95,41 +75,8 @@ type WMailServer struct {
symFilter *whisper.Filter
asymFilter *whisper.Filter
muLimiter sync.RWMutex
limiter *limiter
tick *ticker
}
// DBKey key to be stored on db.
type DBKey struct {
timestamp uint32
hash common.Hash
raw []byte
}
// Bytes returns a bytes representation of the DBKey.
func (k *DBKey) Bytes() []byte {
return k.raw
}
// NewDBKey creates a new DBKey with the given values.
func NewDBKey(t uint32, h common.Hash) *DBKey {
var k DBKey
k.timestamp = t
k.hash = h
k.raw = make([]byte, DBKeyLength)
binary.BigEndian.PutUint32(k.raw, k.timestamp)
copy(k.raw[4:], k.hash[:])
return &k
}
// NewDBKeyFromBytes creates a DBKey from a byte slice.
func NewDBKeyFromBytes(b []byte) *DBKey {
return &DBKey{
raw: b,
timestamp: binary.BigEndian.Uint32(b),
hash: common.BytesToHash(b[4:]),
}
muRateLimiter sync.RWMutex
rateLimiter *rateLimiter
}
// Init initializes mailServer.
@ -150,7 +97,7 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e
if err := s.setupRequestMessageDecryptor(config); err != nil {
return err
}
s.setupLimiter(time.Duration(config.MailServerRateLimit) * time.Second)
s.setupRateLimiter(time.Duration(config.MailServerRateLimit) * time.Second)
// Open database in the last step in order not to init with error
// and leave the database open by accident.
@ -163,12 +110,12 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e
return nil
}
// setupLimiter in case limit is bigger than 0 it will setup an automated
// setupRateLimiter in case limit is bigger than 0 it will setup an automated
// limit db cleanup.
func (s *WMailServer) setupLimiter(limit time.Duration) {
func (s *WMailServer) setupRateLimiter(limit time.Duration) {
if limit > 0 {
s.limiter = newLimiter(limit)
s.setupMailServerCleanup(limit)
s.rateLimiter = newRateLimiter(limit)
s.rateLimiter.Start()
}
}
@ -203,15 +150,6 @@ func (s *WMailServer) setupRequestMessageDecryptor(config *params.WhisperConfig)
return nil
}
// setupMailServerCleanup periodically runs an expired entries deleteion for
// stored limits.
func (s *WMailServer) setupMailServerCleanup(period time.Duration) {
if s.tick == nil {
s.tick = &ticker{}
}
go s.tick.run(period, s.limiter.deleteExpired)
}
// Close the mailserver and its associated db connection.
func (s *WMailServer) Close() {
if s.db != nil {
@ -219,8 +157,8 @@ func (s *WMailServer) Close() {
log.Error(fmt.Sprintf("s.db.Close failed: %s", err))
}
}
if s.tick != nil {
s.tick.stop()
if s.rateLimiter != nil {
s.rateLimiter.Stop()
}
}
@ -450,18 +388,21 @@ func (s *WMailServer) SyncMail(peer *whisper.Peer, request whisper.SyncMailReque
// exceedsPeerRequests in case limit its been setup on the current server and limit
// allows the query, it will store/update new query time for the current peer.
func (s *WMailServer) exceedsPeerRequests(peer []byte) bool {
s.muLimiter.RLock()
defer s.muLimiter.RUnlock()
s.muRateLimiter.RLock()
defer s.muRateLimiter.RUnlock()
if s.limiter != nil {
peerID := string(peer)
if !s.limiter.isAllowed(peerID) {
log.Info("peerID exceeded the number of requests per second")
return true
}
s.limiter.add(peerID)
if s.rateLimiter == nil {
return false
}
return false
peerID := string(peer)
if s.rateLimiter.IsAllowed(peerID) {
s.rateLimiter.Add(peerID)
return false
}
log.Info("peerID exceeded the number of requests per second")
return true
}
func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterator.Iterator {
@ -472,7 +413,7 @@ func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterato
kl = NewDBKey(lower, emptyHash)
if len(cursor) == DBKeyLength {
ku = NewDBKeyFromBytes(cursor)
ku = mustNewDBKeyFromBytes(cursor)
} else {
ku = NewDBKey(upper+1, emptyHash)
}

View file

@ -173,7 +173,7 @@ func (s *MailserverSuite) TestInit() {
}
if tc.config.MailServerRateLimit > 0 {
s.NotNil(mailServer.limiter)
s.NotNil(mailServer.rateLimiter)
}
})
}
@ -273,15 +273,15 @@ func (s *MailserverSuite) TestArchive() {
}
func (s *MailserverSuite) TestManageLimits() {
s.server.limiter = newLimiter(time.Duration(5) * time.Millisecond)
s.server.rateLimiter = newRateLimiter(time.Duration(5) * time.Millisecond)
s.False(s.server.exceedsPeerRequests([]byte("peerID")))
s.Equal(1, len(s.server.limiter.db))
firstSaved := s.server.limiter.db["peerID"]
s.Equal(1, len(s.server.rateLimiter.db))
firstSaved := s.server.rateLimiter.db["peerID"]
// second call when limit is not accomplished does not store a new limit
s.True(s.server.exceedsPeerRequests([]byte("peerID")))
s.Equal(1, len(s.server.limiter.db))
s.Equal(firstSaved, s.server.limiter.db["peerID"])
s.Equal(1, len(s.server.rateLimiter.db))
s.Equal(firstSaved, s.server.rateLimiter.db["peerID"])
}
func (s *MailserverSuite) TestDBKey() {

22
mailserver/metrics.go Normal file
View file

@ -0,0 +1,22 @@
package mailserver
import "github.com/ethereum/go-ethereum/metrics"
var (
// By default go-ethereum/metrics creates dummy metrics that don't register anything.
// Real metrics are collected only if -metrics flag is set
requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil)
requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil)
requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil)
requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil)
requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil)
sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil)
sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil)
archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil)
archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil)
archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil)
requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil)
processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil)
historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil)
syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil)
)

View file

@ -1,33 +0,0 @@
package mailserver
import (
"sync"
"time"
)
type ticker struct {
mu sync.RWMutex
timeTicker *time.Ticker
}
func (t *ticker) run(period time.Duration, fn func()) {
if t.timeTicker != nil {
return
}
tt := time.NewTicker(period)
t.mu.Lock()
t.timeTicker = tt
t.mu.Unlock()
go func() {
for range tt.C {
fn()
}
}()
}
func (t *ticker) stop() {
t.mu.RLock()
t.timeTicker.Stop()
t.mu.RUnlock()
}

View file

@ -448,6 +448,9 @@ func (s *PeerPoolSimulationSuite) TestUpdateTopicLimits() {
func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() {
s.setupEthV5()
// eliminate peer we won't use
s.peers[2].Stop()
// Buffered channels must be used because we expect the events
// to be in the same order. Use a buffer length greater than
// the expected number of events to avoid deadlock.
@ -515,5 +518,4 @@ func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() {
disconnectedPeer := s.getPeerFromEvent(events, p2p.PeerEventTypeDrop)
s.Equal(s.peers[0].Self().ID().String(), disconnectedPeer.String())
s.Equal(signal.EventDiscoverySummary, s.getPoolEvent(poolEvents))
s.Len(<-summaries, 0)
}