s3.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package s3
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. s3Crypto "github.com/aws/amazon-s3-encryption-client-go/v3/client"
  13. s3CryptoMaterials "github.com/aws/amazon-s3-encryption-client-go/v3/materials"
  14. "github.com/aws/aws-sdk-go-v2/aws"
  15. awsHttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
  16. awsConfig "github.com/aws/aws-sdk-go-v2/config"
  17. "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
  18. s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
  19. "github.com/aws/aws-sdk-go-v2/service/kms"
  20. "github.com/aws/aws-sdk-go-v2/service/s3"
  21. "github.com/aws/aws-sdk-go-v2/service/sts"
  22. "github.com/imgproxy/imgproxy/v3/config"
  23. defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
  24. )
  25. type s3Client interface {
  26. GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error)
  27. HeadBucket(ctx context.Context, input *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error)
  28. }
  29. // transport implements RoundTripper for the 's3' protocol.
  30. type transport struct {
  31. clientOptions []func(*s3.Options)
  32. defaultClient s3Client
  33. defaultConfig aws.Config
  34. clientsByRegion map[string]s3Client
  35. clientsByBucket map[string]s3Client
  36. mu sync.RWMutex
  37. }
  38. func New() (http.RoundTripper, error) {
  39. conf, err := awsConfig.LoadDefaultConfig(context.Background())
  40. if err != nil {
  41. return nil, fmt.Errorf("can't load AWS S3 config: %s", err)
  42. }
  43. trans, err := defaultTransport.New(false)
  44. if err != nil {
  45. return nil, err
  46. }
  47. conf.HTTPClient = &http.Client{Transport: trans}
  48. if len(config.S3Region) != 0 {
  49. conf.Region = config.S3Region
  50. }
  51. if len(conf.Region) == 0 {
  52. conf.Region = "us-west-1"
  53. }
  54. if len(config.S3AssumeRoleArn) != 0 {
  55. creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.S3AssumeRoleArn)
  56. conf.Credentials = creds
  57. }
  58. clientOptions := []func(*s3.Options){}
  59. if len(config.S3Endpoint) != 0 {
  60. clientOptions = append(clientOptions, func(o *s3.Options) {
  61. o.BaseEndpoint = aws.String(config.S3Endpoint)
  62. o.UsePathStyle = true
  63. })
  64. }
  65. client, err := createClient(conf, clientOptions)
  66. if err != nil {
  67. return nil, fmt.Errorf("can't create S3 client: %s", err)
  68. }
  69. return &transport{
  70. clientOptions: clientOptions,
  71. defaultClient: client,
  72. defaultConfig: conf,
  73. clientsByRegion: map[string]s3Client{conf.Region: client},
  74. clientsByBucket: make(map[string]s3Client),
  75. }, nil
  76. }
  77. func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
  78. input := &s3.GetObjectInput{
  79. Bucket: aws.String(req.URL.Host),
  80. Key: aws.String(strings.TrimPrefix(req.URL.Path, "/")),
  81. }
  82. if len(req.URL.RawQuery) > 0 {
  83. input.VersionId = aws.String(req.URL.RawQuery)
  84. }
  85. statusCode := http.StatusOK
  86. if r := req.Header.Get("Range"); len(r) != 0 {
  87. input.Range = aws.String(r)
  88. } else {
  89. if config.ETagEnabled {
  90. if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
  91. input.IfNoneMatch = aws.String(ifNoneMatch)
  92. }
  93. }
  94. if config.LastModifiedEnabled {
  95. if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
  96. parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
  97. if err == nil {
  98. input.IfModifiedSince = &parsedIfModifiedSince
  99. }
  100. }
  101. }
  102. }
  103. client, err := t.getClient(req.Context(), *input.Bucket)
  104. if err != nil {
  105. return handleError(req, err)
  106. }
  107. output, err := client.GetObject(req.Context(), input)
  108. if err != nil {
  109. if output != nil && output.Body != nil {
  110. output.Body.Close()
  111. }
  112. return handleError(req, err)
  113. }
  114. contentLength := int64(-1)
  115. if output.ContentLength != nil {
  116. contentLength = *output.ContentLength
  117. }
  118. if config.S3DecryptionClientEnabled {
  119. if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
  120. cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
  121. if err != nil {
  122. handleError(req, err)
  123. }
  124. contentLength = cl
  125. }
  126. }
  127. header := make(http.Header)
  128. if contentLength > 0 {
  129. header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
  130. }
  131. if output.ContentType != nil {
  132. header.Set("Content-Type", *output.ContentType)
  133. }
  134. if output.ContentEncoding != nil {
  135. header.Set("Content-Encoding", *output.ContentEncoding)
  136. }
  137. if output.CacheControl != nil {
  138. header.Set("Cache-Control", *output.CacheControl)
  139. }
  140. if output.Expires != nil {
  141. header.Set("Expires", output.Expires.Format(http.TimeFormat))
  142. }
  143. if output.ETag != nil {
  144. header.Set("ETag", *output.ETag)
  145. }
  146. if output.LastModified != nil {
  147. header.Set("Last-Modified", output.LastModified.Format(http.TimeFormat))
  148. }
  149. if output.AcceptRanges != nil {
  150. header.Set("Accept-Ranges", *output.AcceptRanges)
  151. }
  152. if output.ContentRange != nil {
  153. header.Set("Content-Range", *output.ContentRange)
  154. statusCode = http.StatusPartialContent
  155. }
  156. return &http.Response{
  157. StatusCode: statusCode,
  158. Proto: "HTTP/1.0",
  159. ProtoMajor: 1,
  160. ProtoMinor: 0,
  161. Header: header,
  162. ContentLength: contentLength,
  163. Body: output.Body,
  164. Close: true,
  165. Request: req,
  166. }, nil
  167. }
  168. func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
  169. if !config.S3MultiRegion {
  170. return t.defaultClient, nil
  171. }
  172. var client s3Client
  173. func() {
  174. t.mu.RLock()
  175. defer t.mu.RUnlock()
  176. client = t.clientsByBucket[bucket]
  177. }()
  178. if client != nil {
  179. return client, nil
  180. }
  181. t.mu.Lock()
  182. defer t.mu.Unlock()
  183. // Check again if someone did this before us
  184. if client = t.clientsByBucket[bucket]; client != nil {
  185. return client, nil
  186. }
  187. region, err := s3Manager.GetBucketRegion(ctx, t.defaultClient, bucket)
  188. if err != nil {
  189. return nil, fmt.Errorf("can't get bucket region: %s", err)
  190. }
  191. if len(region) == 0 {
  192. region = t.defaultConfig.Region
  193. }
  194. if client = t.clientsByRegion[region]; client != nil {
  195. t.clientsByBucket[bucket] = client
  196. return client, nil
  197. }
  198. conf := t.defaultConfig.Copy()
  199. conf.Region = region
  200. client, err = createClient(conf, t.clientOptions)
  201. if err != nil {
  202. return nil, fmt.Errorf("can't create regional S3 client: %s", err)
  203. }
  204. t.clientsByRegion[region] = client
  205. t.clientsByBucket[bucket] = client
  206. return client, nil
  207. }
  208. func createClient(conf aws.Config, opts []func(*s3.Options)) (s3Client, error) {
  209. client := s3.NewFromConfig(conf, opts...)
  210. if config.S3DecryptionClientEnabled {
  211. kmsClient := kms.NewFromConfig(conf)
  212. keyring := s3CryptoMaterials.NewKmsDecryptOnlyAnyKeyKeyring(kmsClient)
  213. cmm, err := s3CryptoMaterials.NewCryptographicMaterialsManager(keyring)
  214. if err != nil {
  215. return nil, err
  216. }
  217. return s3Crypto.New(client, cmm)
  218. } else {
  219. return client, nil
  220. }
  221. }
  222. func handleError(req *http.Request, err error) (*http.Response, error) {
  223. var rerr *awsHttp.ResponseError
  224. if !errors.As(err, &rerr) {
  225. return nil, err
  226. }
  227. if rerr.Response == nil || rerr.Response.StatusCode < 100 || rerr.Response.StatusCode == 301 {
  228. return nil, err
  229. }
  230. body := strings.NewReader(err.Error())
  231. return &http.Response{
  232. StatusCode: rerr.Response.StatusCode,
  233. Proto: "HTTP/1.0",
  234. ProtoMajor: 1,
  235. ProtoMinor: 0,
  236. Header: http.Header{},
  237. ContentLength: int64(body.Len()),
  238. Body: io.NopCloser(body),
  239. Close: false,
  240. Request: req,
  241. }, nil
  242. }