diff --git a/protocol/transport/envelopes_monitor.go b/protocol/transport/envelopes_monitor.go index d5db3d42a..c839112cb 100644 --- a/protocol/transport/envelopes_monitor.go +++ b/protocol/transport/envelopes_monitor.go @@ -61,16 +61,20 @@ func NewEnvelopesMonitor(w types.Waku, config EnvelopesMonitorConfig) *Envelopes logger: logger.With(zap.Namespace("EnvelopesMonitor")), // key is envelope hash (event.Hash) - envelopes: map[types.Hash]EnvelopeState{}, - messages: map[types.Hash]*types.NewMessage{}, - attempts: map[types.Hash]int{}, - identifiers: make(map[types.Hash][][]byte), + envelopes: map[types.Hash]*monitoredEnvelope{}, // key is hash of the batch (event.Batch) batches: map[types.Hash]map[types.Hash]struct{}{}, } } +type monitoredEnvelope struct { + state EnvelopeState + attempts int + message *types.NewMessage + identifiers [][]byte +} + // EnvelopesMonitor is responsible for monitoring waku envelopes state. type EnvelopesMonitor struct { w types.Waku @@ -78,13 +82,10 @@ type EnvelopesMonitor struct { handler EnvelopeEventsHandler maxAttempts int - mu sync.Mutex - envelopes map[types.Hash]EnvelopeState - batches map[types.Hash]map[types.Hash]struct{} + mu sync.Mutex - messages map[types.Hash]*types.NewMessage - attempts map[types.Hash]int - identifiers map[types.Hash][][]byte + envelopes map[types.Hash]*monitoredEnvelope + batches map[types.Hash]map[types.Hash]struct{} awaitOnlyMailServerConfirmations bool @@ -115,28 +116,30 @@ func (m *EnvelopesMonitor) Stop() { func (m *EnvelopesMonitor) Add(identifiers [][]byte, envelopeHash types.Hash, message types.NewMessage) { m.mu.Lock() defer m.mu.Unlock() - m.identifiers[envelopeHash] = identifiers - // If it's already been marked as sent, we notify the client - if m.envelopes[envelopeHash] == EnvelopeSent { - if m.handler != nil { - m.handler.EnvelopeSent(m.identifiers[envelopeHash]) + + if envelope, ok := m.envelopes[envelopeHash]; !ok { + m.envelopes[envelopeHash] = &monitoredEnvelope{ + state: EnvelopePosted, + attempts: 1, + message: &message, + identifiers: identifiers, + } + } else if envelope.state == EnvelopeSent { + // If it's already been marked as sent, we notify the client + if m.handler != nil { + m.handler.EnvelopeSent(envelope.identifiers) } - } else { - // otherwise we keep track of the message - m.messages[envelopeHash] = &message - m.attempts[envelopeHash] = 1 - m.envelopes[envelopeHash] = EnvelopePosted } } func (m *EnvelopesMonitor) GetState(hash types.Hash) EnvelopeState { m.mu.Lock() defer m.mu.Unlock() - state, exist := m.envelopes[hash] + envelope, exist := m.envelopes[hash] if !exist { return NotRegistered } - return state + return envelope.state } // handleEnvelopeEvents processes waku envelope events @@ -184,17 +187,17 @@ func (m *EnvelopesMonitor) handleEventEnvelopeSent(event types.EnvelopeEvent) { confirmationExpected := event.Batch != (types.Hash{}) - state, ok := m.envelopes[event.Hash] + envelope, ok := m.envelopes[event.Hash] // If confirmations are not expected, we keep track of the envelope // being sent if !ok && !confirmationExpected { - m.envelopes[event.Hash] = EnvelopeSent + m.envelopes[event.Hash] = &monitoredEnvelope{state: EnvelopeSent} return } // if message was already confirmed - skip it - if state == EnvelopeSent { + if envelope.state == EnvelopeSent { return } m.logger.Debug("envelope is sent", zap.String("hash", event.Hash.String()), zap.String("peer", event.Peer.String())) @@ -206,9 +209,9 @@ func (m *EnvelopesMonitor) handleEventEnvelopeSent(event types.EnvelopeEvent) { m.logger.Debug("waiting for a confirmation", zap.String("batch", event.Batch.String())) } else { m.logger.Debug("confirmation not expected, marking as sent") - m.envelopes[event.Hash] = EnvelopeSent + envelope.state = EnvelopeSent if m.handler != nil { - m.handler.EnvelopeSent(m.identifiers[event.Hash]) + m.handler.EnvelopeSent(envelope.identifiers) } } } @@ -251,13 +254,13 @@ func (m *EnvelopesMonitor) handleAcknowledgedBatch(event types.EnvelopeEvent) { if _, exist := failedEnvelopes[hash]; exist { continue } - state, ok := m.envelopes[hash] - if !ok || state == EnvelopeSent { + envelope, ok := m.envelopes[hash] + if !ok || envelope.state == EnvelopeSent { continue } - m.envelopes[hash] = EnvelopeSent + envelope.state = EnvelopeSent if m.handler != nil { - m.handler.EnvelopeSent(m.identifiers[hash]) + m.handler.EnvelopeSent(envelope.identifiers) } } delete(m.batches, event.Batch) @@ -272,36 +275,32 @@ func (m *EnvelopesMonitor) handleEventEnvelopeExpired(event types.EnvelopeEvent) // handleEnvelopeFailure is a common code path for processing envelopes failures. not thread safe, lock // must be used on a higher level. func (m *EnvelopesMonitor) handleEnvelopeFailure(hash types.Hash, err error) { - if state, ok := m.envelopes[hash]; ok { - message, exist := m.messages[hash] - if !exist { - m.logger.Error("message was deleted erroneously", zap.String("envelope hash", hash.String())) - } - attempt := m.attempts[hash] - identifiers := m.identifiers[hash] + if envelope, ok := m.envelopes[hash]; ok { m.clearMessageState(hash) - if state == EnvelopeSent { + if envelope.state == EnvelopeSent { return } - if attempt < m.maxAttempts { - m.logger.Debug("retrying to send a message", zap.String("hash", hash.String()), zap.Int("attempt", attempt+1)) - hex, err := m.api.Post(context.TODO(), *message) + if envelope.attempts < m.maxAttempts { + m.logger.Debug("retrying to send a message", zap.String("hash", hash.String()), zap.Int("attempt", envelope.attempts+1)) + hex, err := m.api.Post(context.TODO(), *envelope.message) if err != nil { - m.logger.Error("failed to retry sending message", zap.String("hash", hash.String()), zap.Int("attempt", attempt+1), zap.Error(err)) + m.logger.Error("failed to retry sending message", zap.String("hash", hash.String()), zap.Int("attempt", envelope.attempts+1), zap.Error(err)) if m.handler != nil { - m.handler.EnvelopeExpired(identifiers, err) + m.handler.EnvelopeExpired(envelope.identifiers, err) } } envelopeID := types.BytesToHash(hex) - m.envelopes[envelopeID] = EnvelopePosted - m.messages[envelopeID] = message - m.attempts[envelopeID] = attempt + 1 - m.identifiers[envelopeID] = identifiers + m.envelopes[envelopeID] = &monitoredEnvelope{ + state: EnvelopePosted, + attempts: envelope.attempts + 1, + message: envelope.message, + identifiers: envelope.identifiers, + } } else { m.logger.Debug("envelope expired", zap.String("hash", hash.String())) if m.handler != nil { - m.handler.EnvelopeExpired(identifiers, err) + m.handler.EnvelopeExpired(envelope.identifiers, err) } } } @@ -313,14 +312,14 @@ func (m *EnvelopesMonitor) handleEventEnvelopeReceived(event types.EnvelopeEvent } m.mu.Lock() defer m.mu.Unlock() - state, ok := m.envelopes[event.Hash] - if !ok || state != EnvelopePosted { + envelope, ok := m.envelopes[event.Hash] + if !ok || envelope.state != EnvelopePosted { return } m.logger.Debug("expected envelope received", zap.String("hash", event.Hash.String()), zap.String("peer", event.Peer.String())) - m.envelopes[event.Hash] = EnvelopeSent + envelope.state = EnvelopeSent if m.handler != nil { - m.handler.EnvelopeSent(m.identifiers[event.Hash]) + m.handler.EnvelopeSent(envelope.identifiers) } } @@ -328,7 +327,4 @@ func (m *EnvelopesMonitor) handleEventEnvelopeReceived(event types.EnvelopeEvent // not thread-safe, should be protected on a higher level. func (m *EnvelopesMonitor) clearMessageState(envelopeID types.Hash) { delete(m.envelopes, envelopeID) - delete(m.messages, envelopeID) - delete(m.attempts, envelopeID) - delete(m.identifiers, envelopeID) } diff --git a/protocol/transport/envelopes_monitor_test.go b/protocol/transport/envelopes_monitor_test.go index 826ebf722..3fbbe3d19 100644 --- a/protocol/transport/envelopes_monitor_test.go +++ b/protocol/transport/envelopes_monitor_test.go @@ -44,13 +44,13 @@ func (s *EnvelopesMonitorSuite) SetupTest() { func (s *EnvelopesMonitorSuite) TestEnvelopePosted() { s.monitor.Add(testIDs, testHash, types.NewMessage{}) s.Contains(s.monitor.envelopes, testHash) - s.Equal(EnvelopePosted, s.monitor.envelopes[testHash]) + s.Equal(EnvelopePosted, s.monitor.envelopes[testHash].state) s.monitor.handleEvent(types.EnvelopeEvent{ Event: types.EventEnvelopeSent, Hash: testHash, }) s.Contains(s.monitor.envelopes, testHash) - s.Equal(EnvelopeSent, s.monitor.envelopes[testHash]) + s.Equal(EnvelopeSent, s.monitor.envelopes[testHash].state) } func (s *EnvelopesMonitorSuite) TestEnvelopePostedOutOfOrder() { @@ -61,7 +61,7 @@ func (s *EnvelopesMonitorSuite) TestEnvelopePostedOutOfOrder() { s.monitor.Add(testIDs, testHash, types.NewMessage{}) s.Require().Contains(s.monitor.envelopes, testHash) - s.Require().Equal(EnvelopeSent, s.monitor.envelopes[testHash]) + s.Require().Equal(EnvelopeSent, s.monitor.envelopes[testHash].state) } func (s *EnvelopesMonitorSuite) TestConfirmedWithAcknowledge() { @@ -71,20 +71,20 @@ func (s *EnvelopesMonitorSuite) TestConfirmedWithAcknowledge() { node := enode.NewV4(&pkey.PublicKey, nil, 0, 0) s.monitor.Add(testIDs, testHash, types.NewMessage{}) s.Contains(s.monitor.envelopes, testHash) - s.Equal(EnvelopePosted, s.monitor.envelopes[testHash]) + s.Equal(EnvelopePosted, s.monitor.envelopes[testHash].state) s.monitor.handleEvent(types.EnvelopeEvent{ Event: types.EventEnvelopeSent, Hash: testHash, Batch: testBatch, }) - s.Equal(EnvelopePosted, s.monitor.envelopes[testHash]) + s.Equal(EnvelopePosted, s.monitor.envelopes[testHash].state) s.monitor.handleEvent(types.EnvelopeEvent{ Event: types.EventBatchAcknowledged, Batch: testBatch, Peer: types.EnodeID(node.ID()), }) s.Contains(s.monitor.envelopes, testHash) - s.Equal(EnvelopeSent, s.monitor.envelopes[testHash]) + s.Equal(EnvelopeSent, s.monitor.envelopes[testHash].state) } func (s *EnvelopesMonitorSuite) TestRemoved() {