handler_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. package stream
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "os"
  8. "strconv"
  9. "testing"
  10. "time"
  11. "github.com/sirupsen/logrus"
  12. "github.com/stretchr/testify/suite"
  13. "github.com/imgproxy/imgproxy/v3/config"
  14. "github.com/imgproxy/imgproxy/v3/fetcher"
  15. "github.com/imgproxy/imgproxy/v3/httpheaders"
  16. "github.com/imgproxy/imgproxy/v3/options"
  17. "github.com/imgproxy/imgproxy/v3/server/responsewriter"
  18. "github.com/imgproxy/imgproxy/v3/testutil"
  19. )
  20. type HandlerTestSuite struct {
  21. testutil.LazySuite
  22. testData *testutil.TestDataProvider
  23. rwConf testutil.LazyObj[*responsewriter.Config]
  24. rwFactory testutil.LazyObj[*responsewriter.Factory]
  25. config testutil.LazyObj[*Config]
  26. handler testutil.LazyObj[*Handler]
  27. testServer testutil.LazyTestServer
  28. }
  29. func (s *HandlerTestSuite) SetupSuite() {
  30. config.Reset()
  31. s.testData = testutil.NewTestDataProvider(s.T)
  32. s.rwConf, _ = testutil.NewLazySuiteObj(
  33. s,
  34. func() (*responsewriter.Config, error) {
  35. c := responsewriter.NewDefaultConfig()
  36. return &c, nil
  37. },
  38. )
  39. s.rwFactory, _ = testutil.NewLazySuiteObj(
  40. s,
  41. func() (*responsewriter.Factory, error) {
  42. return responsewriter.NewFactory(s.rwConf())
  43. },
  44. )
  45. s.config, _ = testutil.NewLazySuiteObj(
  46. s,
  47. func() (*Config, error) {
  48. c := NewDefaultConfig()
  49. return &c, nil
  50. },
  51. )
  52. s.handler, _ = testutil.NewLazySuiteObj(
  53. s,
  54. func() (*Handler, error) {
  55. fc := fetcher.NewDefaultConfig()
  56. fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
  57. fetcher, err := fetcher.New(&fc)
  58. s.Require().NoError(err)
  59. return New(s.config(), fetcher)
  60. },
  61. )
  62. s.testServer, _ = testutil.NewLazySuiteTestServer(s)
  63. // Silence logs during tests
  64. logrus.SetOutput(io.Discard)
  65. }
  66. func (s *HandlerTestSuite) TearDownSuite() {
  67. logrus.SetOutput(os.Stdout)
  68. }
  69. func (s *HandlerTestSuite) SetupSubTest() {
  70. // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
  71. s.ResetLazyObjects()
  72. }
  73. func (s *HandlerTestSuite) execute(
  74. imageURL string,
  75. header http.Header,
  76. po *options.ProcessingOptions,
  77. ) *http.Response {
  78. imageURL = s.testServer().URL() + imageURL
  79. req := httptest.NewRequest("GET", "/", nil)
  80. httpheaders.CopyAll(header, req.Header, true)
  81. ctx := s.T().Context()
  82. rw := httptest.NewRecorder()
  83. rww := s.rwFactory().NewWriter(rw)
  84. err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
  85. s.Require().NoError(err)
  86. return rw.Result()
  87. }
  88. // TestHandlerBasicRequest checks basic streaming request
  89. func (s *HandlerTestSuite) TestHandlerBasicRequest() {
  90. data := s.testData.Read("test1.png")
  91. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  92. res := s.execute("", nil, &options.ProcessingOptions{})
  93. s.Require().Equal(200, res.StatusCode)
  94. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  95. // Verify we get the original image data
  96. actual, err := io.ReadAll(res.Body)
  97. s.Require().NoError(err)
  98. s.Require().Equal(data, actual)
  99. }
  100. // TestHandlerResponseHeadersPassthrough checks that original response headers are
  101. // passed through to the client
  102. func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
  103. data := s.testData.Read("test1.png")
  104. contentLength := len(data)
  105. s.testServer().SetHeaders(
  106. httpheaders.ContentType, "image/png",
  107. httpheaders.ContentLength, strconv.Itoa(contentLength),
  108. httpheaders.AcceptRanges, "bytes",
  109. httpheaders.Etag, "etag",
  110. httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
  111. ).SetBody(data)
  112. res := s.execute("", nil, &options.ProcessingOptions{})
  113. s.Require().Equal(200, res.StatusCode)
  114. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  115. s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
  116. s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
  117. s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
  118. s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
  119. }
  120. // TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
  121. // to the server
  122. func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
  123. etag := `"test-etag-123"`
  124. data := s.testData.Read("test1.png")
  125. s.testServer().
  126. SetBody(data).
  127. SetHeaders(httpheaders.Etag, etag).
  128. SetHook(func(r *http.Request, rw http.ResponseWriter) {
  129. // Verify that If-None-Match header is passed through
  130. s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
  131. s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
  132. s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
  133. })
  134. h := make(http.Header)
  135. h.Set(httpheaders.IfNoneMatch, etag)
  136. h.Set(httpheaders.AcceptEncoding, "gzip")
  137. h.Set(httpheaders.Range, "bytes=*")
  138. res := s.execute("", h, &options.ProcessingOptions{})
  139. s.Require().Equal(200, res.StatusCode)
  140. s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
  141. }
  142. // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
  143. func (s *HandlerTestSuite) TestHandlerContentDisposition() {
  144. data := s.testData.Read("test1.png")
  145. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  146. po := &options.ProcessingOptions{
  147. Filename: "custom_name",
  148. ReturnAttachment: true,
  149. }
  150. // Use a URL with a .png extension to help content disposition logic
  151. res := s.execute("/test.png", nil, po)
  152. s.Require().Equal(200, res.StatusCode)
  153. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
  154. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
  155. }
  156. // TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
  157. func (s *HandlerTestSuite) TestHandlerCacheControl() {
  158. type testCase struct {
  159. name string
  160. cacheControlPassthrough bool
  161. setupOriginHeaders func()
  162. timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
  163. expectedStatusCode int
  164. validate func(*testing.T, *http.Response)
  165. }
  166. // Duration variables for test cases
  167. var (
  168. oneHour = time.Hour
  169. thirtyMinutes = 30 * time.Minute
  170. fortyFiveMinutes = 45 * time.Minute
  171. twoHours = time.Hour * 2
  172. oneMinuteDelta = float64(time.Minute)
  173. )
  174. defaultTTL := 4242
  175. testCases := []testCase{
  176. {
  177. name: "Passthrough",
  178. cacheControlPassthrough: true,
  179. setupOriginHeaders: func() {
  180. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
  181. },
  182. timestampOffset: nil,
  183. expectedStatusCode: 200,
  184. validate: func(t *testing.T, res *http.Response) {
  185. s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
  186. },
  187. },
  188. // Checks that expires gets convert to cache-control
  189. {
  190. name: "ExpiresPassthrough",
  191. cacheControlPassthrough: true,
  192. setupOriginHeaders: func() {
  193. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  194. },
  195. timestampOffset: nil,
  196. expectedStatusCode: 200,
  197. validate: func(t *testing.T, res *http.Response) {
  198. // When expires is converted to cache-control, the expires header should be empty
  199. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  200. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  201. },
  202. },
  203. // It would be set to something like default ttl
  204. {
  205. name: "PassthroughDisabled",
  206. cacheControlPassthrough: false,
  207. setupOriginHeaders: func() {
  208. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
  209. },
  210. timestampOffset: nil,
  211. expectedStatusCode: 200,
  212. validate: func(t *testing.T, res *http.Response) {
  213. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  214. },
  215. },
  216. // When expires is set in processing options, but not present in the response
  217. {
  218. name: "WithProcessingOptionsExpires",
  219. cacheControlPassthrough: false,
  220. timestampOffset: &oneHour,
  221. expectedStatusCode: 200,
  222. validate: func(t *testing.T, res *http.Response) {
  223. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  224. },
  225. },
  226. // When expires is set in processing options, and is present in the response,
  227. // and passthrough is enabled
  228. {
  229. name: "ProcessingOptionsOverridesOrigin",
  230. cacheControlPassthrough: true,
  231. setupOriginHeaders: func() {
  232. // Origin has a longer cache time
  233. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
  234. },
  235. timestampOffset: &thirtyMinutes,
  236. expectedStatusCode: 200,
  237. validate: func(t *testing.T, res *http.Response) {
  238. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  239. },
  240. },
  241. // When expires is not set in po, but both expires and cc are present in response,
  242. // and passthrough is enabled
  243. {
  244. name: "BothHeadersPassthroughEnabled",
  245. cacheControlPassthrough: true,
  246. setupOriginHeaders: func() {
  247. // Origin has both Cache-Control and Expires headers
  248. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
  249. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  250. },
  251. timestampOffset: nil,
  252. expectedStatusCode: 200,
  253. validate: func(t *testing.T, res *http.Response) {
  254. // Cache-Control should take precedence over Expires when both are present
  255. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  256. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  257. },
  258. },
  259. // When expires is set in PO AND both cache-control and expires are present in response,
  260. // and passthrough is enabled
  261. {
  262. name: "ProcessingOptionsOverridesBothOriginHeaders",
  263. cacheControlPassthrough: true,
  264. setupOriginHeaders: func() {
  265. // Origin has both Cache-Control and Expires headers with longer cache times
  266. s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
  267. s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
  268. },
  269. timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
  270. expectedStatusCode: 200,
  271. validate: func(t *testing.T, res *http.Response) {
  272. s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
  273. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  274. },
  275. },
  276. // No headers set
  277. {
  278. name: "NoOriginHeaders",
  279. cacheControlPassthrough: false,
  280. timestampOffset: nil,
  281. expectedStatusCode: 200,
  282. validate: func(t *testing.T, res *http.Response) {
  283. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  284. },
  285. },
  286. }
  287. for _, tc := range testCases {
  288. s.Run(tc.name, func() {
  289. data := s.testData.Read("test1.png")
  290. if tc.setupOriginHeaders != nil {
  291. tc.setupOriginHeaders()
  292. }
  293. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  294. s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
  295. s.rwConf().DefaultTTL = 4242
  296. po := &options.ProcessingOptions{}
  297. if tc.timestampOffset != nil {
  298. expires := time.Now().Add(*tc.timestampOffset)
  299. po.Expires = &expires
  300. }
  301. res := s.execute("", nil, po)
  302. s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
  303. tc.validate(s.T(), res)
  304. })
  305. }
  306. }
  307. // maxAgeValue parses max-age from cache-control
  308. func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
  309. cacheControl := res.Header.Get(httpheaders.CacheControl)
  310. if cacheControl == "" {
  311. return 0
  312. }
  313. var maxAge int
  314. fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
  315. return time.Duration(maxAge) * time.Second
  316. }
  317. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
  318. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
  319. data := s.testData.Read("test1.png")
  320. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  321. res := s.execute("", nil, &options.ProcessingOptions{})
  322. s.Require().Equal(http.StatusOK, res.StatusCode)
  323. s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
  324. }
  325. // TestHandlerErrorResponse tests the error responses from the streaming service.
  326. func (s *HandlerTestSuite) TestHandlerErrorResponse() {
  327. s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
  328. res := s.execute("", nil, &options.ProcessingOptions{})
  329. s.Require().Equal(http.StatusNotFound, res.StatusCode)
  330. }
  331. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
  332. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
  333. s.config().CookiePassthrough = true
  334. data := s.testData.Read("test1.png")
  335. s.testServer().
  336. SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
  337. SetHook(func(r *http.Request, rw http.ResponseWriter) {
  338. // Verify cookies are passed through
  339. cookie, cerr := r.Cookie("test_cookie")
  340. if cerr == nil {
  341. s.Equal("test_value", cookie.Value)
  342. }
  343. }).SetBody(data)
  344. h := make(http.Header)
  345. h.Set(httpheaders.Cookie, "test_cookie=test_value")
  346. res := s.execute("", h, &options.ProcessingOptions{})
  347. s.Require().Equal(200, res.StatusCode)
  348. }
  349. // TestHandlerCanonicalHeader tests that the canonical header is set correctly
  350. func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
  351. data := s.testData.Read("test1.png")
  352. s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
  353. for _, sc := range []bool{true, false} {
  354. s.rwConf().SetCanonicalHeader = sc
  355. res := s.execute("", nil, &options.ProcessingOptions{})
  356. s.Require().Equal(200, res.StatusCode)
  357. if sc {
  358. s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
  359. } else {
  360. s.Require().Empty(res.Header.Get(httpheaders.Link))
  361. }
  362. }
  363. }
  364. func TestHandler(t *testing.T) {
  365. suite.Run(t, new(HandlerTestSuite))
  366. }