1
0

handler_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. package stream
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "testing"
  12. "time"
  13. "github.com/sirupsen/logrus"
  14. "github.com/stretchr/testify/suite"
  15. "github.com/imgproxy/imgproxy/v3/config"
  16. "github.com/imgproxy/imgproxy/v3/fetcher"
  17. "github.com/imgproxy/imgproxy/v3/headerwriter"
  18. "github.com/imgproxy/imgproxy/v3/httpheaders"
  19. "github.com/imgproxy/imgproxy/v3/options"
  20. )
  21. const (
  22. testDataPath = "../../testdata"
  23. )
  24. type HandlerTestSuite struct {
  25. suite.Suite
  26. handler *Handler
  27. }
  28. func (s *HandlerTestSuite) SetupSuite() {
  29. config.Reset()
  30. config.AllowLoopbackSourceAddresses = true
  31. // Silence logs during tests
  32. logrus.SetOutput(io.Discard)
  33. }
  34. func (s *HandlerTestSuite) TearDownSuite() {
  35. config.Reset()
  36. logrus.SetOutput(os.Stdout)
  37. }
  38. func (s *HandlerTestSuite) SetupTest() {
  39. config.Reset()
  40. config.AllowLoopbackSourceAddresses = true
  41. fc := fetcher.NewDefaultConfig()
  42. fetcher, err := fetcher.New(&fc)
  43. s.Require().NoError(err)
  44. cfg := NewDefaultConfig()
  45. hwc := headerwriter.NewDefaultConfig()
  46. hw, err := headerwriter.New(&hwc)
  47. s.Require().NoError(err)
  48. h, err := New(&cfg, hw, fetcher)
  49. s.Require().NoError(err)
  50. s.handler = h
  51. }
  52. func (s *HandlerTestSuite) readTestFile(name string) []byte {
  53. data, err := os.ReadFile(filepath.Join(testDataPath, name))
  54. s.Require().NoError(err)
  55. return data
  56. }
  57. // TestHandlerBasicRequest checks basic streaming request
  58. func (s *HandlerTestSuite) TestHandlerBasicRequest() {
  59. data := s.readTestFile("test1.png")
  60. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  61. w.Header().Set(httpheaders.ContentType, "image/png")
  62. w.WriteHeader(200)
  63. w.Write(data)
  64. }))
  65. defer ts.Close()
  66. req := httptest.NewRequest("GET", "/", nil)
  67. rw := httptest.NewRecorder()
  68. po := &options.ProcessingOptions{}
  69. err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
  70. s.Require().NoError(err)
  71. res := rw.Result()
  72. s.Require().Equal(200, res.StatusCode)
  73. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  74. // Verify we get the original image data
  75. actual := rw.Body.Bytes()
  76. s.Require().Equal(data, actual)
  77. }
  78. // TestHandlerResponseHeadersPassthrough checks that original response headers are
  79. // passed through to the client
  80. func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
  81. data := s.readTestFile("test1.png")
  82. contentLength := len(data)
  83. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  84. w.Header().Set(httpheaders.ContentType, "image/png")
  85. w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
  86. w.Header().Set(httpheaders.AcceptRanges, "bytes")
  87. w.Header().Set(httpheaders.Etag, "etag")
  88. w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
  89. w.WriteHeader(200)
  90. w.Write(data)
  91. }))
  92. defer ts.Close()
  93. req := httptest.NewRequest("GET", "/", nil)
  94. rw := httptest.NewRecorder()
  95. po := &options.ProcessingOptions{}
  96. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  97. s.Require().NoError(err)
  98. res := rw.Result()
  99. s.Require().Equal(200, res.StatusCode)
  100. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  101. s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
  102. s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
  103. s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
  104. s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
  105. }
  106. // TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
  107. // to the server
  108. func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
  109. etag := `"test-etag-123"`
  110. data := s.readTestFile("test1.png")
  111. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  112. // Verify that If-None-Match header is passed through
  113. s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
  114. s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
  115. s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
  116. w.Header().Set(httpheaders.Etag, etag)
  117. w.WriteHeader(200)
  118. w.Write(data)
  119. }))
  120. defer ts.Close()
  121. req := httptest.NewRequest("GET", "/", nil)
  122. req.Header.Set(httpheaders.IfNoneMatch, etag)
  123. req.Header.Set(httpheaders.AcceptEncoding, "gzip")
  124. req.Header.Set(httpheaders.Range, "bytes=*")
  125. rw := httptest.NewRecorder()
  126. po := &options.ProcessingOptions{}
  127. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  128. s.Require().NoError(err)
  129. res := rw.Result()
  130. s.Require().Equal(200, res.StatusCode)
  131. s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
  132. }
  133. // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
  134. func (s *HandlerTestSuite) TestHandlerContentDisposition() {
  135. data := s.readTestFile("test1.png")
  136. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  137. w.Header().Set(httpheaders.ContentType, "image/png")
  138. w.WriteHeader(200)
  139. w.Write(data)
  140. }))
  141. defer ts.Close()
  142. req := httptest.NewRequest("GET", "/", nil)
  143. rw := httptest.NewRecorder()
  144. po := &options.ProcessingOptions{
  145. Filename: "custom_name",
  146. ReturnAttachment: true,
  147. }
  148. // Use a URL with a .png extension to help content disposition logic
  149. imageURL := ts.URL + "/test.png"
  150. err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
  151. s.Require().NoError(err)
  152. res := rw.Result()
  153. s.Require().Equal(200, res.StatusCode)
  154. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
  155. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
  156. }
  157. // TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
  158. func (s *HandlerTestSuite) TestHandlerCacheControl() {
  159. type testCase struct {
  160. name string
  161. cacheControlPassthrough bool
  162. setupOriginHeaders func(http.ResponseWriter)
  163. timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
  164. expectedStatusCode int
  165. validate func(*testing.T, *http.Response)
  166. }
  167. // Duration variables for test cases
  168. var (
  169. oneHour = time.Hour
  170. thirtyMinutes = 30 * time.Minute
  171. fortyFiveMinutes = 45 * time.Minute
  172. twoHours = time.Hour * 2
  173. oneMinuteDelta = float64(time.Minute)
  174. )
  175. defaultTTL := 4242
  176. testCases := []testCase{
  177. {
  178. name: "Passthrough",
  179. cacheControlPassthrough: true,
  180. setupOriginHeaders: func(w http.ResponseWriter) {
  181. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  182. },
  183. timestampOffset: nil,
  184. expectedStatusCode: 200,
  185. validate: func(t *testing.T, res *http.Response) {
  186. s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
  187. },
  188. },
  189. // Checks that expires gets convert to cache-control
  190. {
  191. name: "ExpiresPassthrough",
  192. cacheControlPassthrough: true,
  193. setupOriginHeaders: func(w http.ResponseWriter) {
  194. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  195. },
  196. timestampOffset: nil,
  197. expectedStatusCode: 200,
  198. validate: func(t *testing.T, res *http.Response) {
  199. // When expires is converted to cache-control, the expires header should be empty
  200. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  201. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  202. },
  203. },
  204. // It would be set to something like default ttl
  205. {
  206. name: "PassthroughDisabled",
  207. cacheControlPassthrough: false,
  208. setupOriginHeaders: func(w http.ResponseWriter) {
  209. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  210. },
  211. timestampOffset: nil,
  212. expectedStatusCode: 200,
  213. validate: func(t *testing.T, res *http.Response) {
  214. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  215. },
  216. },
  217. // When expires is set in processing options, but not present in the response
  218. {
  219. name: "WithProcessingOptionsExpires",
  220. cacheControlPassthrough: false,
  221. setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
  222. timestampOffset: &oneHour,
  223. expectedStatusCode: 200,
  224. validate: func(t *testing.T, res *http.Response) {
  225. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  226. },
  227. },
  228. // When expires is set in processing options, and is present in the response,
  229. // and passthrough is enabled
  230. {
  231. name: "ProcessingOptionsOverridesOrigin",
  232. cacheControlPassthrough: true,
  233. setupOriginHeaders: func(w http.ResponseWriter) {
  234. // Origin has a longer cache time
  235. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  236. },
  237. timestampOffset: &thirtyMinutes,
  238. expectedStatusCode: 200,
  239. validate: func(t *testing.T, res *http.Response) {
  240. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  241. },
  242. },
  243. // When expires is not set in po, but both expires and cc are present in response,
  244. // and passthrough is enabled
  245. {
  246. name: "BothHeadersPassthroughEnabled",
  247. cacheControlPassthrough: true,
  248. setupOriginHeaders: func(w http.ResponseWriter) {
  249. // Origin has both Cache-Control and Expires headers
  250. w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
  251. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  252. },
  253. timestampOffset: nil,
  254. expectedStatusCode: 200,
  255. validate: func(t *testing.T, res *http.Response) {
  256. // Cache-Control should take precedence over Expires when both are present
  257. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  258. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  259. },
  260. },
  261. // When expires is set in PO AND both cache-control and expires are present in response,
  262. // and passthrough is enabled
  263. {
  264. name: "ProcessingOptionsOverridesBothOriginHeaders",
  265. cacheControlPassthrough: true,
  266. setupOriginHeaders: func(w http.ResponseWriter) {
  267. // Origin has both Cache-Control and Expires headers with longer cache times
  268. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  269. w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
  270. },
  271. timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
  272. expectedStatusCode: 200,
  273. validate: func(t *testing.T, res *http.Response) {
  274. s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
  275. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  276. },
  277. },
  278. // No headers set
  279. {
  280. name: "NoOriginHeaders",
  281. cacheControlPassthrough: false,
  282. setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
  283. timestampOffset: nil,
  284. expectedStatusCode: 200,
  285. validate: func(t *testing.T, res *http.Response) {
  286. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  287. },
  288. },
  289. }
  290. for _, tc := range testCases {
  291. s.Run(tc.name, func() {
  292. data := s.readTestFile("test1.png")
  293. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  294. tc.setupOriginHeaders(w)
  295. w.Header().Set(httpheaders.ContentType, "image/png")
  296. w.WriteHeader(200)
  297. w.Write(data)
  298. }))
  299. defer ts.Close()
  300. fc, err := fetcher.LoadConfigFromEnv(nil)
  301. s.Require().NoError(err)
  302. fetcher, err := fetcher.New(fc)
  303. s.Require().NoError(err)
  304. cfg := NewDefaultConfig()
  305. hwc := headerwriter.NewDefaultConfig()
  306. hwc.CacheControlPassthrough = tc.cacheControlPassthrough
  307. hwc.DefaultTTL = 4242
  308. hw, err := headerwriter.New(&hwc)
  309. s.Require().NoError(err)
  310. handler, err := New(&cfg, hw, fetcher)
  311. s.Require().NoError(err)
  312. req := httptest.NewRequest("GET", "/", nil)
  313. rw := httptest.NewRecorder()
  314. po := &options.ProcessingOptions{}
  315. if tc.timestampOffset != nil {
  316. expires := time.Now().Add(*tc.timestampOffset)
  317. po.Expires = &expires
  318. }
  319. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  320. s.Require().NoError(err)
  321. res := rw.Result()
  322. s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
  323. tc.validate(s.T(), res)
  324. })
  325. }
  326. }
  327. // maxAgeValue parses max-age from cache-control
  328. func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
  329. cacheControl := res.Header.Get(httpheaders.CacheControl)
  330. if cacheControl == "" {
  331. return 0
  332. }
  333. var maxAge int
  334. fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
  335. return time.Duration(maxAge) * time.Second
  336. }
  337. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
  338. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
  339. data := s.readTestFile("test1.png")
  340. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  341. w.Header().Set(httpheaders.ContentType, "image/png")
  342. w.WriteHeader(200)
  343. w.Write(data)
  344. }))
  345. defer ts.Close()
  346. req := httptest.NewRequest("GET", "/", nil)
  347. rw := httptest.NewRecorder()
  348. po := &options.ProcessingOptions{}
  349. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  350. s.Require().NoError(err)
  351. res := rw.Result()
  352. s.Require().Equal(200, res.StatusCode)
  353. s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
  354. }
  355. // TestHandlerErrorResponse tests the error responses from the streaming service.
  356. func (s *HandlerTestSuite) TestHandlerErrorResponse() {
  357. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  358. w.WriteHeader(404)
  359. w.Write([]byte("Not Found"))
  360. }))
  361. defer ts.Close()
  362. req := httptest.NewRequest("GET", "/", nil)
  363. rw := httptest.NewRecorder()
  364. po := &options.ProcessingOptions{}
  365. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  366. s.Require().NoError(err)
  367. res := rw.Result()
  368. s.Require().Equal(404, res.StatusCode)
  369. }
  370. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
  371. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
  372. fc, err := fetcher.LoadConfigFromEnv(nil)
  373. s.Require().NoError(err)
  374. fetcher, err := fetcher.New(fc)
  375. s.Require().NoError(err)
  376. cfg := NewDefaultConfig()
  377. cfg.CookiePassthrough = true
  378. hwc := headerwriter.NewDefaultConfig()
  379. hw, err := headerwriter.New(&hwc)
  380. s.Require().NoError(err)
  381. handler, err := New(&cfg, hw, fetcher)
  382. s.Require().NoError(err)
  383. data := s.readTestFile("test1.png")
  384. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  385. // Verify cookies are passed through
  386. cookie, cerr := r.Cookie("test_cookie")
  387. if cerr == nil {
  388. s.Equal("test_value", cookie.Value)
  389. }
  390. w.Header().Set(httpheaders.ContentType, "image/png")
  391. w.WriteHeader(200)
  392. w.Write(data)
  393. }))
  394. defer ts.Close()
  395. req := httptest.NewRequest("GET", "/", nil)
  396. req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
  397. rw := httptest.NewRecorder()
  398. po := &options.ProcessingOptions{}
  399. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  400. s.Require().NoError(err)
  401. res := rw.Result()
  402. s.Require().Equal(200, res.StatusCode)
  403. }
  404. // TestHandlerCanonicalHeader tests that the canonical header is set correctly
  405. func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
  406. data := s.readTestFile("test1.png")
  407. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  408. w.Header().Set(httpheaders.ContentType, "image/png")
  409. w.WriteHeader(200)
  410. w.Write(data)
  411. }))
  412. defer ts.Close()
  413. for _, sc := range []bool{true, false} {
  414. fc, err := fetcher.LoadConfigFromEnv(nil)
  415. s.Require().NoError(err)
  416. fetcher, err := fetcher.New(fc)
  417. s.Require().NoError(err)
  418. cfg := NewDefaultConfig()
  419. hwc := headerwriter.NewDefaultConfig()
  420. hwc.SetCanonicalHeader = sc
  421. hw, err := headerwriter.New(&hwc)
  422. s.Require().NoError(err)
  423. handler, err := New(&cfg, hw, fetcher)
  424. s.Require().NoError(err)
  425. req := httptest.NewRequest("GET", "/", nil)
  426. rw := httptest.NewRecorder()
  427. po := &options.ProcessingOptions{}
  428. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  429. s.Require().NoError(err)
  430. res := rw.Result()
  431. s.Require().Equal(200, res.StatusCode)
  432. if sc {
  433. s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
  434. } else {
  435. s.Require().Empty(res.Header.Get(httpheaders.Link))
  436. }
  437. }
  438. }
  439. func TestHandler(t *testing.T) {
  440. suite.Run(t, new(HandlerTestSuite))
  441. }