From 90d54b1a3d8f18e69c1b36342f2bcd4ff4703846 Mon Sep 17 00:00:00 2001 From: Samuel Hawksby-Robinson Date: Wed, 15 Feb 2023 15:50:30 +0000 Subject: [PATCH] Added timeout functionality to Servers (#3192) * Added timeout functionality to servers currently only possible on the pairnig serve * Removed logging (like a mad man) * handling linter erroring --- server/pairing/payload_manager.go | 3 + server/pairing/server.go | 7 +- server/pairing/server_pairing_test.go | 33 ++++++++++ server/server.go | 25 ++++++-- server/timeout.go | 92 +++++++++++++++++++++++++++ server/timeout_test.go | 68 ++++++++++++++++++++ 6 files changed, 221 insertions(+), 7 deletions(-) create mode 100644 server/timeout.go create mode 100644 server/timeout_test.go diff --git a/server/pairing/payload_manager.go b/server/pairing/payload_manager.go index 2c005955d..5ac61e3c3 100644 --- a/server/pairing/payload_manager.go +++ b/server/pairing/payload_manager.go @@ -61,6 +61,9 @@ type PayloadSourceConfig struct { // they are required in other cases KeyUID string `json:"keyUID"` Password string `json:"password"` + + // Timeout the number of milliseconds after which the pairing server will automatically terminate + Timeout uint `json:"timeout"` } // AccountPayloadManagerConfig represents the initialisation parameters required for a AccountPayloadManager diff --git a/server/pairing/server.go b/server/pairing/server.go index 8f28ae839..4115531ee 100644 --- a/server/pairing/server.go +++ b/server/pairing/server.go @@ -76,7 +76,7 @@ func NewPairingServer(backend *api.GethStatusBackend, config *Config) (*Server, return nil, err } - return &Server{Server: server.NewServer( + s := &Server{Server: server.NewServer( config.Cert, config.Hostname, nil, @@ -88,7 +88,10 @@ func NewPairingServer(backend *api.GethStatusBackend, config *Config) (*Server, PayloadManager: pm, cookieStore: cs, rawMessagePayloadManager: rmpm, - }, nil + } + s.SetTimeout(config.Timeout) + + return s, nil } // MakeConnectionParams generates a *ConnectionParams based on the Server's current state diff --git a/server/pairing/server_pairing_test.go b/server/pairing/server_pairing_test.go index 67c0e320e..e17a0a48f 100644 --- a/server/pairing/server_pairing_test.go +++ b/server/pairing/server_pairing_test.go @@ -41,6 +41,39 @@ func (s *PairingServerSuite) TestMultiBackgroundForeground() { s.Require().Regexp(regexp.MustCompile("(https://\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}:\\d{1,5})"), s.PS.MakeBaseURL().String()) // nolint: gosimple } +func (s *PairingServerSuite) TestMultiTimeout() { + s.PS.SetTimeout(20) + + err := s.PS.Start() + s.Require().NoError(err) + + s.PS.ToBackground() + s.PS.ToForeground() + s.PS.ToBackground() + s.PS.ToBackground() + s.PS.ToForeground() + s.PS.ToForeground() + + s.Require().Regexp(regexp.MustCompile("(https://\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}:\\d{1,5})"), s.PS.MakeBaseURL().String()) // nolint: gosimple + + time.Sleep(7 * time.Millisecond) + s.PS.ToBackground() + time.Sleep(7 * time.Millisecond) + s.PS.ToForeground() + time.Sleep(7 * time.Millisecond) + s.PS.ToBackground() + time.Sleep(7 * time.Millisecond) + s.PS.ToBackground() + time.Sleep(7 * time.Millisecond) + s.PS.ToForeground() + time.Sleep(7 * time.Millisecond) + s.PS.ToForeground() + + // Wait for timeout to expire + time.Sleep(40 * time.Millisecond) + s.Require().False(s.PS.IsRunning()) +} + func (s *PairingServerSuite) TestPairingServer_StartPairing() { // Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralAES) diff --git a/server/server.go b/server/server.go index ad71cb94b..4fda7d8eb 100644 --- a/server/server.go +++ b/server/server.go @@ -18,15 +18,18 @@ type Server struct { cert *tls.Certificate hostname string handlers HandlerPatternMap + portManger + *timeoutManager } func NewServer(cert *tls.Certificate, hostname string, afterPortChanged func(int), logger *zap.Logger) Server { return Server{ - logger: logger, - cert: cert, - hostname: hostname, - portManger: newPortManager(logger.Named("Server"), afterPortChanged), + logger: logger, + cert: cert, + hostname: hostname, + portManger: newPortManager(logger.Named("Server"), afterPortChanged), + timeoutManager: newTimeoutManager(), } } @@ -73,6 +76,13 @@ func (s *Server) listenAndServe() { s.isRunning = true + s.StartTimeout(func() { + err := s.Stop() + if err != nil { + s.logger.Error("PairingServer termination fail", zap.Error(err)) + } + }) + err = s.server.Serve(listener) if err != http.ErrServerClosed { s.logger.Error("server failed unexpectedly, restarting", zap.Error(err)) @@ -82,11 +92,11 @@ func (s *Server) listenAndServe() { } return } - s.isRunning = false } func (s *Server) resetServer() { + s.StopTimeout() s.server = new(http.Server) s.ResetPort() } @@ -112,6 +122,7 @@ func (s *Server) Start() error { } func (s *Server) Stop() error { + s.StopTimeout() if s.server != nil { return s.server.Shutdown(context.Background()) } @@ -119,6 +130,10 @@ func (s *Server) Stop() error { return nil } +func (s *Server) IsRunning() bool { + return s.isRunning +} + func (s *Server) ToForeground() { if !s.isRunning && (s.server != nil) { err := s.Start() diff --git a/server/timeout.go b/server/timeout.go new file mode 100644 index 000000000..170c9814d --- /dev/null +++ b/server/timeout.go @@ -0,0 +1,92 @@ +package server + +import ( + "sync" + "time" +) + +// timeoutManager represents a discrete encapsulation of timeout functionality. +// this struct expose 3 functions: +// - SetTimeout +// - StartTimeout +// - StopTimeout +type timeoutManager struct { + // timeout number of milliseconds the timeout operation will run before executing the `terminate` func() + // 0 represents an inactive timeout + timeout uint + + // exitQueue handles the cancel signal channels that circumvent timeout operations and prevent the + // execution of any `terminate` func() + exitQueue *exitQueueManager +} + +// newTimeoutManager returns a fully qualified and initialised timeoutManager +func newTimeoutManager() *timeoutManager { + return &timeoutManager{ + exitQueue: &exitQueueManager{queue: []chan struct{}{}}, + } +} + +// SetTimeout sets the value of the timeoutManager.timeout +func (t *timeoutManager) SetTimeout(milliseconds uint) { + t.timeout = milliseconds +} + +// StartTimeout starts a timeout operation based on the set timeoutManager.timeout value +// the given terminate func() will be executed once the timeout duration has passed +func (t *timeoutManager) StartTimeout(terminate func()) { + if t.timeout == 0 { + return + } + t.StopTimeout() + + exit := make(chan struct{}, 1) + t.exitQueue.add(exit) + go t.run(terminate, exit) +} + +// StopTimeout terminates a timeout operation and exits gracefully +func (t *timeoutManager) StopTimeout() { + if t.timeout == 0 { + return + } + t.exitQueue.empty() +} + +// run inits the main timeout run function that awaits for the exit command to be triggered or for the +// timeout duration to elapse and trigger the parameter terminate function. +func (t *timeoutManager) run(terminate func(), exit chan struct{}) { + select { + case <-exit: + return + case <-time.After(time.Duration(t.timeout) * time.Millisecond): + terminate() + return + } +} + +// exitQueueManager +type exitQueueManager struct { + queue []chan struct{} + queueLock sync.Mutex +} + +// add handles new exit channels adding them to the exit queue +func (e *exitQueueManager) add(exit chan struct{}) { + e.queueLock.Lock() + defer e.queueLock.Unlock() + + e.queue = append(e.queue, exit) +} + +// empty sends a signal to every exit channel in the queue and then resets the queue +func (e *exitQueueManager) empty() { + e.queueLock.Lock() + defer e.queueLock.Unlock() + + for i := range e.queue { + e.queue[i] <- struct{}{} + } + + e.queue = []chan struct{}{} +} diff --git a/server/timeout_test.go b/server/timeout_test.go new file mode 100644 index 000000000..209764386 --- /dev/null +++ b/server/timeout_test.go @@ -0,0 +1,68 @@ +package server + +import ( + "crypto/rand" + "math/big" + "testing" + "time" +) + +func TestTimeoutManager(t *testing.T) { + tm := newTimeoutManager() + + // test 0 timeout means timeout does not occur + tm.SetTimeout(0) + + // test fuzzing - 0 timeout - multiple sequential calls to random init and stop funcs + for i := 0; i < 30; i++ { + b, err := rand.Int(rand.Reader, big.NewInt(2)) + if err != nil { + t.Error(err) + } + + if b.Int64() == 1 { + tm.StartTimeout(t.FailNow) + } else { + tm.StopTimeout() + } + } + + // test fuzzing - random timeout - multiple sequential calls to random init and stop funcs + for i := 0; i < 30; i++ { + b, err := rand.Int(rand.Reader, big.NewInt(2)) + if err != nil { + t.Error(err) + } + to, err := rand.Int(rand.Reader, big.NewInt(11)) + if err != nil { + t.Error(err) + } + + tm.SetTimeout(uint(to.Int64() * 20)) + + if b.Int64() == 1 { + tm.StartTimeout(t.FailNow) + } else { + tm.StopTimeout() + } + tm.StopTimeout() + } + + // test StopTimeout() prevents termination func + tm.SetTimeout(20) + tm.StartTimeout(t.FailNow) + time.Sleep(10 * time.Millisecond) + tm.StopTimeout() + + // test StartTimeout() executes termination func on timeout + ok := false + tm.SetTimeout(10) + tm.StartTimeout(func() { + ok = true + }) + time.Sleep(20 * time.Millisecond) + if !ok { + t.FailNow() + } + +}