s3_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package s3
  2. import (
  3. "bytes"
  4. "context"
  5. "net/http"
  6. "net/http/httptest"
  7. "os"
  8. "testing"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. "github.com/johannesboyne/gofakes3"
  13. "github.com/johannesboyne/gofakes3/backend/s3mem"
  14. "github.com/stretchr/testify/suite"
  15. "github.com/imgproxy/imgproxy/v3/config"
  16. )
  17. type S3TestSuite struct {
  18. suite.Suite
  19. server *httptest.Server
  20. transport http.RoundTripper
  21. etag string
  22. lastModified time.Time
  23. }
  24. func (s *S3TestSuite) SetupSuite() {
  25. backend := s3mem.New()
  26. faker := gofakes3.New(backend)
  27. s.server = httptest.NewServer(faker.Server())
  28. config.S3Enabled = true
  29. config.S3Endpoint = s.server.URL
  30. os.Setenv("AWS_REGION", "eu-central-1")
  31. os.Setenv("AWS_ACCESS_KEY_ID", "Foo")
  32. os.Setenv("AWS_SECRET_ACCESS_KEY", "Bar")
  33. var err error
  34. s.transport, err = New()
  35. s.Require().NoError(err)
  36. err = backend.CreateBucket("test")
  37. s.Require().NoError(err)
  38. svc, err := s.transport.(*transport).getClient(context.Background(), "test")
  39. s.Require().NoError(err)
  40. s.Require().NotNil(svc)
  41. s.Require().IsType(&s3.Client{}, svc)
  42. client := svc.(*s3.Client)
  43. _, err = client.PutObject(context.Background(), &s3.PutObjectInput{
  44. Body: bytes.NewReader(make([]byte, 32)),
  45. Bucket: aws.String("test"),
  46. Key: aws.String("foo/test.png"),
  47. })
  48. s.Require().NoError(err)
  49. obj, err := client.GetObject(context.Background(), &s3.GetObjectInput{
  50. Bucket: aws.String("test"),
  51. Key: aws.String("foo/test.png"),
  52. })
  53. s.Require().NoError(err)
  54. defer obj.Body.Close()
  55. s.etag = *obj.ETag
  56. s.lastModified = *obj.LastModified
  57. }
  58. func (s *S3TestSuite) TearDownSuite() {
  59. s.server.Close()
  60. config.Reset()
  61. }
  62. func (s *S3TestSuite) TestRoundTripWithETagDisabledReturns200() {
  63. config.ETagEnabled = false
  64. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  65. response, err := s.transport.RoundTrip(request)
  66. s.Require().NoError(err)
  67. s.Require().Equal(200, response.StatusCode)
  68. }
  69. func (s *S3TestSuite) TestRoundTripWithETagEnabled() {
  70. config.ETagEnabled = true
  71. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  72. response, err := s.transport.RoundTrip(request)
  73. s.Require().NoError(err)
  74. s.Require().Equal(200, response.StatusCode)
  75. s.Require().Equal(s.etag, response.Header.Get("ETag"))
  76. }
  77. func (s *S3TestSuite) TestRoundTripWithIfNoneMatchReturns304() {
  78. config.ETagEnabled = true
  79. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  80. request.Header.Set("If-None-Match", s.etag)
  81. response, err := s.transport.RoundTrip(request)
  82. s.Require().NoError(err)
  83. s.Require().Equal(http.StatusNotModified, response.StatusCode)
  84. }
  85. func (s *S3TestSuite) TestRoundTripWithUpdatedETagReturns200() {
  86. config.ETagEnabled = true
  87. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  88. request.Header.Set("If-None-Match", s.etag+"_wrong")
  89. response, err := s.transport.RoundTrip(request)
  90. s.Require().NoError(err)
  91. s.Require().Equal(http.StatusOK, response.StatusCode)
  92. }
  93. func (s *S3TestSuite) TestRoundTripWithLastModifiedDisabledReturns200() {
  94. config.LastModifiedEnabled = false
  95. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  96. response, err := s.transport.RoundTrip(request)
  97. s.Require().NoError(err)
  98. s.Require().Equal(200, response.StatusCode)
  99. }
  100. func (s *S3TestSuite) TestRoundTripWithLastModifiedEnabled() {
  101. config.LastModifiedEnabled = true
  102. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  103. response, err := s.transport.RoundTrip(request)
  104. s.Require().NoError(err)
  105. s.Require().Equal(200, response.StatusCode)
  106. s.Require().Equal(s.lastModified.Format(http.TimeFormat), response.Header.Get("Last-Modified"))
  107. }
  108. func (s *S3TestSuite) TestRoundTripWithIfModifiedSinceReturns304() {
  109. config.LastModifiedEnabled = true
  110. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  111. request.Header.Set("If-Modified-Since", s.lastModified.Format(http.TimeFormat))
  112. response, err := s.transport.RoundTrip(request)
  113. s.Require().NoError(err)
  114. s.Require().Equal(http.StatusNotModified, response.StatusCode)
  115. }
  116. func (s *S3TestSuite) TestRoundTripWithUpdatedLastModifiedReturns200() {
  117. config.LastModifiedEnabled = true
  118. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  119. request.Header.Set("If-Modified-Since", s.lastModified.Add(-24*time.Hour).Format(http.TimeFormat))
  120. response, err := s.transport.RoundTrip(request)
  121. s.Require().NoError(err)
  122. s.Require().Equal(http.StatusOK, response.StatusCode)
  123. }
  124. func (s *S3TestSuite) TestRoundTripWithMultiregionEnabledReturns200() {
  125. config.S3MultiRegion = true
  126. request, _ := http.NewRequest("GET", "s3://test/foo/test.png", nil)
  127. response, err := s.transport.RoundTrip(request)
  128. s.Require().NoError(err)
  129. s.Require().Equal(200, response.StatusCode)
  130. }
  131. func TestS3Transport(t *testing.T) {
  132. suite.Run(t, new(S3TestSuite))
  133. }