s3.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package s3
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "time"
  11. s3Crypto "github.com/aws/amazon-s3-encryption-client-go/v3/client"
  12. s3CryptoMaterials "github.com/aws/amazon-s3-encryption-client-go/v3/materials"
  13. "github.com/aws/aws-sdk-go-v2/aws"
  14. awsHttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
  15. awsConfig "github.com/aws/aws-sdk-go-v2/config"
  16. "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
  17. "github.com/aws/aws-sdk-go-v2/service/kms"
  18. "github.com/aws/aws-sdk-go-v2/service/s3"
  19. "github.com/aws/aws-sdk-go-v2/service/sts"
  20. "github.com/imgproxy/imgproxy/v3/fetcher/transport/common"
  21. "github.com/imgproxy/imgproxy/v3/httpheaders"
  22. "github.com/imgproxy/imgproxy/v3/ierrors"
  23. )
  24. type s3Client interface {
  25. GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error)
  26. }
  27. // transport implements RoundTripper for the 's3' protocol.
  28. type transport struct {
  29. clientOptions []func(*s3.Options)
  30. defaultClient s3Client
  31. defaultConfig aws.Config
  32. clientsByRegion map[string]s3Client
  33. clientsByBucket map[string]s3Client
  34. mu sync.RWMutex
  35. config *Config
  36. querySeparator string
  37. }
  38. func New(config *Config, trans *http.Transport, querySeparator string) (http.RoundTripper, error) {
  39. if err := config.Validate(); err != nil {
  40. return nil, err
  41. }
  42. conf, err := awsConfig.LoadDefaultConfig(context.Background())
  43. if err != nil {
  44. return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config"))
  45. }
  46. conf.HTTPClient = &http.Client{Transport: trans}
  47. if len(config.Region) != 0 {
  48. conf.Region = config.Region
  49. }
  50. if len(conf.Region) == 0 {
  51. conf.Region = "us-west-1"
  52. }
  53. if len(config.AssumeRoleArn) != 0 {
  54. creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.AssumeRoleArn, func(o *stscreds.AssumeRoleOptions) {
  55. if len(config.AssumeRoleExternalID) != 0 {
  56. o.ExternalID = aws.String(config.AssumeRoleExternalID)
  57. }
  58. })
  59. conf.Credentials = creds
  60. }
  61. clientOptions := []func(*s3.Options){
  62. func(o *s3.Options) {
  63. o.DisableLogOutputChecksumValidationSkipped = true
  64. },
  65. }
  66. if len(config.Endpoint) != 0 {
  67. endpoint := config.Endpoint
  68. if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
  69. endpoint = "http://" + endpoint
  70. }
  71. clientOptions = append(clientOptions, func(o *s3.Options) {
  72. o.BaseEndpoint = aws.String(endpoint)
  73. o.UsePathStyle = config.EndpointUsePathStyle
  74. })
  75. }
  76. client, err := createClient(conf, clientOptions, config)
  77. if err != nil {
  78. return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create S3 client"))
  79. }
  80. return &transport{
  81. clientOptions: clientOptions,
  82. defaultClient: client,
  83. defaultConfig: conf,
  84. clientsByRegion: map[string]s3Client{conf.Region: client},
  85. clientsByBucket: make(map[string]s3Client),
  86. config: config,
  87. querySeparator: querySeparator,
  88. }, nil
  89. }
  90. func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
  91. bucket, key, query := common.GetBucketAndKey(req.URL, t.querySeparator)
  92. if len(bucket) == 0 || len(key) == 0 {
  93. body := strings.NewReader("Invalid S3 URL: bucket name or object key is empty")
  94. return &http.Response{
  95. StatusCode: http.StatusNotFound,
  96. Proto: "HTTP/1.0",
  97. ProtoMajor: 1,
  98. ProtoMinor: 0,
  99. Header: http.Header{httpheaders.ContentType: {"text/plain"}},
  100. ContentLength: int64(body.Len()),
  101. Body: io.NopCloser(body),
  102. Close: false,
  103. Request: req,
  104. }, nil
  105. }
  106. input := &s3.GetObjectInput{
  107. Bucket: aws.String(bucket),
  108. Key: aws.String(key),
  109. }
  110. if len(query) > 0 {
  111. input.VersionId = aws.String(query)
  112. }
  113. statusCode := http.StatusOK
  114. if r := req.Header.Get("Range"); len(r) != 0 {
  115. input.Range = aws.String(r)
  116. } else {
  117. if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
  118. input.IfNoneMatch = aws.String(ifNoneMatch)
  119. }
  120. if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
  121. parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
  122. if err == nil {
  123. input.IfModifiedSince = &parsedIfModifiedSince
  124. }
  125. }
  126. }
  127. client := t.getBucketClient(bucket)
  128. output, err := client.GetObject(req.Context(), input)
  129. defer func() {
  130. if err != nil && output != nil && output.Body != nil {
  131. output.Body.Close()
  132. }
  133. }()
  134. if err != nil {
  135. // Check if the error is the region mismatch error.
  136. // If so, create a new client with the correct region and retry the request.
  137. if region := regionFromError(err); len(region) != 0 {
  138. client, err = t.createBucketClient(bucket, region)
  139. if err != nil {
  140. return handleError(req, err)
  141. }
  142. output, err = client.GetObject(req.Context(), input)
  143. }
  144. }
  145. if err != nil {
  146. return handleError(req, err)
  147. }
  148. contentLength := int64(-1)
  149. if output.ContentLength != nil {
  150. contentLength = *output.ContentLength
  151. }
  152. if t.config.DecryptionClientEnabled {
  153. if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
  154. cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
  155. if err != nil {
  156. handleError(req, err)
  157. }
  158. contentLength = cl
  159. }
  160. }
  161. header := make(http.Header)
  162. if contentLength > 0 {
  163. header.Set(httpheaders.ContentLength, strconv.FormatInt(contentLength, 10))
  164. }
  165. if output.ContentType != nil {
  166. header.Set(httpheaders.ContentType, *output.ContentType)
  167. }
  168. if output.ContentEncoding != nil {
  169. header.Set(httpheaders.ContentEncoding, *output.ContentEncoding)
  170. }
  171. if output.CacheControl != nil {
  172. header.Set(httpheaders.CacheControl, *output.CacheControl)
  173. }
  174. if output.ExpiresString != nil {
  175. header.Set(httpheaders.Expires, *output.ExpiresString)
  176. }
  177. if output.ETag != nil {
  178. header.Set(httpheaders.Etag, *output.ETag)
  179. }
  180. if output.LastModified != nil {
  181. header.Set(httpheaders.LastModified, output.LastModified.Format(http.TimeFormat))
  182. }
  183. if output.AcceptRanges != nil {
  184. header.Set(httpheaders.AcceptRanges, *output.AcceptRanges)
  185. }
  186. if output.ContentRange != nil {
  187. header.Set(httpheaders.ContentRange, *output.ContentRange)
  188. statusCode = http.StatusPartialContent
  189. }
  190. return &http.Response{
  191. StatusCode: statusCode,
  192. Proto: "HTTP/1.0",
  193. ProtoMajor: 1,
  194. ProtoMinor: 0,
  195. Header: header,
  196. ContentLength: contentLength,
  197. Body: output.Body,
  198. Close: true,
  199. Request: req,
  200. }, nil
  201. }
  202. func (t *transport) getBucketClient(bucket string) s3Client {
  203. var client s3Client
  204. func() {
  205. t.mu.RLock()
  206. defer t.mu.RUnlock()
  207. client = t.clientsByBucket[bucket]
  208. }()
  209. if client != nil {
  210. return client
  211. }
  212. return t.defaultClient
  213. }
  214. func (t *transport) createBucketClient(bucket, region string) (s3Client, error) {
  215. t.mu.Lock()
  216. defer t.mu.Unlock()
  217. // Check again if someone did this before us
  218. if client := t.clientsByBucket[bucket]; client != nil {
  219. return client, nil
  220. }
  221. if client := t.clientsByRegion[region]; client != nil {
  222. t.clientsByBucket[bucket] = client
  223. return client, nil
  224. }
  225. conf := t.defaultConfig.Copy()
  226. conf.Region = region
  227. client, err := createClient(conf, t.clientOptions, t.config)
  228. if err != nil {
  229. return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create regional S3 client"))
  230. }
  231. t.clientsByRegion[region] = client
  232. t.clientsByBucket[bucket] = client
  233. return client, nil
  234. }
  235. func createClient(conf aws.Config, opts []func(*s3.Options), config *Config) (s3Client, error) {
  236. client := s3.NewFromConfig(conf, opts...)
  237. if config.DecryptionClientEnabled {
  238. kmsClient := kms.NewFromConfig(conf)
  239. keyring := s3CryptoMaterials.NewKmsDecryptOnlyAnyKeyKeyring(kmsClient)
  240. cmm, err := s3CryptoMaterials.NewCryptographicMaterialsManager(keyring)
  241. if err != nil {
  242. return nil, err
  243. }
  244. return s3Crypto.New(client, cmm)
  245. } else {
  246. return client, nil
  247. }
  248. }
  249. func regionFromError(err error) string {
  250. var rerr *awsHttp.ResponseError
  251. if !errors.As(err, &rerr) {
  252. return ""
  253. }
  254. if rerr.Response == nil || rerr.Response.StatusCode != 301 {
  255. return ""
  256. }
  257. return rerr.Response.Header.Get("X-Amz-Bucket-Region")
  258. }
  259. func handleError(req *http.Request, err error) (*http.Response, error) {
  260. var rerr *awsHttp.ResponseError
  261. if !errors.As(err, &rerr) {
  262. return nil, ierrors.Wrap(err, 0)
  263. }
  264. if rerr.Response == nil || rerr.Response.StatusCode < 100 || rerr.Response.StatusCode == 301 {
  265. return nil, ierrors.Wrap(err, 0)
  266. }
  267. return &http.Response{
  268. StatusCode: rerr.Response.StatusCode,
  269. Proto: "HTTP/1.0",
  270. ProtoMajor: 1,
  271. ProtoMinor: 0,
  272. Header: http.Header{"Content-Type": {"text/plain"}},
  273. ContentLength: int64(len(err.Error())),
  274. Body: io.NopCloser(strings.NewReader(err.Error())),
  275. Close: false,
  276. Request: req,
  277. }, nil
  278. }