123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- package server
- import (
- "context"
- "errors"
- "net/http"
- "net/http/httptest"
- "sync/atomic"
- "testing"
- "time"
- "github.com/imgproxy/imgproxy/v3/httpheaders"
- "github.com/imgproxy/imgproxy/v3/monitoring"
- "github.com/stretchr/testify/suite"
- )
- type ServerTestSuite struct {
- suite.Suite
- config *Config
- monitoring *monitoring.Monitoring
- blankRouter *Router
- }
- func (s *ServerTestSuite) SetupTest() {
- c := NewDefaultConfig()
- s.config = &c
- s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
- mc := monitoring.NewDefaultConfig()
- m, err := monitoring.New(s.T().Context(), &mc, 1)
- s.Require().NoError(err)
- s.monitoring = m
- r, err := NewRouter(s.config, m)
- s.Require().NoError(err)
- s.blankRouter = r
- }
- func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
- return nil
- }
- func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
- return s.blankRouter.rwFactory.NewWriter(rw)
- }
- func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
- ctx, cancel := context.WithCancel(s.T().Context())
- // Track if cancel was called using atomic
- var cancelCalled atomic.Bool
- cancelWrapper := func() {
- cancel()
- cancelCalled.Store(true)
- }
- invalidConfig := NewDefaultConfig()
- invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address
- r, err := NewRouter(&invalidConfig, s.monitoring)
- s.Require().NoError(err)
- server, err := Start(cancelWrapper, r)
- s.Require().Error(err)
- s.Nil(server)
- s.Contains(err.Error(), "can't start server")
- // Check if cancel was called using Eventually
- s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
- // Also verify the context was cancelled
- s.Require().Eventually(func() bool {
- select {
- case <-ctx.Done():
- return true
- default:
- return false
- }
- }, 100*time.Millisecond, 10*time.Millisecond)
- }
- func (s *ServerTestSuite) TestShutdown() {
- _, cancel := context.WithCancel(context.Background())
- defer cancel()
- server, err := Start(cancel, s.blankRouter)
- s.Require().NoError(err)
- s.NotNil(server)
- // Test graceful shutdown
- shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
- defer shutdownCancel()
- // Should not panic or hang
- s.NotPanics(func() {
- server.Shutdown(shutdownCtx)
- })
- }
- func (s *ServerTestSuite) TestWithCORS() {
- tests := []struct {
- name string
- corsAllowOrigin string
- expectedOrigin string
- expectedMethods string
- }{
- {
- name: "WithCORSOrigin",
- corsAllowOrigin: "https://example.com",
- expectedOrigin: "https://example.com",
- expectedMethods: "GET, OPTIONS",
- },
- {
- name: "NoCORSOrigin",
- corsAllowOrigin: "",
- expectedOrigin: "",
- expectedMethods: "",
- },
- }
- for _, tt := range tests {
- s.Run(tt.name, func() {
- config := NewDefaultConfig()
- config.CORSAllowOrigin = tt.corsAllowOrigin
- router, err := NewRouter(&config, s.monitoring)
- s.Require().NoError(err)
- wrappedHandler := router.WithCORS(s.mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- wrappedHandler("test-req-id", s.wrapRW(rw), req)
- s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
- s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
- })
- }
- }
- func (s *ServerTestSuite) TestWithSecret() {
- tests := []struct {
- name string
- secret string
- authHeader string
- expectError bool
- }{
- {
- name: "ValidSecret",
- secret: "test-secret",
- authHeader: "Bearer test-secret",
- },
- {
- name: "InvalidSecret",
- secret: "foo-secret",
- authHeader: "Bearer wrong-secret",
- expectError: true,
- },
- {
- name: "NoSecretConfigured",
- secret: "",
- authHeader: "",
- },
- }
- for _, tt := range tests {
- s.Run(tt.name, func() {
- config := NewDefaultConfig()
- config.Secret = tt.secret
- router, err := NewRouter(&config, s.monitoring)
- s.Require().NoError(err)
- wrappedHandler := router.WithSecret(s.mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- if tt.authHeader != "" {
- req.Header.Set(httpheaders.Authorization, tt.authHeader)
- }
- rw := httptest.NewRecorder()
- err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
- if tt.expectError {
- s.Require().Error(err)
- } else {
- s.Require().NoError(err)
- }
- })
- }
- }
- func (s *ServerTestSuite) TestIntoSuccess() {
- mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
- rw.WriteHeader(http.StatusOK)
- return nil
- }
- wrappedHandler := s.blankRouter.WithReportError(mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- wrappedHandler("test-req-id", s.wrapRW(rw), req)
- s.Equal(http.StatusOK, rw.Code)
- }
- func (s *ServerTestSuite) TestIntoWithError() {
- testError := errors.New("test error")
- mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
- return testError
- }
- wrappedHandler := s.blankRouter.WithReportError(mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- wrappedHandler("test-req-id", s.wrapRW(rw), req)
- s.Equal(http.StatusInternalServerError, rw.Code)
- s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
- }
- func (s *ServerTestSuite) TestIntoPanicWithError() {
- testError := errors.New("panic error")
- mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
- panic(testError)
- }
- wrappedHandler := s.blankRouter.WithPanic(mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- s.NotPanics(func() {
- err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
- s.Require().Error(err, "panic error")
- })
- s.Equal(http.StatusOK, rw.Code)
- }
- func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
- mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
- panic(http.ErrAbortHandler)
- }
- wrappedHandler := s.blankRouter.WithPanic(mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- // Should re-panic with ErrAbortHandler
- s.Panics(func() {
- wrappedHandler("test-req-id", s.wrapRW(rw), req)
- })
- }
- func (s *ServerTestSuite) TestIntoPanicWithNonError() {
- mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
- panic("string panic")
- }
- wrappedHandler := s.blankRouter.WithPanic(mockHandler)
- req := httptest.NewRequest("GET", "/test", nil)
- rw := httptest.NewRecorder()
- // Should re-panic with non-error panics
- s.NotPanics(func() {
- err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
- s.Require().Error(err, "string panic")
- })
- }
- func TestServerTestSuite(t *testing.T) {
- suite.Run(t, new(ServerTestSuite))
- }
|