handler_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. package stream
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "testing"
  12. "time"
  13. "github.com/sirupsen/logrus"
  14. "github.com/stretchr/testify/suite"
  15. "github.com/imgproxy/imgproxy/v3/config"
  16. "github.com/imgproxy/imgproxy/v3/headerwriter"
  17. "github.com/imgproxy/imgproxy/v3/httpheaders"
  18. "github.com/imgproxy/imgproxy/v3/imagefetcher"
  19. "github.com/imgproxy/imgproxy/v3/options"
  20. "github.com/imgproxy/imgproxy/v3/transport"
  21. )
  22. const (
  23. testDataPath = "../../testdata"
  24. )
  25. type HandlerTestSuite struct {
  26. suite.Suite
  27. handler *Handler
  28. }
  29. func (s *HandlerTestSuite) SetupSuite() {
  30. config.Reset()
  31. config.AllowLoopbackSourceAddresses = true
  32. // Silence logs during tests
  33. logrus.SetOutput(io.Discard)
  34. }
  35. func (s *HandlerTestSuite) TearDownSuite() {
  36. config.Reset()
  37. logrus.SetOutput(os.Stdout)
  38. }
  39. func (s *HandlerTestSuite) SetupTest() {
  40. config.Reset()
  41. config.AllowLoopbackSourceAddresses = true
  42. tr, err := transport.NewTransport()
  43. s.Require().NoError(err)
  44. fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
  45. s.Require().NoError(err)
  46. s.handler = New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
  47. }
  48. func (s *HandlerTestSuite) readTestFile(name string) []byte {
  49. data, err := os.ReadFile(filepath.Join(testDataPath, name))
  50. s.Require().NoError(err)
  51. return data
  52. }
  53. // TestHandlerBasicRequest checks basic streaming request
  54. func (s *HandlerTestSuite) TestHandlerBasicRequest() {
  55. data := s.readTestFile("test1.png")
  56. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  57. w.Header().Set(httpheaders.ContentType, "image/png")
  58. w.WriteHeader(200)
  59. w.Write(data)
  60. }))
  61. defer ts.Close()
  62. req := httptest.NewRequest("GET", "/", nil)
  63. rw := httptest.NewRecorder()
  64. po := &options.ProcessingOptions{}
  65. err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
  66. s.Require().NoError(err)
  67. res := rw.Result()
  68. s.Require().Equal(200, res.StatusCode)
  69. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  70. // Verify we get the original image data
  71. actual := rw.Body.Bytes()
  72. s.Require().Equal(data, actual)
  73. }
  74. // TestHandlerResponseHeadersPassthrough checks that original response headers are
  75. // passed through to the client
  76. func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
  77. data := s.readTestFile("test1.png")
  78. contentLength := len(data)
  79. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  80. w.Header().Set(httpheaders.ContentType, "image/png")
  81. w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
  82. w.Header().Set(httpheaders.AcceptRanges, "bytes")
  83. w.Header().Set(httpheaders.Etag, "etag")
  84. w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
  85. w.WriteHeader(200)
  86. w.Write(data)
  87. }))
  88. defer ts.Close()
  89. req := httptest.NewRequest("GET", "/", nil)
  90. rw := httptest.NewRecorder()
  91. po := &options.ProcessingOptions{}
  92. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  93. s.Require().NoError(err)
  94. res := rw.Result()
  95. s.Require().Equal(200, res.StatusCode)
  96. s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
  97. s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
  98. s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
  99. s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
  100. s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
  101. }
  102. // TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
  103. // to the server
  104. func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
  105. etag := `"test-etag-123"`
  106. data := s.readTestFile("test1.png")
  107. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  108. // Verify that If-None-Match header is passed through
  109. s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
  110. s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
  111. s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
  112. w.Header().Set(httpheaders.Etag, etag)
  113. w.WriteHeader(200)
  114. w.Write(data)
  115. }))
  116. defer ts.Close()
  117. req := httptest.NewRequest("GET", "/", nil)
  118. req.Header.Set(httpheaders.IfNoneMatch, etag)
  119. req.Header.Set(httpheaders.AcceptEncoding, "gzip")
  120. req.Header.Set(httpheaders.Range, "bytes=*")
  121. rw := httptest.NewRecorder()
  122. po := &options.ProcessingOptions{}
  123. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  124. s.Require().NoError(err)
  125. res := rw.Result()
  126. s.Require().Equal(200, res.StatusCode)
  127. s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
  128. }
  129. // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
  130. func (s *HandlerTestSuite) TestHandlerContentDisposition() {
  131. data := s.readTestFile("test1.png")
  132. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  133. w.Header().Set(httpheaders.ContentType, "image/png")
  134. w.WriteHeader(200)
  135. w.Write(data)
  136. }))
  137. defer ts.Close()
  138. req := httptest.NewRequest("GET", "/", nil)
  139. rw := httptest.NewRecorder()
  140. po := &options.ProcessingOptions{
  141. Filename: "custom_name",
  142. ReturnAttachment: true,
  143. }
  144. // Use a URL with a .png extension to help content disposition logic
  145. imageURL := ts.URL + "/test.png"
  146. err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
  147. s.Require().NoError(err)
  148. res := rw.Result()
  149. s.Require().Equal(200, res.StatusCode)
  150. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
  151. s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
  152. }
  153. // TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
  154. func (s *HandlerTestSuite) TestHandlerCacheControl() {
  155. type testCase struct {
  156. name string
  157. cacheControlPassthrough bool
  158. setupOriginHeaders func(http.ResponseWriter)
  159. timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
  160. expectedStatusCode int
  161. validate func(*testing.T, *http.Response)
  162. }
  163. // Duration variables for test cases
  164. var (
  165. oneHour = time.Hour
  166. thirtyMinutes = 30 * time.Minute
  167. fortyFiveMinutes = 45 * time.Minute
  168. twoHours = time.Hour * 2
  169. oneMinuteDelta = float64(time.Minute)
  170. )
  171. // Set this explicitly for testing purposes
  172. config.TTL = 4242
  173. testCases := []testCase{
  174. {
  175. name: "Passthrough",
  176. cacheControlPassthrough: true,
  177. setupOriginHeaders: func(w http.ResponseWriter) {
  178. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  179. },
  180. timestampOffset: nil,
  181. expectedStatusCode: 200,
  182. validate: func(t *testing.T, res *http.Response) {
  183. s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
  184. },
  185. },
  186. // Checks that expires gets convert to cache-control
  187. {
  188. name: "ExpiresPassthrough",
  189. cacheControlPassthrough: true,
  190. setupOriginHeaders: func(w http.ResponseWriter) {
  191. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  192. },
  193. timestampOffset: nil,
  194. expectedStatusCode: 200,
  195. validate: func(t *testing.T, res *http.Response) {
  196. // When expires is converted to cache-control, the expires header should be empty
  197. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  198. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  199. },
  200. },
  201. // It would be set to something like default ttl
  202. {
  203. name: "PassthroughDisabled",
  204. cacheControlPassthrough: false,
  205. setupOriginHeaders: func(w http.ResponseWriter) {
  206. w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
  207. },
  208. timestampOffset: nil,
  209. expectedStatusCode: 200,
  210. validate: func(t *testing.T, res *http.Response) {
  211. s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second)
  212. },
  213. },
  214. // When expires is set in processing options, but not present in the response
  215. {
  216. name: "WithProcessingOptionsExpires",
  217. cacheControlPassthrough: false,
  218. setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
  219. timestampOffset: &oneHour,
  220. expectedStatusCode: 200,
  221. validate: func(t *testing.T, res *http.Response) {
  222. s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
  223. },
  224. },
  225. // When expires is set in processing options, and is present in the response,
  226. // and passthrough is enabled
  227. {
  228. name: "ProcessingOptionsOverridesOrigin",
  229. cacheControlPassthrough: true,
  230. setupOriginHeaders: func(w http.ResponseWriter) {
  231. // Origin has a longer cache time
  232. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  233. },
  234. timestampOffset: &thirtyMinutes,
  235. expectedStatusCode: 200,
  236. validate: func(t *testing.T, res *http.Response) {
  237. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  238. },
  239. },
  240. // When expires is not set in po, but both expires and cc are present in response,
  241. // and passthrough is enabled
  242. {
  243. name: "BothHeadersPassthroughEnabled",
  244. cacheControlPassthrough: true,
  245. setupOriginHeaders: func(w http.ResponseWriter) {
  246. // Origin has both Cache-Control and Expires headers
  247. w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
  248. w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
  249. },
  250. timestampOffset: nil,
  251. expectedStatusCode: 200,
  252. validate: func(t *testing.T, res *http.Response) {
  253. // Cache-Control should take precedence over Expires when both are present
  254. s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
  255. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  256. },
  257. },
  258. // When expires is set in PO AND both cache-control and expires are present in response,
  259. // and passthrough is enabled
  260. {
  261. name: "ProcessingOptionsOverridesBothOriginHeaders",
  262. cacheControlPassthrough: true,
  263. setupOriginHeaders: func(w http.ResponseWriter) {
  264. // Origin has both Cache-Control and Expires headers with longer cache times
  265. w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
  266. w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
  267. },
  268. timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
  269. expectedStatusCode: 200,
  270. validate: func(t *testing.T, res *http.Response) {
  271. s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
  272. s.Require().Empty(res.Header.Get(httpheaders.Expires))
  273. },
  274. },
  275. // No headers set
  276. {
  277. name: "NoOriginHeaders",
  278. cacheControlPassthrough: false,
  279. setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
  280. timestampOffset: nil,
  281. expectedStatusCode: 200,
  282. validate: func(t *testing.T, res *http.Response) {
  283. s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second)
  284. },
  285. },
  286. }
  287. for _, tc := range testCases {
  288. s.Run(tc.name, func() {
  289. // Set config values for this test
  290. config.CacheControlPassthrough = tc.cacheControlPassthrough
  291. config.TTL = 4242 // Set consistent TTL for testing
  292. data := s.readTestFile("test1.png")
  293. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  294. tc.setupOriginHeaders(w)
  295. w.Header().Set(httpheaders.ContentType, "image/png")
  296. w.WriteHeader(200)
  297. w.Write(data)
  298. }))
  299. defer ts.Close()
  300. // Create new handler with updated config for each test
  301. tr, err := transport.NewTransport()
  302. s.Require().NoError(err)
  303. fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
  304. s.Require().NoError(err)
  305. handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
  306. req := httptest.NewRequest("GET", "/", nil)
  307. rw := httptest.NewRecorder()
  308. po := &options.ProcessingOptions{}
  309. if tc.timestampOffset != nil {
  310. expires := time.Now().Add(*tc.timestampOffset)
  311. po.Expires = &expires
  312. }
  313. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  314. s.Require().NoError(err)
  315. res := rw.Result()
  316. s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
  317. tc.validate(s.T(), res)
  318. })
  319. }
  320. }
  321. // maxAgeValue parses max-age from cache-control
  322. func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
  323. cacheControl := res.Header.Get(httpheaders.CacheControl)
  324. if cacheControl == "" {
  325. return 0
  326. }
  327. var maxAge int
  328. fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
  329. return time.Duration(maxAge) * time.Second
  330. }
  331. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
  332. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
  333. data := s.readTestFile("test1.png")
  334. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  335. w.Header().Set(httpheaders.ContentType, "image/png")
  336. w.WriteHeader(200)
  337. w.Write(data)
  338. }))
  339. defer ts.Close()
  340. req := httptest.NewRequest("GET", "/", nil)
  341. rw := httptest.NewRecorder()
  342. po := &options.ProcessingOptions{}
  343. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  344. s.Require().NoError(err)
  345. res := rw.Result()
  346. s.Require().Equal(200, res.StatusCode)
  347. s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
  348. }
  349. // TestHandlerErrorResponse tests the error responses from the streaming service.
  350. func (s *HandlerTestSuite) TestHandlerErrorResponse() {
  351. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  352. w.WriteHeader(404)
  353. w.Write([]byte("Not Found"))
  354. }))
  355. defer ts.Close()
  356. req := httptest.NewRequest("GET", "/", nil)
  357. rw := httptest.NewRecorder()
  358. po := &options.ProcessingOptions{}
  359. err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  360. s.Require().NoError(err)
  361. res := rw.Result()
  362. s.Require().Equal(404, res.StatusCode)
  363. }
  364. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
  365. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
  366. // Enable cookie passthrough for this test
  367. config.CookiePassthrough = true
  368. defer func() {
  369. config.CookiePassthrough = false // Reset after test
  370. }()
  371. // Create new handler with updated config
  372. tr, err := transport.NewTransport()
  373. s.Require().NoError(err)
  374. fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
  375. s.Require().NoError(err)
  376. handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
  377. data := s.readTestFile("test1.png")
  378. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  379. // Verify cookies are passed through
  380. cookie, cerr := r.Cookie("test_cookie")
  381. if cerr == nil {
  382. s.Equal("test_value", cookie.Value)
  383. }
  384. w.Header().Set(httpheaders.ContentType, "image/png")
  385. w.WriteHeader(200)
  386. w.Write(data)
  387. }))
  388. defer ts.Close()
  389. req := httptest.NewRequest("GET", "/", nil)
  390. req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
  391. rw := httptest.NewRecorder()
  392. po := &options.ProcessingOptions{}
  393. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  394. s.Require().NoError(err)
  395. res := rw.Result()
  396. s.Require().Equal(200, res.StatusCode)
  397. }
  398. // TestHandlerCanonicalHeader tests that the canonical header is set correctly
  399. func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
  400. data := s.readTestFile("test1.png")
  401. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  402. w.Header().Set(httpheaders.ContentType, "image/png")
  403. w.WriteHeader(200)
  404. w.Write(data)
  405. }))
  406. defer ts.Close()
  407. for _, sc := range []bool{true, false} {
  408. config.SetCanonicalHeader = sc
  409. // Create new handler with updated config
  410. tr, err := transport.NewTransport()
  411. s.Require().NoError(err)
  412. fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
  413. s.Require().NoError(err)
  414. handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
  415. req := httptest.NewRequest("GET", "/", nil)
  416. rw := httptest.NewRecorder()
  417. po := &options.ProcessingOptions{}
  418. err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
  419. s.Require().NoError(err)
  420. res := rw.Result()
  421. s.Require().Equal(200, res.StatusCode)
  422. if sc {
  423. s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
  424. } else {
  425. s.Require().Empty(res.Header.Get(httpheaders.Link))
  426. }
  427. }
  428. }
  429. func TestHandler(t *testing.T) {
  430. suite.Run(t, new(HandlerTestSuite))
  431. }