handler_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. package stream
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "strconv"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/suite"
  11. "github.com/imgproxy/imgproxy/v3/config"
  12. "github.com/imgproxy/imgproxy/v3/fetcher"
  13. "github.com/imgproxy/imgproxy/v3/httpheaders"
  14. "github.com/imgproxy/imgproxy/v3/logger"
  15. "github.com/imgproxy/imgproxy/v3/options"
  16. "github.com/imgproxy/imgproxy/v3/server/responsewriter"
  17. "github.com/imgproxy/imgproxy/v3/testutil"
  18. )
  19. type HandlerTestSuite struct {
  20. testutil.LazySuite
  21. testData *testutil.TestDataProvider
  22. rwConf testutil.LazyObj[*responsewriter.Config]
  23. rwFactory testutil.LazyObj[*responsewriter.Factory]
  24. config testutil.LazyObj[*Config]
  25. handler testutil.LazyObj[*Handler]
  26. testServer testutil.LazyTestServer
  27. }
  28. func (s *HandlerTestSuite) SetupSuite() {
  29. config.Reset()
  30. s.testData = testutil.NewTestDataProvider(s.T)
  31. s.rwConf, _ = testutil.NewLazySuiteObj(
  32. s,
  33. func() (*responsewriter.Config, error) {
  34. c := responsewriter.NewDefaultConfig()
  35. return &c, nil
  36. },
  37. )
  38. s.rwFactory, _ = testutil.NewLazySuiteObj(
  39. s,
  40. func() (*responsewriter.Factory, error) {
  41. return responsewriter.NewFactory(s.rwConf())
  42. },
  43. )
  44. s.config, _ = testutil.NewLazySuiteObj(
  45. s,
  46. func() (*Config, error) {
  47. c := NewDefaultConfig()
  48. return &c, nil
  49. },
  50. )
  51. s.handler, _ = testutil.NewLazySuiteObj(
  52. s,
  53. func() (*Handler, error) {
  54. fc := fetcher.NewDefaultConfig()
  55. fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
  56. fetcher, err := fetcher.New(&fc)
  57. s.Require().NoError(err)
  58. return New(s.config(), fetcher)
  59. },
  60. )
  61. s.testServer, _ = testutil.NewLazySuiteTestServer(s)
  62. // Silence logs during tests
  63. logger.Mute()
  64. }
  65. func (s *HandlerTestSuite) TearDownSuite() {
  66. logger.Unmute()
  67. }
  68. func (s *HandlerTestSuite) SetupSubTest() {
  69. // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
  70. s.ResetLazyObjects()
  71. }
  72. func (s *HandlerTestSuite) execute(
  73. imageURL string,
  74. header http.Header,
  75. po *options.ProcessingOptions,
  76. ) *http.Response {
  77. imageURL = s.testServer().URL() + imageURL
  78. req := httptest.NewRequest("GET", "/", nil)
  79. httpheaders.CopyAll(header, req.Header, true)
  80. ctx := s.T().Context()
  81. rw := httptest.NewRecorder()
  82. rww := s.rwFactory().NewWriter(rw)
  83. err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
  84. s.Require().NoError(err)
  85. return rw.Result()
  86. }
  87. // TestHandlerBasicRequest checks basic streaming request
  88. func (s *HandlerTestSuite) TestHandlerBasicRequest() {
  89. data := s.testData.Read("test1.png")
  90. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  91. res := s.execute("", nil, &options.ProcessingOptions{})
  92. s.Require().Equal(200, res.StatusCode)
  93. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  94. // Verify we get the original image data
  95. actual, err := io.ReadAll(res.Body)
  96. s.Require().NoError(err)
  97. s.Require().Equal(data, actual)
  98. }
  99. // TestHandlerResponseHeadersPassthrough checks that original response headers are
  100. // passed through to the client
  101. func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
  102. data := s.testData.Read("test1.png")
  103. contentLength := len(data)
  104. s.testServer().SetHeaders(
  105. httpheaders.ContentType, "image/png",
  106. httpheaders.ContentLength, strconv.Itoa(contentLength),
  107. httpheaders.AcceptRanges, "bytes",
  108. httpheaders.Etag, "etag",
  109. httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
  110. ).SetBody(data)
  111. res := s.execute("", nil, &options.ProcessingOptions{})
  112. s.Require().Equal(200, res.StatusCode)
  113. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  114. s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
  115. s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
  116. s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
  117. s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
  118. }
  119. // TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
  120. // to the server
  121. func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
  122. etag := `"test-etag-123"`
  123. data := s.testData.Read("test1.png")
  124. s.testServer().
  125. SetBody(data).
  126. SetHeaders(httpheaders.Etag, etag).
  127. SetHook(func(r *http.Request, rw http.ResponseWriter) {
  128. // Verify that If-None-Match header is passed through
  129. s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
  130. s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
  131. s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
  132. })
  133. h := make(http.Header)
  134. h.Set(httpheaders.IfNoneMatch, etag)
  135. h.Set(httpheaders.AcceptEncoding, "gzip")
  136. h.Set(httpheaders.Range, "bytes=*")
  137. res := s.execute("", h, &options.ProcessingOptions{})
  138. s.Require().Equal(200, res.StatusCode)
  139. s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
  140. }
  141. // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
  142. func (s *HandlerTestSuite) TestHandlerContentDisposition() {
  143. data := s.testData.Read("test1.png")
  144. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  145. po := &options.ProcessingOptions{
  146. Filename: "custom_name",
  147. ReturnAttachment: true,
  148. }
  149. // Use a URL with a .png extension to help content disposition logic
  150. res := s.execute("/test.png", nil, po)
  151. s.Require().Equal(200, res.StatusCode)
  152. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
  153. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
  154. }
  155. // TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
  156. func (s *HandlerTestSuite) TestHandlerCacheControl() {
  157. type testCase struct {
  158. name string
  159. cacheControlPassthrough bool
  160. setupOriginHeaders func()
  161. timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
  162. expectedStatusCode int
  163. validate func(*testing.T, *http.Response)
  164. }
  165. // Duration variables for test cases
  166. var (
  167. oneHour = time.Hour
  168. thirtyMinutes = 30 * time.Minute
  169. fortyFiveMinutes = 45 * time.Minute
  170. twoHours = time.Hour * 2
  171. oneMinuteDelta = float64(time.Minute)
  172. )
  173. defaultTTL := 4242
  174. testCases := []testCase{
  175. {
  176. name: "Passthrough",
  177. cacheControlPassthrough: true,
  178. setupOriginHeaders: func() {
  179. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
  180. },
  181. timestampOffset: nil,
  182. expectedStatusCode: 200,
  183. validate: func(t *testing.T, res *http.Response) {
  184. s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
  185. },
  186. },
  187. // Checks that expires gets convert to cache-control
  188. {
  189. name: "ExpiresPassthrough",
  190. cacheControlPassthrough: true,
  191. setupOriginHeaders: func() {
  192. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  193. },
  194. timestampOffset: nil,
  195. expectedStatusCode: 200,
  196. validate: func(t *testing.T, res *http.Response) {
  197. // When expires is converted to cache-control, the expires header should be empty
  198. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  199. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  200. },
  201. },
  202. // It would be set to something like default ttl
  203. {
  204. name: "PassthroughDisabled",
  205. cacheControlPassthrough: false,
  206. setupOriginHeaders: func() {
  207. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
  208. },
  209. timestampOffset: nil,
  210. expectedStatusCode: 200,
  211. validate: func(t *testing.T, res *http.Response) {
  212. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  213. },
  214. },
  215. // When expires is set in processing options, but not present in the response
  216. {
  217. name: "WithProcessingOptionsExpires",
  218. cacheControlPassthrough: false,
  219. timestampOffset: &oneHour,
  220. expectedStatusCode: 200,
  221. validate: func(t *testing.T, res *http.Response) {
  222. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  223. },
  224. },
  225. // When expires is set in processing options, and is present in the response,
  226. // and passthrough is enabled
  227. {
  228. name: "ProcessingOptionsOverridesOrigin",
  229. cacheControlPassthrough: true,
  230. setupOriginHeaders: func() {
  231. // Origin has a longer cache time
  232. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
  233. },
  234. timestampOffset: &thirtyMinutes,
  235. expectedStatusCode: 200,
  236. validate: func(t *testing.T, res *http.Response) {
  237. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  238. },
  239. },
  240. // When expires is not set in po, but both expires and cc are present in response,
  241. // and passthrough is enabled
  242. {
  243. name: "BothHeadersPassthroughEnabled",
  244. cacheControlPassthrough: true,
  245. setupOriginHeaders: func() {
  246. // Origin has both Cache-Control and Expires headers
  247. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
  248. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  249. },
  250. timestampOffset: nil,
  251. expectedStatusCode: 200,
  252. validate: func(t *testing.T, res *http.Response) {
  253. // Cache-Control should take precedence over Expires when both are present
  254. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  255. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  256. },
  257. },
  258. // When expires is set in PO AND both cache-control and expires are present in response,
  259. // and passthrough is enabled
  260. {
  261. name: "ProcessingOptionsOverridesBothOriginHeaders",
  262. cacheControlPassthrough: true,
  263. setupOriginHeaders: func() {
  264. // Origin has both Cache-Control and Expires headers with longer cache times
  265. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
  266. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
  267. },
  268. timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
  269. expectedStatusCode: 200,
  270. validate: func(t *testing.T, res *http.Response) {
  271. s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
  272. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  273. },
  274. },
  275. // No headers set
  276. {
  277. name: "NoOriginHeaders",
  278. cacheControlPassthrough: false,
  279. timestampOffset: nil,
  280. expectedStatusCode: 200,
  281. validate: func(t *testing.T, res *http.Response) {
  282. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  283. },
  284. },
  285. }
  286. for _, tc := range testCases {
  287. s.Run(tc.name, func() {
  288. data := s.testData.Read("test1.png")
  289. if tc.setupOriginHeaders != nil {
  290. tc.setupOriginHeaders()
  291. }
  292. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  293. s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
  294. s.rwConf().DefaultTTL = 4242
  295. po := &options.ProcessingOptions{}
  296. if tc.timestampOffset != nil {
  297. expires := time.Now().Add(*tc.timestampOffset)
  298. po.Expires = &expires
  299. }
  300. res := s.execute("", nil, po)
  301. s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
  302. tc.validate(s.T(), res)
  303. })
  304. }
  305. }
  306. // maxAgeValue parses max-age from cache-control
  307. func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
  308. cacheControl := res.Header.Get(httpheaders.CacheControl)
  309. if cacheControl == "" {
  310. return 0
  311. }
  312. var maxAge int
  313. fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
  314. return time.Duration(maxAge) * time.Second
  315. }
  316. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
  317. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
  318. data := s.testData.Read("test1.png")
  319. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  320. res := s.execute("", nil, &options.ProcessingOptions{})
  321. s.Require().Equal(http.StatusOK, res.StatusCode)
  322. s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
  323. }
  324. // TestHandlerErrorResponse tests the error responses from the streaming service.
  325. func (s *HandlerTestSuite) TestHandlerErrorResponse() {
  326. s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
  327. res := s.execute("", nil, &options.ProcessingOptions{})
  328. s.Require().Equal(http.StatusNotFound, res.StatusCode)
  329. }
  330. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
  331. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
  332. s.config().CookiePassthrough = true
  333. data := s.testData.Read("test1.png")
  334. s.testServer().
  335. SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
  336. SetHook(func(r *http.Request, rw http.ResponseWriter) {
  337. // Verify cookies are passed through
  338. cookie, cerr := r.Cookie("test_cookie")
  339. if cerr == nil {
  340. s.Equal("test_value", cookie.Value)
  341. }
  342. }).SetBody(data)
  343. h := make(http.Header)
  344. h.Set(httpheaders.Cookie, "test_cookie=test_value")
  345. res := s.execute("", h, &options.ProcessingOptions{})
  346. s.Require().Equal(200, res.StatusCode)
  347. }
  348. // TestHandlerCanonicalHeader tests that the canonical header is set correctly
  349. func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
  350. data := s.testData.Read("test1.png")
  351. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  352. for _, sc := range []bool{true, false} {
  353. s.rwConf().SetCanonicalHeader = sc
  354. res := s.execute("", nil, &options.ProcessingOptions{})
  355. s.Require().Equal(200, res.StatusCode)
  356. if sc {
  357. s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
  358. } else {
  359. s.Require().Empty(res.Header.Get(httpheaders.Link))
  360. }
  361. }
  362. }
  363. func TestHandler(t *testing.T) {
  364. suite.Run(t, new(HandlerTestSuite))
  365. }