|
@@ -2,6 +2,7 @@ package s3
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
@@ -10,29 +11,32 @@ import (
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
- "github.com/aws/aws-sdk-go/aws"
|
|
|
- "github.com/aws/aws-sdk-go/aws/awserr"
|
|
|
- "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
|
- "github.com/aws/aws-sdk-go/aws/request"
|
|
|
- "github.com/aws/aws-sdk-go/aws/session"
|
|
|
- "github.com/aws/aws-sdk-go/service/kms"
|
|
|
- "github.com/aws/aws-sdk-go/service/s3"
|
|
|
- "github.com/aws/aws-sdk-go/service/s3/s3crypto"
|
|
|
- "github.com/aws/aws-sdk-go/service/s3/s3manager"
|
|
|
+ s3Crypto "github.com/aws/amazon-s3-encryption-client-go/v3/client"
|
|
|
+ s3CryptoMaterials "github.com/aws/amazon-s3-encryption-client-go/v3/materials"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/aws"
|
|
|
+ awsHttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
|
|
|
+ awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
|
|
+ s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/kms"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/s3"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/sts"
|
|
|
|
|
|
"github.com/imgproxy/imgproxy/v3/config"
|
|
|
defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
|
|
|
)
|
|
|
|
|
|
type s3Client interface {
|
|
|
- GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
|
|
|
+ GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error)
|
|
|
+ HeadBucket(ctx context.Context, input *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error)
|
|
|
}
|
|
|
|
|
|
// transport implements RoundTripper for the 's3' protocol.
|
|
|
type transport struct {
|
|
|
- session *session.Session
|
|
|
+ clientOptions []func(*s3.Options)
|
|
|
+
|
|
|
defaultClient s3Client
|
|
|
- defaultConfig *aws.Config
|
|
|
+ defaultConfig aws.Config
|
|
|
|
|
|
clientsByRegion map[string]s3Client
|
|
|
clientsByBucket map[string]s3Client
|
|
@@ -41,7 +45,10 @@ type transport struct {
|
|
|
}
|
|
|
|
|
|
func New() (http.RoundTripper, error) {
|
|
|
- conf := aws.NewConfig()
|
|
|
+ conf, err := awsConfig.LoadDefaultConfig(context.Background())
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("can't load AWS S3 config: %s", err)
|
|
|
+ }
|
|
|
|
|
|
trans, err := defaultTransport.New(false)
|
|
|
if err != nil {
|
|
@@ -50,40 +57,38 @@ func New() (http.RoundTripper, error) {
|
|
|
|
|
|
conf.HTTPClient = &http.Client{Transport: trans}
|
|
|
|
|
|
- if len(config.S3Endpoint) != 0 {
|
|
|
- conf.Endpoint = aws.String(config.S3Endpoint)
|
|
|
- conf.S3ForcePathStyle = aws.Bool(true)
|
|
|
+ if len(config.S3Region) != 0 {
|
|
|
+ conf.Region = config.S3Region
|
|
|
}
|
|
|
|
|
|
- sess, err := session.NewSession()
|
|
|
- if err != nil {
|
|
|
- return nil, fmt.Errorf("can't create S3 session: %s", err)
|
|
|
+ if len(conf.Region) == 0 {
|
|
|
+ conf.Region = "us-west-1"
|
|
|
}
|
|
|
|
|
|
- if len(config.S3Region) != 0 {
|
|
|
- sess.Config.Region = aws.String(config.S3Region)
|
|
|
+ if len(config.S3AssumeRoleArn) != 0 {
|
|
|
+ creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.S3AssumeRoleArn)
|
|
|
+ conf.Credentials = creds
|
|
|
}
|
|
|
|
|
|
- if sess.Config.Region == nil || len(*sess.Config.Region) == 0 {
|
|
|
- sess.Config.Region = aws.String("us-west-1")
|
|
|
- }
|
|
|
+ clientOptions := []func(*s3.Options){}
|
|
|
|
|
|
- if len(config.S3AssumeRoleArn) != 0 {
|
|
|
- conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
|
|
|
+ if len(config.S3Endpoint) != 0 {
|
|
|
+ clientOptions = append(clientOptions, func(o *s3.Options) {
|
|
|
+ o.BaseEndpoint = aws.String(config.S3Endpoint)
|
|
|
+ o.UsePathStyle = true
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
- client, err := createClient(sess, conf)
|
|
|
+ client, err := createClient(conf, clientOptions)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("can't create S3 client: %s", err)
|
|
|
}
|
|
|
|
|
|
- clientRegion := *sess.Config.Region
|
|
|
-
|
|
|
return &transport{
|
|
|
- session: sess,
|
|
|
+ clientOptions: clientOptions,
|
|
|
defaultClient: client,
|
|
|
defaultConfig: conf,
|
|
|
- clientsByRegion: map[string]s3Client{clientRegion: client},
|
|
|
+ clientsByRegion: map[string]s3Client{conf.Region: client},
|
|
|
clientsByBucket: make(map[string]s3Client),
|
|
|
}, nil
|
|
|
}
|
|
@@ -91,13 +96,15 @@ func New() (http.RoundTripper, error) {
|
|
|
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
input := &s3.GetObjectInput{
|
|
|
Bucket: aws.String(req.URL.Host),
|
|
|
- Key: aws.String(req.URL.Path),
|
|
|
+ Key: aws.String(strings.TrimPrefix(req.URL.Path, "/")),
|
|
|
}
|
|
|
|
|
|
if len(req.URL.RawQuery) > 0 {
|
|
|
input.VersionId = aws.String(req.URL.RawQuery)
|
|
|
}
|
|
|
|
|
|
+ statusCode := http.StatusOK
|
|
|
+
|
|
|
if r := req.Header.Get("Range"); len(r) != 0 {
|
|
|
input.Range = aws.String(r)
|
|
|
} else {
|
|
@@ -121,30 +128,71 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
return handleError(req, err)
|
|
|
}
|
|
|
|
|
|
- s3req, objectOutput := client.GetObjectRequest(input)
|
|
|
- s3req.SetContext(req.Context())
|
|
|
-
|
|
|
- if err := s3req.Send(); err != nil {
|
|
|
- if s3req.HTTPResponse != nil && s3req.HTTPResponse.Body != nil {
|
|
|
- s3req.HTTPResponse.Body.Close()
|
|
|
+ output, err := client.GetObject(req.Context(), input)
|
|
|
+ if err != nil {
|
|
|
+ if output != nil && output.Body != nil {
|
|
|
+ output.Body.Close()
|
|
|
}
|
|
|
|
|
|
return handleError(req, err)
|
|
|
}
|
|
|
|
|
|
- if config.S3DecryptionClientEnabled {
|
|
|
- s3req.HTTPResponse.Body = objectOutput.Body
|
|
|
+ contentLength := int64(-1)
|
|
|
+ if output.ContentLength != nil {
|
|
|
+ contentLength = *output.ContentLength
|
|
|
+ }
|
|
|
|
|
|
- if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
|
|
|
- contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
|
|
+ if config.S3DecryptionClientEnabled {
|
|
|
+ if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
|
|
|
+ cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
|
|
if err != nil {
|
|
|
handleError(req, err)
|
|
|
}
|
|
|
- s3req.HTTPResponse.ContentLength = contentLength
|
|
|
+ contentLength = cl
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return s3req.HTTPResponse, nil
|
|
|
+ header := make(http.Header)
|
|
|
+ if contentLength > 0 {
|
|
|
+ header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
|
|
+ }
|
|
|
+ if output.ContentType != nil {
|
|
|
+ header.Set("Content-Type", *output.ContentType)
|
|
|
+ }
|
|
|
+ if output.ContentEncoding != nil {
|
|
|
+ header.Set("Content-Encoding", *output.ContentEncoding)
|
|
|
+ }
|
|
|
+ if output.CacheControl != nil {
|
|
|
+ header.Set("Cache-Control", *output.CacheControl)
|
|
|
+ }
|
|
|
+ if output.Expires != nil {
|
|
|
+ header.Set("Expires", output.Expires.Format(http.TimeFormat))
|
|
|
+ }
|
|
|
+ if output.ETag != nil {
|
|
|
+ header.Set("ETag", *output.ETag)
|
|
|
+ }
|
|
|
+ if output.LastModified != nil {
|
|
|
+ header.Set("Last-Modified", output.LastModified.Format(http.TimeFormat))
|
|
|
+ }
|
|
|
+ if output.AcceptRanges != nil {
|
|
|
+ header.Set("Accept-Ranges", *output.AcceptRanges)
|
|
|
+ }
|
|
|
+ if output.ContentRange != nil {
|
|
|
+ header.Set("Content-Range", *output.ContentRange)
|
|
|
+ statusCode = http.StatusPartialContent
|
|
|
+ }
|
|
|
+
|
|
|
+ return &http.Response{
|
|
|
+ StatusCode: statusCode,
|
|
|
+ Proto: "HTTP/1.0",
|
|
|
+ ProtoMajor: 1,
|
|
|
+ ProtoMinor: 0,
|
|
|
+ Header: header,
|
|
|
+ ContentLength: contentLength,
|
|
|
+ Body: output.Body,
|
|
|
+ Close: true,
|
|
|
+ Request: req,
|
|
|
+ }, nil
|
|
|
}
|
|
|
|
|
|
func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
|
|
@@ -172,9 +220,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, err
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
- region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
|
|
|
+ region, err := s3Manager.GetBucketRegion(ctx, t.defaultClient, bucket)
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
+ return nil, fmt.Errorf("can't get bucket region: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(region) == 0 {
|
|
|
+ region = t.defaultConfig.Region
|
|
|
}
|
|
|
|
|
|
if client = t.clientsByRegion[region]; client != nil {
|
|
@@ -183,9 +235,9 @@ func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, err
|
|
|
}
|
|
|
|
|
|
conf := t.defaultConfig.Copy()
|
|
|
- conf.Region = aws.String(region)
|
|
|
+ conf.Region = region
|
|
|
|
|
|
- client, err = createClient(t.session, conf)
|
|
|
+ client, err = createClient(conf, t.clientOptions)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("can't create regional S3 client: %s", err)
|
|
|
}
|
|
@@ -196,53 +248,38 @@ func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, err
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
-func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
|
|
|
+func createClient(conf aws.Config, opts []func(*s3.Options)) (s3Client, error) {
|
|
|
+ client := s3.NewFromConfig(conf, opts...)
|
|
|
+
|
|
|
if config.S3DecryptionClientEnabled {
|
|
|
- // `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
|
|
|
- // need to copy the session with an additional config
|
|
|
- sess = sess.Copy(conf)
|
|
|
+ kmsClient := kms.NewFromConfig(conf)
|
|
|
+ keyring := s3CryptoMaterials.NewKmsDecryptOnlyAnyKeyKeyring(kmsClient)
|
|
|
|
|
|
- cryptoRegistry, err := createCryptoRegistry(sess)
|
|
|
+ cmm, err := s3CryptoMaterials.NewCryptographicMaterialsManager(keyring)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
|
|
|
+ return s3Crypto.New(client, cmm)
|
|
|
} else {
|
|
|
- return s3.New(sess, conf), nil
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
|
|
|
- kmsClient := kms.New(sess)
|
|
|
-
|
|
|
- cr := s3crypto.NewCryptoRegistry()
|
|
|
- if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
|
|
|
- return nil, err
|
|
|
+ return client, nil
|
|
|
}
|
|
|
-
|
|
|
- return cr, nil
|
|
|
}
|
|
|
|
|
|
func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
- if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
|
|
|
- if e := s3err.OrigErr(); e != nil {
|
|
|
- return nil, e
|
|
|
- }
|
|
|
+ var rerr *awsHttp.ResponseError
|
|
|
+ if !errors.As(err, &rerr) {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- s3err, ok := err.(awserr.RequestFailure)
|
|
|
- if !ok || s3err.StatusCode() < 100 || s3err.StatusCode() == 301 {
|
|
|
+ if rerr.Response == nil || rerr.Response.StatusCode < 100 || rerr.Response.StatusCode == 301 {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- body := strings.NewReader(s3err.Message())
|
|
|
+ body := strings.NewReader(err.Error())
|
|
|
|
|
|
return &http.Response{
|
|
|
- StatusCode: s3err.StatusCode(),
|
|
|
+ StatusCode: rerr.Response.StatusCode,
|
|
|
Proto: "HTTP/1.0",
|
|
|
ProtoMajor: 1,
|
|
|
ProtoMinor: 0,
|