server_test.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/imgproxy/imgproxy/v3/httpheaders"
  11. "github.com/imgproxy/imgproxy/v3/monitoring"
  12. "github.com/stretchr/testify/suite"
  13. )
  14. type ServerTestSuite struct {
  15. suite.Suite
  16. config *Config
  17. monitoring *monitoring.Monitoring
  18. blankRouter *Router
  19. }
  20. func (s *ServerTestSuite) SetupTest() {
  21. c := NewDefaultConfig()
  22. s.config = &c
  23. s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
  24. mc := monitoring.NewDefaultConfig()
  25. m, err := monitoring.New(s.T().Context(), &mc, 1)
  26. s.Require().NoError(err)
  27. s.monitoring = m
  28. r, err := NewRouter(s.config, m)
  29. s.Require().NoError(err)
  30. s.blankRouter = r
  31. }
  32. func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
  33. return nil
  34. }
  35. func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
  36. return s.blankRouter.rwFactory.NewWriter(rw)
  37. }
  38. func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
  39. ctx, cancel := context.WithCancel(s.T().Context())
  40. // Track if cancel was called using atomic
  41. var cancelCalled atomic.Bool
  42. cancelWrapper := func() {
  43. cancel()
  44. cancelCalled.Store(true)
  45. }
  46. invalidConfig := NewDefaultConfig()
  47. invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address
  48. r, err := NewRouter(&invalidConfig, s.monitoring)
  49. s.Require().NoError(err)
  50. server, err := Start(cancelWrapper, r)
  51. s.Require().Error(err)
  52. s.Nil(server)
  53. s.Contains(err.Error(), "can't start server")
  54. // Check if cancel was called using Eventually
  55. s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
  56. // Also verify the context was cancelled
  57. s.Require().Eventually(func() bool {
  58. select {
  59. case <-ctx.Done():
  60. return true
  61. default:
  62. return false
  63. }
  64. }, 100*time.Millisecond, 10*time.Millisecond)
  65. }
  66. func (s *ServerTestSuite) TestShutdown() {
  67. _, cancel := context.WithCancel(context.Background())
  68. defer cancel()
  69. server, err := Start(cancel, s.blankRouter)
  70. s.Require().NoError(err)
  71. s.NotNil(server)
  72. // Test graceful shutdown
  73. shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
  74. defer shutdownCancel()
  75. // Should not panic or hang
  76. s.NotPanics(func() {
  77. server.Shutdown(shutdownCtx)
  78. })
  79. }
  80. func (s *ServerTestSuite) TestWithCORS() {
  81. tests := []struct {
  82. name string
  83. corsAllowOrigin string
  84. expectedOrigin string
  85. expectedMethods string
  86. }{
  87. {
  88. name: "WithCORSOrigin",
  89. corsAllowOrigin: "https://example.com",
  90. expectedOrigin: "https://example.com",
  91. expectedMethods: "GET, OPTIONS",
  92. },
  93. {
  94. name: "NoCORSOrigin",
  95. corsAllowOrigin: "",
  96. expectedOrigin: "",
  97. expectedMethods: "",
  98. },
  99. }
  100. for _, tt := range tests {
  101. s.Run(tt.name, func() {
  102. config := NewDefaultConfig()
  103. config.CORSAllowOrigin = tt.corsAllowOrigin
  104. router, err := NewRouter(&config, s.monitoring)
  105. s.Require().NoError(err)
  106. wrappedHandler := router.WithCORS(s.mockHandler)
  107. req := httptest.NewRequest("GET", "/test", nil)
  108. rw := httptest.NewRecorder()
  109. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  110. s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
  111. s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
  112. })
  113. }
  114. }
  115. func (s *ServerTestSuite) TestWithSecret() {
  116. tests := []struct {
  117. name string
  118. secret string
  119. authHeader string
  120. expectError bool
  121. }{
  122. {
  123. name: "ValidSecret",
  124. secret: "test-secret",
  125. authHeader: "Bearer test-secret",
  126. },
  127. {
  128. name: "InvalidSecret",
  129. secret: "foo-secret",
  130. authHeader: "Bearer wrong-secret",
  131. expectError: true,
  132. },
  133. {
  134. name: "NoSecretConfigured",
  135. secret: "",
  136. authHeader: "",
  137. },
  138. }
  139. for _, tt := range tests {
  140. s.Run(tt.name, func() {
  141. config := NewDefaultConfig()
  142. config.Secret = tt.secret
  143. router, err := NewRouter(&config, s.monitoring)
  144. s.Require().NoError(err)
  145. wrappedHandler := router.WithSecret(s.mockHandler)
  146. req := httptest.NewRequest("GET", "/test", nil)
  147. if tt.authHeader != "" {
  148. req.Header.Set(httpheaders.Authorization, tt.authHeader)
  149. }
  150. rw := httptest.NewRecorder()
  151. err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
  152. if tt.expectError {
  153. s.Require().Error(err)
  154. } else {
  155. s.Require().NoError(err)
  156. }
  157. })
  158. }
  159. }
  160. func (s *ServerTestSuite) TestIntoSuccess() {
  161. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  162. rw.WriteHeader(http.StatusOK)
  163. return nil
  164. }
  165. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  166. req := httptest.NewRequest("GET", "/test", nil)
  167. rw := httptest.NewRecorder()
  168. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  169. s.Equal(http.StatusOK, rw.Code)
  170. }
  171. func (s *ServerTestSuite) TestIntoWithError() {
  172. testError := errors.New("test error")
  173. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  174. return testError
  175. }
  176. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  177. req := httptest.NewRequest("GET", "/test", nil)
  178. rw := httptest.NewRecorder()
  179. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  180. s.Equal(http.StatusInternalServerError, rw.Code)
  181. s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
  182. }
  183. func (s *ServerTestSuite) TestIntoPanicWithError() {
  184. testError := errors.New("panic error")
  185. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  186. panic(testError)
  187. }
  188. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  189. req := httptest.NewRequest("GET", "/test", nil)
  190. rw := httptest.NewRecorder()
  191. s.NotPanics(func() {
  192. err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
  193. s.Require().Error(err, "panic error")
  194. })
  195. s.Equal(http.StatusOK, rw.Code)
  196. }
  197. func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
  198. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  199. panic(http.ErrAbortHandler)
  200. }
  201. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  202. req := httptest.NewRequest("GET", "/test", nil)
  203. rw := httptest.NewRecorder()
  204. // Should re-panic with ErrAbortHandler
  205. s.Panics(func() {
  206. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  207. })
  208. }
  209. func (s *ServerTestSuite) TestIntoPanicWithNonError() {
  210. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  211. panic("string panic")
  212. }
  213. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  214. req := httptest.NewRequest("GET", "/test", nil)
  215. rw := httptest.NewRecorder()
  216. // Should re-panic with non-error panics
  217. s.NotPanics(func() {
  218. err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
  219. s.Require().Error(err, "string panic")
  220. })
  221. }
  222. func TestServerTestSuite(t *testing.T) {
  223. suite.Run(t, new(ServerTestSuite))
  224. }