diff --git a/README.md b/README.md index 0c4a277..74daf59 100644 --- a/README.md +++ b/README.md @@ -74,8 +74,8 @@ $ make run - [X] Health endpoint - [X] Next birthdays endpoint - [ ] Birthday list endpoint -- [ ] Allow to use a random port in web tests -- [ ] Web server should be optional +- [X] Allow to use a random port in web tests +- [X] Web server should be optional - [ ] Create different message systems to use with the bot - [X] Telegram - [ ] Email diff --git a/example-config.yml b/example-config.yml index 57b9ca2..d59be78 100644 --- a/example-config.yml +++ b/example-config.yml @@ -1,5 +1,6 @@ --- web: + enabled: true port: 8080 birthdays: diff --git a/model/config.go b/model/config.go index 469cd4c..9ec6a18 100644 --- a/model/config.go +++ b/model/config.go @@ -111,14 +111,11 @@ func (lc *LoggerConfig) IsValid() error { } type WebConfig struct { - Port int `yaml:"port"` + Enabled bool `yaml:"enabled"` + Port int `yaml:"port"` } -func (wc *WebConfig) SetDefaults() { - if wc.Port == 0 { - wc.Port = 8080 - } -} +func (wc *WebConfig) SetDefaults() {} func (wc *WebConfig) IsValid() error { return nil diff --git a/server/helpers_test.go b/server/helpers_test.go index c9b6af4..fec1ce9 100644 --- a/server/helpers_test.go +++ b/server/helpers_test.go @@ -28,9 +28,8 @@ func testConfig(t *testing.T) *model.Config { require.NoError(t, f.Close()) require.NoError(t, os.Remove(f.Name())) - // ToDo: allow for a random port to be used return &model.Config{ - Web: &model.WebConfig{Port: 9090}, + Web: &model.WebConfig{Enabled: true, Port: 0}, Birthdays: &model.BirthdaysConfig{File: f.Name()}, } } diff --git a/server/server.go b/server/server.go index 9b88f74..46254df 100644 --- a/server/server.go +++ b/server/server.go @@ -108,10 +108,15 @@ func New(options ...Option) (*Server, error) { } } - if srv.WebServer == nil { + if srv.WebServer == nil && srv.Config.Web.Enabled { srv.Logger.Debug("creating web server") - srv.WebServer = NewWebServer(srv) + ws, err := NewWebServer(srv) + if err != nil { + return nil, fmt.Errorf("cannot create web server: %w", err) + } + + srv.WebServer = ws } return srv, nil @@ -120,8 +125,10 @@ func New(options ...Option) (*Server, error) { func (s *Server) Start() error { s.Logger.Info("starting server") - if err := s.WebServer.Start(); err != nil { - return fmt.Errorf("cannot start web server: %w", err) + if s.WebServer != nil { + if err := s.WebServer.Start(); err != nil { + return fmt.Errorf("cannot start web server: %w", err) + } } for _, worker := range s.workers { @@ -136,8 +143,10 @@ func (s *Server) Start() error { func (s *Server) Stop() error { s.Logger.Info("stopping server") - if err := s.WebServer.Stop(); err != nil { - return fmt.Errorf("cannot stop web server: %w", err) + if s.WebServer != nil { + if err := s.WebServer.Stop(); err != nil { + return fmt.Errorf("cannot stop web server: %w", err) + } } for _, worker := range s.workers { diff --git a/server/server_test.go b/server/server_test.go index 3da18b9..8d76cb1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -15,6 +15,7 @@ import ( func TestNotify(t *testing.T) { th := SetupTestHelper(t) defer th.TearDown() + t.Run("should correctly use the notification services to notify", func(t *testing.T) { birthday := th.srv.birthdays[0] th.mockNotificationService. diff --git a/server/web.go b/server/web.go index 6709699..d2e1dfd 100644 --- a/server/web.go +++ b/server/web.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "github.com/charmbracelet/log" @@ -11,33 +12,37 @@ import ( type WebServer struct { server *Server + listener net.Listener logger *log.Logger httpServer *http.Server } -func NewWebServer(server *Server) *WebServer { +func NewWebServer(server *Server) (*WebServer, error) { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", server.Config.Web.Port)) + if err != nil { + return nil, fmt.Errorf("cannot create listener: %w", err) + } + ws := &WebServer{ - server: server, - logger: server.Logger, - httpServer: &http.Server{ - Addr: fmt.Sprintf(":%d", server.Config.Web.Port), - }, + server: server, + listener: listener, + logger: server.Logger, } mux := http.NewServeMux() mux.HandleFunc("/health", ws.healthHandler) mux.HandleFunc("/next_birthdays", ws.nextBirthdayHandler) - ws.httpServer.Handler = mux + ws.httpServer = &http.Server{Handler: mux} - return ws + return ws, nil } func (ws *WebServer) Start() error { ws.logger.Debug("starting web server") go func() { - if err := ws.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := ws.httpServer.Serve(ws.listener); err != nil && !errors.Is(err, http.ErrServerClosed) { ws.logger.Fatal("cannot start web server", "error", err) } }() @@ -55,6 +60,10 @@ func (ws *WebServer) Stop() error { return nil } +func (ws *WebServer) Port() int { + return ws.listener.Addr().(*net.TCPAddr).Port +} + func (ws *WebServer) healthHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "OK") } diff --git a/server/web_test.go b/server/web_test.go new file mode 100644 index 0000000..14d41ac --- /dev/null +++ b/server/web_test.go @@ -0,0 +1,15 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPort(t *testing.T) { + th := SetupTestHelper(t) + defer th.TearDown() + + port := th.srv.WebServer.Port() + require.NotEmpty(t, port) +}