handler_test.go 17 KB


  1. package stream
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "testing"
  12. "time"
  13. "github.com/sirupsen/logrus"
  14. "github.com/stretchr/testify/suite"
  15. "github.com/imgproxy/imgproxy/v3/config"
  16. "github.com/imgproxy/imgproxy/v3/fetcher"
  17. "github.com/imgproxy/imgproxy/v3/headerwriter"
  18. "github.com/imgproxy/imgproxy/v3/httpheaders"
  19. "github.com/imgproxy/imgproxy/v3/options"
  20. "github.com/imgproxy/imgproxy/v3/transport"
  21. )
  22. const (
  23. testDataPath = "../../testdata"
  24. )
  25. type HandlerTestSuite struct {
  26. suite.Suite
  27. handler *Handler
  28. }
  29. func (s *HandlerTestSuite) SetupSuite() {
  30. config.Reset()
  31. config.AllowLoopbackSourceAddresses = true
  32. // Silence logs during tests
  33. logrus.SetOutput(io.Discard)
  34. }
  35. func (s *HandlerTestSuite) TearDownSuite() {
  36. config.Reset()
  37. logrus.SetOutput(os.Stdout)
  38. }
  39. func (s *HandlerTestSuite) SetupTest() {
  40. config.Reset()
  41. config.AllowLoopbackSourceAddresses = true
  42. trc, err := transport.LoadFromEnv(transport.NewDefaultConfig())
  43. s.Require().NoError(err)
  44. tr, err := transport.New(trc)
  45. s.Require().NoError(err)
  46. fc := fetcher.NewDefaultConfig()
  47. fetcher, err := fetcher.NewFetcher(tr, fc)
  48. s.Require().NoError(err)
  49. cfg := NewDefaultConfig()
  50. hwc := headerwriter.NewDefaultConfig()
  51. hw, err := headerwriter.New(hwc)
  52. s.Require().NoError(err)
  53. h, err := New(cfg, hw, fetcher)
  54. s.Require().NoError(err)
  55. s.handler = h
  56. }
  57. func (s *HandlerTestSuite) readTestFile(name string) []byte {
  58. data, err := os.ReadFile(filepath.Join(testDataPath, name))
  59. s.Require().NoError(err)
  60. return data
  61. }
  62. // TestHandlerBasicRequest checks basic streaming request
  63. func (s *HandlerTestSuite) TestHandlerBasicRequest() {
  64. data := s.readTestFile("test1.png")
  65. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  66. w.Header().Set(httpheaders.ContentType, "image/png")
  67. w.WriteHeader(200)
  68. w.Write(data)
  69. }))
  70. defer ts.Close()
  71. req := httptest.NewRequest("GET", "/", nil)
  72. rw := httptest.NewRecorder()
  73. po := &options.ProcessingOptions{}
  74. err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
  75. s.Require().NoError(err)
  76. res := rw.Result()
  77. s.Require().Equal(200, res.StatusCode)
  78. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  79. // Verify we get the original image data
  80. actual := rw.Body.Bytes()
  81. s.Require().Equal(data, actual)
  82. }
  83. // TestHandlerResponseHeadersPassthrough checks that original response headers are
  84. // passed through to the client
  85. func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
  86. data := s.readTestFile("test1.png")
  87. contentLength := len(data)
  88. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  89. w.Header().Set(httpheaders.ContentType, "image/png")
  90. w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
  91. w.Header().Set(httpheaders.AcceptRanges, "bytes")
  92. w.Header().Set(httpheaders.Etag, "etag")
  93. w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
  94. w.WriteHeader(200)
  95. w.Write(data)
  96. }))
  97. defer ts.Close()
  98. req := httptest.NewRequest("GET", "/", nil)
  99. rw := httptest.NewRecorder()
  100. po := &options.ProcessingOptions{}
  101. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  102. s.Require().NoError(err)
  103. res := rw.Result()
  104. s.Require().Equal(200, res.StatusCode)
  105. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  106. s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
  107. s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
  108. s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
  109. s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
  110. }
  111. // TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
  112. // to the server
  113. func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
  114. etag := `"test-etag-123"`
  115. data := s.readTestFile("test1.png")
  116. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  117. // Verify that If-None-Match header is passed through
  118. s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
  119. s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
  120. s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
  121. w.Header().Set(httpheaders.Etag, etag)
  122. w.WriteHeader(200)
  123. w.Write(data)
  124. }))
  125. defer ts.Close()
  126. req := httptest.NewRequest("GET", "/", nil)
  127. req.Header.Set(httpheaders.IfNoneMatch, etag)
  128. req.Header.Set(httpheaders.AcceptEncoding, "gzip")
  129. req.Header.Set(httpheaders.Range, "bytes=*")
  130. rw := httptest.NewRecorder()
  131. po := &options.ProcessingOptions{}
  132. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  133. s.Require().NoError(err)
  134. res := rw.Result()
  135. s.Require().Equal(200, res.StatusCode)
  136. s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
  137. }
  138. // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
  139. func (s *HandlerTestSuite) TestHandlerContentDisposition() {
  140. data := s.readTestFile("test1.png")
  141. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  142. w.Header().Set(httpheaders.ContentType, "image/png")
  143. w.WriteHeader(200)
  144. w.Write(data)
  145. }))
  146. defer ts.Close()
  147. req := httptest.NewRequest("GET", "/", nil)
  148. rw := httptest.NewRecorder()
  149. po := &options.ProcessingOptions{
  150. Filename: "custom_name",
  151. ReturnAttachment: true,
  152. }
  153. // Use a URL with a .png extension to help content disposition logic
  154. imageURL := ts.URL + "/test.png"
  155. err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
  156. s.Require().NoError(err)
  157. res := rw.Result()
  158. s.Require().Equal(200, res.StatusCode)
  159. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
  160. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
  161. }
  162. // TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
  163. func (s *HandlerTestSuite) TestHandlerCacheControl() {
  164. type testCase struct {
  165. name string
  166. cacheControlPassthrough bool
  167. setupOriginHeaders func(http.ResponseWriter)
  168. timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
  169. expectedStatusCode int
  170. validate func(*testing.T, *http.Response)
  171. }
  172. // Duration variables for test cases
  173. var (
  174. oneHour = time.Hour
  175. thirtyMinutes = 30 * time.Minute
  176. fortyFiveMinutes = 45 * time.Minute
  177. twoHours = time.Hour * 2
  178. oneMinuteDelta = float64(time.Minute)
  179. )
  180. defaultTTL := 4242
  181. testCases := []testCase{
  182. {
  183. name: "Passthrough",
  184. cacheControlPassthrough: true,
  185. setupOriginHeaders: func(w http.ResponseWriter) {
  186. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  187. },
  188. timestampOffset: nil,
  189. expectedStatusCode: 200,
  190. validate: func(t *testing.T, res *http.Response) {
  191. s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
  192. },
  193. },
  194. // Checks that expires gets convert to cache-control
  195. {
  196. name: "ExpiresPassthrough",
  197. cacheControlPassthrough: true,
  198. setupOriginHeaders: func(w http.ResponseWriter) {
  199. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  200. },
  201. timestampOffset: nil,
  202. expectedStatusCode: 200,
  203. validate: func(t *testing.T, res *http.Response) {
  204. // When expires is converted to cache-control, the expires header should be empty
  205. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  206. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  207. },
  208. },
  209. // It would be set to something like default ttl
  210. {
  211. name: "PassthroughDisabled",
  212. cacheControlPassthrough: false,
  213. setupOriginHeaders: func(w http.ResponseWriter) {
  214. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  215. },
  216. timestampOffset: nil,
  217. expectedStatusCode: 200,
  218. validate: func(t *testing.T, res *http.Response) {
  219. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  220. },
  221. },
  222. // When expires is set in processing options, but not present in the response
  223. {
  224. name: "WithProcessingOptionsExpires",
  225. cacheControlPassthrough: false,
  226. setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
  227. timestampOffset: &oneHour,
  228. expectedStatusCode: 200,
  229. validate: func(t *testing.T, res *http.Response) {
  230. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  231. },
  232. },
  233. // When expires is set in processing options, and is present in the response,
  234. // and passthrough is enabled
  235. {
  236. name: "ProcessingOptionsOverridesOrigin",
  237. cacheControlPassthrough: true,
  238. setupOriginHeaders: func(w http.ResponseWriter) {
  239. // Origin has a longer cache time
  240. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  241. },
  242. timestampOffset: &thirtyMinutes,
  243. expectedStatusCode: 200,
  244. validate: func(t *testing.T, res *http.Response) {
  245. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  246. },
  247. },
  248. // When expires is not set in po, but both expires and cc are present in response,
  249. // and passthrough is enabled
  250. {
  251. name: "BothHeadersPassthroughEnabled",
  252. cacheControlPassthrough: true,
  253. setupOriginHeaders: func(w http.ResponseWriter) {
  254. // Origin has both Cache-Control and Expires headers
  255. w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
  256. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  257. },
  258. timestampOffset: nil,
  259. expectedStatusCode: 200,
  260. validate: func(t *testing.T, res *http.Response) {
  261. // Cache-Control should take precedence over Expires when both are present
  262. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  263. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  264. },
  265. },
  266. // When expires is set in PO AND both cache-control and expires are present in response,
  267. // and passthrough is enabled
  268. {
  269. name: "ProcessingOptionsOverridesBothOriginHeaders",
  270. cacheControlPassthrough: true,
  271. setupOriginHeaders: func(w http.ResponseWriter) {
  272. // Origin has both Cache-Control and Expires headers with longer cache times
  273. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  274. w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
  275. },
  276. timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
  277. expectedStatusCode: 200,
  278. validate: func(t *testing.T, res *http.Response) {
  279. s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
  280. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  281. },
  282. },
  283. // No headers set
  284. {
  285. name: "NoOriginHeaders",
  286. cacheControlPassthrough: false,
  287. setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
  288. timestampOffset: nil,
  289. expectedStatusCode: 200,
  290. validate: func(t *testing.T, res *http.Response) {
  291. s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
  292. },
  293. },
  294. }
  295. for _, tc := range testCases {
  296. s.Run(tc.name, func() {
  297. data := s.readTestFile("test1.png")
  298. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  299. tc.setupOriginHeaders(w)
  300. w.Header().Set(httpheaders.ContentType, "image/png")
  301. w.WriteHeader(200)
  302. w.Write(data)
  303. }))
  304. defer ts.Close()
  305. trc, err := transport.LoadFromEnv(transport.NewDefaultConfig())
  306. s.Require().NoError(err)
  307. // Create new handler with updated config for each test
  308. tr, err := transport.New(trc)
  309. s.Require().NoError(err)
  310. fc := fetcher.NewDefaultConfig()
  311. fetcher, err := fetcher.NewFetcher(tr, fc)
  312. s.Require().NoError(err)
  313. cfg := NewDefaultConfig()
  314. hwc := headerwriter.NewDefaultConfig()
  315. hwc.CacheControlPassthrough = tc.cacheControlPassthrough
  316. hwc.DefaultTTL = 4242
  317. hw, err := headerwriter.New(hwc)
  318. s.Require().NoError(err)
  319. handler, err := New(cfg, hw, fetcher)
  320. s.Require().NoError(err)
  321. req := httptest.NewRequest("GET", "/", nil)
  322. rw := httptest.NewRecorder()
  323. po := &options.ProcessingOptions{}
  324. if tc.timestampOffset != nil {
  325. expires := time.Now().Add(*tc.timestampOffset)
  326. po.Expires = &expires
  327. }
  328. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  329. s.Require().NoError(err)
  330. res := rw.Result()
  331. s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
  332. tc.validate(s.T(), res)
  333. })
  334. }
  335. }
  336. // maxAgeValue parses max-age from cache-control
  337. func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
  338. cacheControl := res.Header.Get(httpheaders.CacheControl)
  339. if cacheControl == "" {
  340. return 0
  341. }
  342. var maxAge int
  343. fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
  344. return time.Duration(maxAge) * time.Second
  345. }
  346. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
  347. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
  348. data := s.readTestFile("test1.png")
  349. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  350. w.Header().Set(httpheaders.ContentType, "image/png")
  351. w.WriteHeader(200)
  352. w.Write(data)
  353. }))
  354. defer ts.Close()
  355. req := httptest.NewRequest("GET", "/", nil)
  356. rw := httptest.NewRecorder()
  357. po := &options.ProcessingOptions{}
  358. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  359. s.Require().NoError(err)
  360. res := rw.Result()
  361. s.Require().Equal(200, res.StatusCode)
  362. s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
  363. }
  364. // TestHandlerErrorResponse tests the error responses from the streaming service.
  365. func (s *HandlerTestSuite) TestHandlerErrorResponse() {
  366. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  367. w.WriteHeader(404)
  368. w.Write([]byte("Not Found"))
  369. }))
  370. defer ts.Close()
  371. req := httptest.NewRequest("GET", "/", nil)
  372. rw := httptest.NewRecorder()
  373. po := &options.ProcessingOptions{}
  374. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  375. s.Require().NoError(err)
  376. res := rw.Result()
  377. s.Require().Equal(404, res.StatusCode)
  378. }
  379. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
  380. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
  381. trc, err := transport.LoadFromEnv(transport.NewDefaultConfig())
  382. s.Require().NoError(err)
  383. // Create new handler with updated config
  384. tr, err := transport.New(trc)
  385. s.Require().NoError(err)
  386. fc := fetcher.NewDefaultConfig()
  387. fetcher, err := fetcher.NewFetcher(tr, fc)
  388. s.Require().NoError(err)
  389. cfg := NewDefaultConfig()
  390. cfg.CookiePassthrough = true
  391. hwc := headerwriter.NewDefaultConfig()
  392. hw, err := headerwriter.New(hwc)
  393. s.Require().NoError(err)
  394. handler, err := New(cfg, hw, fetcher)
  395. s.Require().NoError(err)
  396. data := s.readTestFile("test1.png")
  397. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  398. // Verify cookies are passed through
  399. cookie, cerr := r.Cookie("test_cookie")
  400. if cerr == nil {
  401. s.Equal("test_value", cookie.Value)
  402. }
  403. w.Header().Set(httpheaders.ContentType, "image/png")
  404. w.WriteHeader(200)
  405. w.Write(data)
  406. }))
  407. defer ts.Close()
  408. req := httptest.NewRequest("GET", "/", nil)
  409. req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
  410. rw := httptest.NewRecorder()
  411. po := &options.ProcessingOptions{}
  412. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  413. s.Require().NoError(err)
  414. res := rw.Result()
  415. s.Require().Equal(200, res.StatusCode)
  416. }
  417. // TestHandlerCanonicalHeader tests that the canonical header is set correctly
  418. func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
  419. data := s.readTestFile("test1.png")
  420. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  421. w.Header().Set(httpheaders.ContentType, "image/png")
  422. w.WriteHeader(200)
  423. w.Write(data)
  424. }))
  425. defer ts.Close()
  426. for _, sc := range []bool{true, false} {
  427. trc, err := transport.LoadFromEnv(transport.NewDefaultConfig())
  428. s.Require().NoError(err)
  429. // Create new handler with updated config
  430. tr, err := transport.New(trc)
  431. s.Require().NoError(err)
  432. fc := fetcher.NewDefaultConfig()
  433. fetcher, err := fetcher.NewFetcher(tr, fc)
  434. s.Require().NoError(err)
  435. cfg := NewDefaultConfig()
  436. hwc := headerwriter.NewDefaultConfig()
  437. hwc.SetCanonicalHeader = sc
  438. hw, err := headerwriter.New(hwc)
  439. s.Require().NoError(err)
  440. handler, err := New(cfg, hw, fetcher)
  441. s.Require().NoError(err)
  442. req := httptest.NewRequest("GET", "/", nil)
  443. rw := httptest.NewRecorder()
  444. po := &options.ProcessingOptions{}
  445. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  446. s.Require().NoError(err)
  447. res := rw.Result()
  448. s.Require().Equal(200, res.StatusCode)
  449. if sc {
  450. s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
  451. } else {
  452. s.Require().Empty(res.Header.Get(httpheaders.Link))
  453. }
  454. }
  455. }
  456. func TestHandler(t *testing.T) {
  457. suite.Run(t, new(HandlerTestSuite))
  458. }