123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- package s3
- import (
- "context"
- "errors"
- "io"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "time"
- s3Crypto "github.com/aws/amazon-s3-encryption-client-go/v3/client"
- s3CryptoMaterials "github.com/aws/amazon-s3-encryption-client-go/v3/materials"
- "github.com/aws/aws-sdk-go-v2/aws"
- awsHttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
- awsConfig "github.com/aws/aws-sdk-go-v2/config"
- "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
- "github.com/aws/aws-sdk-go-v2/service/kms"
- "github.com/aws/aws-sdk-go-v2/service/s3"
- "github.com/aws/aws-sdk-go-v2/service/sts"
- "github.com/imgproxy/imgproxy/v3/config"
- "github.com/imgproxy/imgproxy/v3/ierrors"
- defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
- "github.com/imgproxy/imgproxy/v3/transport/common"
- )
- type s3Client interface {
- GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error)
- }
- // transport implements RoundTripper for the 's3' protocol.
- type transport struct {
- clientOptions []func(*s3.Options)
- defaultClient s3Client
- defaultConfig aws.Config
- clientsByRegion map[string]s3Client
- clientsByBucket map[string]s3Client
- mu sync.RWMutex
- }
- func New() (http.RoundTripper, error) {
- conf, err := awsConfig.LoadDefaultConfig(context.Background())
- if err != nil {
- return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config"))
- }
- trans, err := defaultTransport.New(false)
- if err != nil {
- return nil, err
- }
- conf.HTTPClient = &http.Client{Transport: trans}
- if len(config.S3Region) != 0 {
- conf.Region = config.S3Region
- }
- if len(conf.Region) == 0 {
- conf.Region = "us-west-1"
- }
- if len(config.S3AssumeRoleArn) != 0 {
- creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.S3AssumeRoleArn, func(o *stscreds.AssumeRoleOptions) {
- if len(config.S3AssumeRoleExternalID) != 0 {
- o.ExternalID = aws.String(config.S3AssumeRoleExternalID)
- }
- })
- conf.Credentials = creds
- }
- clientOptions := []func(*s3.Options){
- func(o *s3.Options) {
- o.DisableLogOutputChecksumValidationSkipped = true
- },
- }
- if len(config.S3Endpoint) != 0 {
- endpoint := config.S3Endpoint
- if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
- endpoint = "http://" + endpoint
- }
- clientOptions = append(clientOptions, func(o *s3.Options) {
- o.BaseEndpoint = aws.String(endpoint)
- o.UsePathStyle = config.S3EndpointUsePathStyle
- })
- }
- client, err := createClient(conf, clientOptions)
- if err != nil {
- return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create S3 client"))
- }
- return &transport{
- clientOptions: clientOptions,
- defaultClient: client,
- defaultConfig: conf,
- clientsByRegion: map[string]s3Client{conf.Region: client},
- clientsByBucket: make(map[string]s3Client),
- }, nil
- }
- func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
- bucket, key, query := common.GetBucketAndKey(req.URL)
- if len(bucket) == 0 || len(key) == 0 {
- body := strings.NewReader("Invalid S3 URL: bucket name or object key is empty")
- return &http.Response{
- StatusCode: http.StatusNotFound,
- Proto: "HTTP/1.0",
- ProtoMajor: 1,
- ProtoMinor: 0,
- Header: http.Header{"Content-Type": {"text/plain"}},
- ContentLength: int64(body.Len()),
- Body: io.NopCloser(body),
- Close: false,
- Request: req,
- }, nil
- }
- input := &s3.GetObjectInput{
- Bucket: aws.String(bucket),
- Key: aws.String(key),
- }
- if len(query) > 0 {
- input.VersionId = aws.String(query)
- }
- statusCode := http.StatusOK
- if r := req.Header.Get("Range"); len(r) != 0 {
- input.Range = aws.String(r)
- } else {
- if config.ETagEnabled {
- if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
- input.IfNoneMatch = aws.String(ifNoneMatch)
- }
- }
- if config.LastModifiedEnabled {
- if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
- parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
- if err == nil {
- input.IfModifiedSince = &parsedIfModifiedSince
- }
- }
- }
- }
- client := t.getBucketClient(bucket)
- output, err := client.GetObject(req.Context(), input)
- defer func() {
- if err != nil && output != nil && output.Body != nil {
- output.Body.Close()
- }
- }()
- if err != nil {
- // Check if the error is the region mismatch error.
- // If so, create a new client with the correct region and retry the request.
- if region := regionFromError(err); len(region) != 0 {
- client, err = t.createBucketClient(bucket, region)
- if err != nil {
- return handleError(req, err)
- }
- output, err = client.GetObject(req.Context(), input)
- }
- }
- if err != nil {
- return handleError(req, err)
- }
- contentLength := int64(-1)
- if output.ContentLength != nil {
- contentLength = *output.ContentLength
- }
- if config.S3DecryptionClientEnabled {
- if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
- cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
- if err != nil {
- handleError(req, err)
- }
- contentLength = cl
- }
- }
- header := make(http.Header)
- if contentLength > 0 {
- header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
- }
- if output.ContentType != nil {
- header.Set("Content-Type", *output.ContentType)
- }
- if output.ContentEncoding != nil {
- header.Set("Content-Encoding", *output.ContentEncoding)
- }
- if output.CacheControl != nil {
- header.Set("Cache-Control", *output.CacheControl)
- }
- if output.ExpiresString != nil {
- header.Set("Expires", *output.ExpiresString)
- }
- if output.ETag != nil {
- header.Set("ETag", *output.ETag)
- }
- if output.LastModified != nil {
- header.Set("Last-Modified", output.LastModified.Format(http.TimeFormat))
- }
- if output.AcceptRanges != nil {
- header.Set("Accept-Ranges", *output.AcceptRanges)
- }
- if output.ContentRange != nil {
- header.Set("Content-Range", *output.ContentRange)
- statusCode = http.StatusPartialContent
- }
- return &http.Response{
- StatusCode: statusCode,
- Proto: "HTTP/1.0",
- ProtoMajor: 1,
- ProtoMinor: 0,
- Header: header,
- ContentLength: contentLength,
- Body: output.Body,
- Close: true,
- Request: req,
- }, nil
- }
- func (t *transport) getBucketClient(bucket string) s3Client {
- var client s3Client
- func() {
- t.mu.RLock()
- defer t.mu.RUnlock()
- client = t.clientsByBucket[bucket]
- }()
- if client != nil {
- return client
- }
- return t.defaultClient
- }
- func (t *transport) createBucketClient(bucket, region string) (s3Client, error) {
- t.mu.Lock()
- defer t.mu.Unlock()
- // Check again if someone did this before us
- if client := t.clientsByBucket[bucket]; client != nil {
- return client, nil
- }
- if client := t.clientsByRegion[region]; client != nil {
- t.clientsByBucket[bucket] = client
- return client, nil
- }
- conf := t.defaultConfig.Copy()
- conf.Region = region
- client, err := createClient(conf, t.clientOptions)
- if err != nil {
- return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create regional S3 client"))
- }
- t.clientsByRegion[region] = client
- t.clientsByBucket[bucket] = client
- return client, nil
- }
- func createClient(conf aws.Config, opts []func(*s3.Options)) (s3Client, error) {
- client := s3.NewFromConfig(conf, opts...)
- if config.S3DecryptionClientEnabled {
- kmsClient := kms.NewFromConfig(conf)
- keyring := s3CryptoMaterials.NewKmsDecryptOnlyAnyKeyKeyring(kmsClient)
- cmm, err := s3CryptoMaterials.NewCryptographicMaterialsManager(keyring)
- if err != nil {
- return nil, err
- }
- return s3Crypto.New(client, cmm)
- } else {
- return client, nil
- }
- }
- func regionFromError(err error) string {
- var rerr *awsHttp.ResponseError
- if !errors.As(err, &rerr) {
- return ""
- }
- if rerr.Response == nil || rerr.Response.StatusCode != 301 {
- return ""
- }
- return rerr.Response.Header.Get("X-Amz-Bucket-Region")
- }
- func handleError(req *http.Request, err error) (*http.Response, error) {
- var rerr *awsHttp.ResponseError
- if !errors.As(err, &rerr) {
- return nil, ierrors.Wrap(err, 0)
- }
- if rerr.Response == nil || rerr.Response.StatusCode < 100 || rerr.Response.StatusCode == 301 {
- return nil, ierrors.Wrap(err, 0)
- }
- return &http.Response{
- StatusCode: rerr.Response.StatusCode,
- Proto: "HTTP/1.0",
- ProtoMajor: 1,
- ProtoMinor: 0,
- Header: http.Header{"Content-Type": {"text/plain"}},
- ContentLength: int64(len(err.Error())),
- Body: io.NopCloser(strings.NewReader(err.Error())),
- Close: false,
- Request: req,
- }, nil
- }
|