router_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package server
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/stretchr/testify/suite"
  7. "github.com/imgproxy/imgproxy/v3/httpheaders"
  8. "github.com/imgproxy/imgproxy/v3/monitoring"
  9. )
  10. type RouterTestSuite struct {
  11. suite.Suite
  12. router *Router
  13. }
  14. func (s *RouterTestSuite) SetupTest() {
  15. c := NewDefaultConfig()
  16. mc := monitoring.NewDefaultConfig()
  17. m, err := monitoring.New(s.T().Context(), &mc, 1)
  18. s.Require().NoError(err)
  19. c.PathPrefix = "/api"
  20. r, err := NewRouter(&c, m)
  21. s.Require().NoError(err)
  22. s.router = r
  23. }
  24. // TestHTTPMethods tests route methods registration and HTTP requests
  25. func (s *RouterTestSuite) TestHTTPMethods() {
  26. var capturedMethod string
  27. var capturedPath string
  28. getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  29. capturedMethod = req.Method
  30. capturedPath = req.URL.Path
  31. rw.WriteHeader(200)
  32. rw.Write([]byte("GET response"))
  33. return nil
  34. }
  35. optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  36. capturedMethod = req.Method
  37. capturedPath = req.URL.Path
  38. rw.WriteHeader(200)
  39. rw.Write([]byte("OPTIONS response"))
  40. return nil
  41. }
  42. headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  43. capturedMethod = req.Method
  44. capturedPath = req.URL.Path
  45. rw.WriteHeader(200)
  46. return nil
  47. }
  48. // Register routes with different configurations
  49. s.router.GET("/get-test", getHandler) // exact match
  50. s.router.OPTIONS("/options-test*", optionsHandler) // prefix match
  51. s.router.HEAD("/head-test", headHandler) // exact match
  52. tests := []struct {
  53. name string
  54. requestMethod string
  55. requestPath string
  56. expectedBody string
  57. expectedPath string
  58. }{
  59. {
  60. name: "GET",
  61. requestMethod: http.MethodGet,
  62. requestPath: "/api/get-test",
  63. expectedBody: "GET response",
  64. expectedPath: "/api/get-test",
  65. },
  66. {
  67. name: "OPTIONS",
  68. requestMethod: http.MethodOptions,
  69. requestPath: "/api/options-test",
  70. expectedBody: "OPTIONS response",
  71. expectedPath: "/api/options-test",
  72. },
  73. {
  74. name: "OPTIONSPrefixed",
  75. requestMethod: http.MethodOptions,
  76. requestPath: "/api/options-test/sub",
  77. expectedBody: "OPTIONS response",
  78. expectedPath: "/api/options-test/sub",
  79. },
  80. {
  81. name: "HEAD",
  82. requestMethod: http.MethodHead,
  83. requestPath: "/api/head-test",
  84. expectedBody: "",
  85. expectedPath: "/api/head-test",
  86. },
  87. }
  88. for _, tt := range tests {
  89. s.Run(tt.name, func() {
  90. req := httptest.NewRequest(tt.requestMethod, tt.requestPath, nil)
  91. rw := httptest.NewRecorder()
  92. s.router.ServeHTTP(rw, req)
  93. s.Require().Equal(tt.expectedBody, rw.Body.String())
  94. s.Require().Equal(tt.requestMethod, capturedMethod)
  95. s.Require().Equal(tt.expectedPath, capturedPath)
  96. })
  97. }
  98. }
  99. // TestMiddlewareOrder checks middleware ordering and functionality
  100. func (s *RouterTestSuite) TestMiddlewareOrder() {
  101. var order []string
  102. middleware1 := func(next RouteHandler) RouteHandler {
  103. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  104. order = append(order, "middleware1")
  105. return next(reqID, rw, req)
  106. }
  107. }
  108. middleware2 := func(next RouteHandler) RouteHandler {
  109. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  110. order = append(order, "middleware2")
  111. return next(reqID, rw, req)
  112. }
  113. }
  114. handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  115. order = append(order, "handler")
  116. rw.WriteHeader(200)
  117. return nil
  118. }
  119. s.router.GET("/test", handler, middleware2, middleware1)
  120. req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
  121. rw := httptest.NewRecorder()
  122. s.router.ServeHTTP(rw, req)
  123. // Middleware should execute in the order they are passed (first added first)
  124. s.Require().Equal([]string{"middleware1", "middleware2", "handler"}, order)
  125. }
  126. // TestServeHTTP tests ServeHTTP method
  127. func (s *RouterTestSuite) TestServeHTTP() {
  128. handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  129. rw.Header().Set("Custom-Header", "test-value")
  130. rw.WriteHeader(200)
  131. rw.Write([]byte("success"))
  132. return nil
  133. }
  134. s.router.GET("/test", handler)
  135. req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
  136. rw := httptest.NewRecorder()
  137. s.router.ServeHTTP(rw, req)
  138. s.Require().Equal(200, rw.Code)
  139. s.Require().Equal("success", rw.Body.String())
  140. s.Require().Equal("test-value", rw.Header().Get("Custom-Header"))
  141. s.Require().Equal(defaultServerName, rw.Header().Get(httpheaders.Server))
  142. s.Require().NotEmpty(rw.Header().Get(httpheaders.XRequestID))
  143. }
  144. // TestRequestID checks request ID generation and validation
  145. func (s *RouterTestSuite) TestRequestID() {
  146. handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  147. rw.WriteHeader(200)
  148. return nil
  149. }
  150. s.router.GET("/test", handler)
  151. // Test request ID passthrough (if present)
  152. req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
  153. req.Header.Set(httpheaders.XRequestID, "valid-id-123")
  154. rw := httptest.NewRecorder()
  155. s.router.ServeHTTP(rw, req)
  156. s.Require().Equal("valid-id-123", rw.Header().Get(httpheaders.XRequestID))
  157. // Test invalid request ID (should generate a new one)
  158. req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
  159. req.Header.Set(httpheaders.XRequestID, "invalid id with spaces!")
  160. rw = httptest.NewRecorder()
  161. s.router.ServeHTTP(rw, req)
  162. generatedID := rw.Header().Get(httpheaders.XRequestID)
  163. s.Require().NotEqual("invalid id with spaces!", generatedID)
  164. s.Require().NotEmpty(generatedID)
  165. // Test no request ID (should generate a new one)
  166. req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
  167. rw = httptest.NewRecorder()
  168. s.router.ServeHTTP(rw, req)
  169. generatedID = rw.Header().Get(httpheaders.XRequestID)
  170. s.Require().NotEmpty(generatedID)
  171. s.Require().Regexp(`^[A-Za-z0-9_\-]+$`, generatedID)
  172. }
  173. // TestLambdaRequestIDExtraction checks AWS lambda request id extraction
  174. func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
  175. handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  176. rw.WriteHeader(200)
  177. return nil
  178. }
  179. s.router.GET("/test", handler)
  180. // Test with valid Lambda context
  181. req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
  182. req.Header.Set(httpheaders.XAmznRequestContextHeader, `{"requestId":"lambda-req-123"}`)
  183. rw := httptest.NewRecorder()
  184. s.router.ServeHTTP(rw, req)
  185. s.Require().Equal("lambda-req-123", rw.Header().Get(httpheaders.XRequestID))
  186. }
  187. // Test IP address handling
  188. func (s *RouterTestSuite) TestReplaceIP() {
  189. var capturedRemoteAddr string
  190. handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
  191. capturedRemoteAddr = req.RemoteAddr
  192. rw.WriteHeader(200)
  193. return nil
  194. }
  195. s.router.GET("/test", handler)
  196. tests := []struct {
  197. name string
  198. originalAddr string
  199. headers map[string]string
  200. expectedAddr string
  201. }{
  202. {
  203. name: "CFConnectingIP",
  204. originalAddr: "original:8080",
  205. headers: map[string]string{
  206. httpheaders.CFConnectingIP: "1.2.3.4",
  207. },
  208. expectedAddr: "1.2.3.4:8080",
  209. },
  210. {
  211. name: "XForwardedForMulti",
  212. originalAddr: "original:8080",
  213. headers: map[string]string{
  214. httpheaders.XForwardedFor: "5.6.7.8, 9.10.11.12",
  215. },
  216. expectedAddr: "5.6.7.8:8080",
  217. },
  218. {
  219. name: "XForwardedForSingle",
  220. originalAddr: "original:8080",
  221. headers: map[string]string{
  222. httpheaders.XForwardedFor: "13.14.15.16",
  223. },
  224. expectedAddr: "13.14.15.16:8080",
  225. },
  226. {
  227. name: "XRealIP",
  228. originalAddr: "original:8080",
  229. headers: map[string]string{
  230. httpheaders.XRealIP: "17.18.19.20",
  231. },
  232. expectedAddr: "17.18.19.20:8080",
  233. },
  234. }
  235. for _, tt := range tests {
  236. s.Run(tt.name, func() {
  237. req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
  238. req.RemoteAddr = tt.originalAddr
  239. for header, value := range tt.headers {
  240. req.Header.Set(header, value)
  241. }
  242. rw := httptest.NewRecorder()
  243. s.router.ServeHTTP(rw, req)
  244. s.Require().Equal(tt.expectedAddr, capturedRemoteAddr)
  245. })
  246. }
  247. }
  248. // TestRouteOrder checks exact/non-exact insertion order
  249. func (s *RouterTestSuite) TestRouteOrder() {
  250. h := func(reqID string, rw ResponseWriter, req *http.Request) error {
  251. return nil
  252. }
  253. s.router.GET("/test*", h)
  254. s.router.GET("/test/path", h)
  255. s.router.GET("/test/path/nested", h)
  256. s.Require().Equal("/api/test/path", s.router.routes[0].path)
  257. s.Require().Equal("/api/test/path/nested", s.router.routes[1].path)
  258. s.Require().Equal("/api/test", s.router.routes[2].path)
  259. }
  260. func TestRouterSuite(t *testing.T) {
  261. suite.Run(t, new(RouterTestSuite))
  262. }