|
@@ -4,7 +4,8 @@ import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- http "net/http"
|
|
|
+ "net/http"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
@@ -14,20 +15,27 @@ import (
|
|
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
|
+ "github.com/aws/aws-sdk-go/service/kms"
|
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
|
+ "github.com/aws/aws-sdk-go/service/s3/s3crypto"
|
|
|
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
|
|
|
|
|
"github.com/imgproxy/imgproxy/v3/config"
|
|
|
defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
|
|
|
)
|
|
|
|
|
|
+type s3Client interface {
|
|
|
+ GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
|
|
|
+}
|
|
|
+
|
|
|
// transport implements RoundTripper for the 's3' protocol.
|
|
|
type transport struct {
|
|
|
session *session.Session
|
|
|
- defaultClient *s3.S3
|
|
|
+ defaultClient s3Client
|
|
|
+ defaultConfig *aws.Config
|
|
|
|
|
|
- clientsByRegion map[string]*s3.S3
|
|
|
- clientsByBucket map[string]*s3.S3
|
|
|
+ clientsByRegion map[string]s3Client
|
|
|
+ clientsByBucket map[string]s3Client
|
|
|
|
|
|
mu sync.RWMutex
|
|
|
}
|
|
@@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) {
|
|
|
|
|
|
sess, err := session.NewSession()
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("Can't create S3 session: %s", err)
|
|
|
+ return nil, fmt.Errorf("can't create S3 session: %s", err)
|
|
|
}
|
|
|
|
|
|
if len(config.S3Region) != 0 {
|
|
@@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) {
|
|
|
conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
|
|
|
}
|
|
|
|
|
|
- client := s3.New(sess, conf)
|
|
|
-
|
|
|
- clientRegion := "us-west-1"
|
|
|
- if client.Config.Region != nil {
|
|
|
- clientRegion = *client.Config.Region
|
|
|
+ client, err := createClient(sess, conf)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("can't create S3 client: %s", err)
|
|
|
}
|
|
|
|
|
|
- return &transport{
|
|
|
- session: sess,
|
|
|
- defaultClient: client,
|
|
|
+ clientRegion := *sess.Config.Region
|
|
|
|
|
|
- clientsByRegion: map[string]*s3.S3{clientRegion: client},
|
|
|
- clientsByBucket: make(map[string]*s3.S3),
|
|
|
+ return &transport{
|
|
|
+ session: sess,
|
|
|
+ defaultClient: client,
|
|
|
+ defaultConfig: conf,
|
|
|
+ clientsByRegion: map[string]s3Client{clientRegion: client},
|
|
|
+ clientsByBucket: make(map[string]s3Client),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
@@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
return handleError(req, err)
|
|
|
}
|
|
|
|
|
|
- s3req, _ := client.GetObjectRequest(input)
|
|
|
+ s3req, objectOutput := client.GetObjectRequest(input)
|
|
|
s3req.SetContext(req.Context())
|
|
|
|
|
|
if err := s3req.Send(); err != nil {
|
|
@@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
return handleError(req, err)
|
|
|
}
|
|
|
|
|
|
+ if config.S3DecryptionClientEnabled {
|
|
|
+ s3req.HTTPResponse.Body = objectOutput.Body
|
|
|
+
|
|
|
+ if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
|
|
|
+ contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
|
|
+ if err != nil {
|
|
|
+ handleError(req, err)
|
|
|
+ }
|
|
|
+ s3req.HTTPResponse.ContentLength = contentLength
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return s3req.HTTPResponse, nil
|
|
|
}
|
|
|
|
|
|
-func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) {
|
|
|
+func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
|
|
|
if !config.S3MultiRegion {
|
|
|
return t.defaultClient, nil
|
|
|
}
|
|
|
|
|
|
- var client *s3.S3
|
|
|
+ var client s3Client
|
|
|
|
|
|
func() {
|
|
|
t.mu.RLock()
|
|
@@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
- region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket)
|
|
|
+ region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
@@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
- conf := t.defaultClient.Config.Copy()
|
|
|
+ conf := t.defaultConfig.Copy()
|
|
|
conf.Region = aws.String(region)
|
|
|
|
|
|
- client = s3.New(t.session, conf)
|
|
|
+ client, err = createClient(t.session, conf)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("can't create regional S3 client: %s", err)
|
|
|
+ }
|
|
|
|
|
|
t.clientsByRegion[region] = client
|
|
|
t.clientsByBucket[bucket] = client
|
|
@@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
+func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
|
|
|
+ if config.S3DecryptionClientEnabled {
|
|
|
+ // `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
|
|
|
+ // need to copy the session with an additional config
|
|
|
+ sess = sess.Copy(conf)
|
|
|
+
|
|
|
+ cryptoRegistry, err := createCryptoRegistry(sess)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
|
|
|
+ } else {
|
|
|
+ return s3.New(sess, conf), nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
|
|
|
+ kmsClient := kms.New(sess)
|
|
|
+
|
|
|
+ cr := s3crypto.NewCryptoRegistry()
|
|
|
+ if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return cr, nil
|
|
|
+}
|
|
|
+
|
|
|
func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
|
|
|
if e := s3err.OrigErr(); e != nil {
|