1
0

server_test.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/imgproxy/imgproxy/v3/httpheaders"
  11. "github.com/stretchr/testify/suite"
  12. )
  13. type ServerTestSuite struct {
  14. suite.Suite
  15. config *Config
  16. blankRouter *Router
  17. }
  18. func (s *ServerTestSuite) SetupTest() {
  19. c := NewDefaultConfig()
  20. s.config = &c
  21. s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
  22. r, err := NewRouter(s.config)
  23. s.Require().NoError(err)
  24. s.blankRouter = r
  25. }
  26. func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
  27. return nil
  28. }
  29. func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
  30. ctx, cancel := context.WithCancel(s.T().Context())
  31. // Track if cancel was called using atomic
  32. var cancelCalled atomic.Bool
  33. cancelWrapper := func() {
  34. cancel()
  35. cancelCalled.Store(true)
  36. }
  37. invalidConfig := NewDefaultConfig()
  38. invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address
  39. r, err := NewRouter(&invalidConfig)
  40. s.Require().NoError(err)
  41. server, err := Start(cancelWrapper, r)
  42. s.Require().Error(err)
  43. s.Nil(server)
  44. s.Contains(err.Error(), "can't start server")
  45. // Check if cancel was called using Eventually
  46. s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
  47. // Also verify the context was cancelled
  48. s.Require().Eventually(func() bool {
  49. select {
  50. case <-ctx.Done():
  51. return true
  52. default:
  53. return false
  54. }
  55. }, 100*time.Millisecond, 10*time.Millisecond)
  56. }
  57. func (s *ServerTestSuite) TestShutdown() {
  58. _, cancel := context.WithCancel(context.Background())
  59. defer cancel()
  60. server, err := Start(cancel, s.blankRouter)
  61. s.Require().NoError(err)
  62. s.NotNil(server)
  63. // Test graceful shutdown
  64. shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
  65. defer shutdownCancel()
  66. // Should not panic or hang
  67. s.NotPanics(func() {
  68. server.Shutdown(shutdownCtx)
  69. })
  70. }
  71. func (s *ServerTestSuite) TestWithCORS() {
  72. tests := []struct {
  73. name string
  74. corsAllowOrigin string
  75. expectedOrigin string
  76. expectedMethods string
  77. }{
  78. {
  79. name: "WithCORSOrigin",
  80. corsAllowOrigin: "https://example.com",
  81. expectedOrigin: "https://example.com",
  82. expectedMethods: "GET, OPTIONS",
  83. },
  84. {
  85. name: "NoCORSOrigin",
  86. corsAllowOrigin: "",
  87. expectedOrigin: "",
  88. expectedMethods: "",
  89. },
  90. }
  91. for _, tt := range tests {
  92. s.Run(tt.name, func() {
  93. config := NewDefaultConfig()
  94. config.CORSAllowOrigin = tt.corsAllowOrigin
  95. router, err := NewRouter(&config)
  96. s.Require().NoError(err)
  97. wrappedHandler := router.WithCORS(s.mockHandler)
  98. req := httptest.NewRequest("GET", "/test", nil)
  99. rw := httptest.NewRecorder()
  100. wrappedHandler("test-req-id", rw, req)
  101. s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
  102. s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
  103. })
  104. }
  105. }
  106. func (s *ServerTestSuite) TestWithSecret() {
  107. tests := []struct {
  108. name string
  109. secret string
  110. authHeader string
  111. expectError bool
  112. }{
  113. {
  114. name: "ValidSecret",
  115. secret: "test-secret",
  116. authHeader: "Bearer test-secret",
  117. },
  118. {
  119. name: "InvalidSecret",
  120. secret: "foo-secret",
  121. authHeader: "Bearer wrong-secret",
  122. expectError: true,
  123. },
  124. {
  125. name: "NoSecretConfigured",
  126. secret: "",
  127. authHeader: "",
  128. },
  129. }
  130. for _, tt := range tests {
  131. s.Run(tt.name, func() {
  132. config := NewDefaultConfig()
  133. config.Secret = tt.secret
  134. router, err := NewRouter(&config)
  135. s.Require().NoError(err)
  136. wrappedHandler := router.WithSecret(s.mockHandler)
  137. req := httptest.NewRequest("GET", "/test", nil)
  138. if tt.authHeader != "" {
  139. req.Header.Set(httpheaders.Authorization, tt.authHeader)
  140. }
  141. rw := httptest.NewRecorder()
  142. err = wrappedHandler("test-req-id", rw, req)
  143. if tt.expectError {
  144. s.Require().Error(err)
  145. } else {
  146. s.Require().NoError(err)
  147. }
  148. })
  149. }
  150. }
  151. func (s *ServerTestSuite) TestIntoSuccess() {
  152. mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
  153. rw.WriteHeader(http.StatusOK)
  154. return nil
  155. }
  156. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  157. req := httptest.NewRequest("GET", "/test", nil)
  158. rw := httptest.NewRecorder()
  159. wrappedHandler("test-req-id", rw, req)
  160. s.Equal(http.StatusOK, rw.Code)
  161. }
  162. func (s *ServerTestSuite) TestIntoWithError() {
  163. testError := errors.New("test error")
  164. mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
  165. return testError
  166. }
  167. wrappedHandler := s.blankRouter.WithReportError(mockHandler)
  168. req := httptest.NewRequest("GET", "/test", nil)
  169. rw := httptest.NewRecorder()
  170. wrappedHandler("test-req-id", rw, req)
  171. s.Equal(http.StatusInternalServerError, rw.Code)
  172. s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
  173. }
  174. func (s *ServerTestSuite) TestIntoPanicWithError() {
  175. testError := errors.New("panic error")
  176. mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
  177. panic(testError)
  178. }
  179. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  180. req := httptest.NewRequest("GET", "/test", nil)
  181. rw := httptest.NewRecorder()
  182. s.NotPanics(func() {
  183. err := wrappedHandler("test-req-id", rw, req)
  184. s.Require().Error(err, "panic error")
  185. })
  186. s.Equal(http.StatusOK, rw.Code)
  187. }
  188. func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
  189. mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
  190. panic(http.ErrAbortHandler)
  191. }
  192. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  193. req := httptest.NewRequest("GET", "/test", nil)
  194. rw := httptest.NewRecorder()
  195. // Should re-panic with ErrAbortHandler
  196. s.Panics(func() {
  197. wrappedHandler("test-req-id", rw, req)
  198. })
  199. }
  200. func (s *ServerTestSuite) TestIntoPanicWithNonError() {
  201. mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
  202. panic("string panic")
  203. }
  204. wrappedHandler := s.blankRouter.WithPanic(mockHandler)
  205. req := httptest.NewRequest("GET", "/test", nil)
  206. rw := httptest.NewRecorder()
  207. // Should re-panic with non-error panics
  208. s.NotPanics(func() {
  209. err := wrappedHandler("test-req-id", rw, req)
  210. s.Require().Error(err, "string panic")
  211. })
  212. }
  213. func TestServerTestSuite(t *testing.T) {
  214. suite.Run(t, new(ServerTestSuite))
  215. }