server_test.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/imgproxy/imgproxy/v3/httpheaders"
  11. "github.com/stretchr/testify/suite"
  12. )
  13. type ServerTestSuite struct {
  14. suite.Suite
  15. config *Config
  16. blankRouter *Router
  17. }
  18. func (s *ServerTestSuite) SetupTest() {
  19. c := NewDefaultConfig()
  20. s.config = &c
  21. s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
  22. r, err := NewRouter(s.config)
  23. s.Require().NoError(err)
  24. s.blankRouter = r
  25. }
  26. func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
  27. return nil
  28. }
  29. func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
  30. return s.blankRouter.rwFactory.NewWriter(rw)
  31. }
  32. func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
  33. ctx, cancel := context.WithCancel(s.T().Context())
  34. // Track if cancel was called using atomic
  35. var cancelCalled atomic.Bool
  36. cancelWrapper := func() {
  37. cancel()
  38. cancelCalled.Store(true)
  39. }
  40. invalidConfig := NewDefaultConfig()
  41. invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address
  42. r, err := NewRouter(&invalidConfig)
  43. s.Require().NoError(err)
  44. server, err := Start(cancelWrapper, r)
  45. s.Require().Error(err)
  46. s.Nil(server)
  47. s.Contains(err.Error(), "can't start server")
  48. // Check if cancel was called using Eventually
  49. s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
  50. // Also verify the context was cancelled
  51. s.Require().Eventually(func() bool {
  52. select {
  53. case <-ctx.Done():
  54. return true
  55. default:
  56. return false
  57. }
  58. }, 100*time.Millisecond, 10*time.Millisecond)
  59. }
  60. func (s *ServerTestSuite) TestShutdown() {
  61. _, cancel := context.WithCancel(context.Background())
  62. defer cancel()
  63. server, err := Start(cancel, s.blankRouter)
  64. s.Require().NoError(err)
  65. s.NotNil(server)
  66. // Test graceful shutdown
  67. shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
  68. defer shutdownCancel()
  69. // Should not panic or hang
  70. s.NotPanics(func() {
  71. server.Shutdown(shutdownCtx)
  72. })
  73. }
  74. func (s *ServerTestSuite) TestWithCORS() {
  75. tests := []struct {
  76. name string
  77. corsAllowOrigin string
  78. expectedOrigin string
  79. expectedMethods string
  80. }{
  81. {
  82. name: "WithCORSOrigin",
  83. corsAllowOrigin: "https://example.com",
  84. expectedOrigin: "https://example.com",
  85. expectedMethods: "GET, OPTIONS",
  86. },
  87. {
  88. name: "NoCORSOrigin",
  89. corsAllowOrigin: "",
  90. expectedOrigin: "",
  91. expectedMethods: "",
  92. },
  93. }
  94. for _, tt := range tests {
  95. s.Run(tt.name, func() {
  96. config := NewDefaultConfig()
  97. config.CORSAllowOrigin = tt.corsAllowOrigin
  98. router, err := NewRouter(&config)
  99. s.Require().NoError(err)
  100. wrappedHandler := router.WithCORS(s.mockHandler)
  101. req := httptest.NewRequest("GET", "/test", nil)
  102. rw := httptest.NewRecorder()
  103. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  104. s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
  105. s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
  106. })
  107. }
  108. }
  109. func (s *ServerTestSuite) TestWithSecret() {
  110. tests := []struct {
  111. name string
  112. secret string
  113. authHeader string
  114. expectError bool
  115. }{
  116. {
  117. name: "ValidSecret",
  118. secret: "test-secret",
  119. authHeader: "Bearer test-secret",
  120. },
  121. {
  122. name: "InvalidSecret",
  123. secret: "foo-secret",
  124. authHeader: "Bearer wrong-secret",
  125. expectError: true,
  126. },
  127. {
  128. name: "NoSecretConfigured",
  129. secret: "",
  130. authHeader: "",
  131. },
  132. }
  133. for _, tt := range tests {
  134. s.Run(tt.name, func() {
  135. config := NewDefaultConfig()
  136. config.Secret = tt.secret
  137. router, err := NewRouter(&config)
  138. s.Require().NoError(err)
  139. wrappedHandler := router.WithSecret(s.mockHandler)
  140. req := httptest.NewRequest("GET", "/test", nil)
  141. if tt.authHeader != "" {
  142. req.Header.Set(httpheaders.Authorization, tt.authHeader)
  143. }
  144. rw := httptest.NewRecorder()
  145. err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
  146. if tt.expectError {
  147. s.Require().Error(err)
  148. } else {
  149. s.Require().NoError(err)
  150. }
  151. })
  152. }
  153. }
  154. func (s *ServerTestSuite) TestIntoSuccess() {
  155. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  156. rw.WriteHeader(http.StatusOK)
  157. return nil
  158. }
  159. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  160. req := httptest.NewRequest("GET", "/test", nil)
  161. rw := httptest.NewRecorder()
  162. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  163. s.Equal(http.StatusOK, rw.Code)
  164. }
  165. func (s *ServerTestSuite) TestIntoWithError() {
  166. testError := errors.New("test error")
  167. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  168. return testError
  169. }
  170. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  171. req := httptest.NewRequest("GET", "/test", nil)
  172. rw := httptest.NewRecorder()
  173. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  174. s.Equal(http.StatusInternalServerError, rw.Code)
  175. s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
  176. }
  177. func (s *ServerTestSuite) TestIntoPanicWithError() {
  178. testError := errors.New("panic error")
  179. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  180. panic(testError)
  181. }
  182. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  183. req := httptest.NewRequest("GET", "/test", nil)
  184. rw := httptest.NewRecorder()
  185. s.NotPanics(func() {
  186. err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
  187. s.Require().Error(err, "panic error")
  188. })
  189. s.Equal(http.StatusOK, rw.Code)
  190. }
  191. func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
  192. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  193. panic(http.ErrAbortHandler)
  194. }
  195. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  196. req := httptest.NewRequest("GET", "/test", nil)
  197. rw := httptest.NewRecorder()
  198. // Should re-panic with ErrAbortHandler
  199. s.Panics(func() {
  200. wrappedHandler("test-req-id", s.wrapRW(rw), req)
  201. })
  202. }
  203. func (s *ServerTestSuite) TestIntoPanicWithNonError() {
  204. mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
  205. panic("string panic")
  206. }
  207. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  208. req := httptest.NewRequest("GET", "/test", nil)
  209. rw := httptest.NewRecorder()
  210. // Should re-panic with non-error panics
  211. s.NotPanics(func() {
  212. err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
  213. s.Require().Error(err, "string panic")
  214. })
  215. }
  216. func TestServerTestSuite(t *testing.T) {
  217. suite.Run(t, new(ServerTestSuite))
  218. }