1
0
Эх сурвалжийг харах

Add client-side decryption support to S3

Garen J. Torikian 1 жил өмнө
parent
commit
b384e2bb7f

+ 1 - 0
CHANGELOG.md

@@ -3,6 +3,7 @@
 ## [Unreleased]
 ### Add
 - Add `status_codes_total` counter to Prometheus metrics.
+- Add client-side decryprion support for S3 integration.
 - (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option.
 - (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option.
 

+ 8 - 5
config/config.go

@@ -99,11 +99,12 @@ var (
 
 	LocalFileSystemRoot string
 
-	S3Enabled       bool
-	S3Region        string
-	S3Endpoint      string
-	S3AssumeRoleArn string
-	S3MultiRegion   bool
+	S3Enabled                 bool
+	S3Region                  string
+	S3Endpoint                string
+	S3AssumeRoleArn           string
+	S3MultiRegion             bool
+	S3DecryptionClientEnabled bool
 
 	GCSEnabled  bool
 	GCSKey      string
@@ -300,6 +301,7 @@ func Reset() {
 	S3Endpoint = ""
 	S3AssumeRoleArn = ""
 	S3MultiRegion = false
+	S3DecryptionClientEnabled = false
 	GCSEnabled = false
 	GCSKey = ""
 	ABSEnabled = false
@@ -501,6 +503,7 @@ func Configure() error {
 	configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT")
 	configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN")
 	configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION")
+	configurators.Bool(&S3DecryptionClientEnabled, "IMGPROXY_S3_USE_DECRYPTION_CLIENT")
 
 	configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS")
 	configurators.String(&GCSKey, "IMGPROXY_GCS_KEY")

+ 78 - 21
transport/s3/s3.go

@@ -4,7 +4,8 @@ import (
 	"context"
 	"fmt"
 	"io"
-	http "net/http"
+	"net/http"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -14,20 +15,27 @@ import (
 	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
 	"github.com/aws/aws-sdk-go/aws/request"
 	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/kms"
 	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/aws/aws-sdk-go/service/s3/s3crypto"
 	"github.com/aws/aws-sdk-go/service/s3/s3manager"
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
 )
 
+type s3Client interface {
+	GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
+}
+
 // transport implements RoundTripper for the 's3' protocol.
 type transport struct {
 	session       *session.Session
-	defaultClient *s3.S3
+	defaultClient s3Client
+	defaultConfig *aws.Config
 
-	clientsByRegion map[string]*s3.S3
-	clientsByBucket map[string]*s3.S3
+	clientsByRegion map[string]s3Client
+	clientsByBucket map[string]s3Client
 
 	mu sync.RWMutex
 }
@@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) {
 
 	sess, err := session.NewSession()
 	if err != nil {
-		return nil, fmt.Errorf("Can't create S3 session: %s", err)
+		return nil, fmt.Errorf("can't create S3 session: %s", err)
 	}
 
 	if len(config.S3Region) != 0 {
@@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) {
 		conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
 	}
 
-	client := s3.New(sess, conf)
-
-	clientRegion := "us-west-1"
-	if client.Config.Region != nil {
-		clientRegion = *client.Config.Region
+	client, err := createClient(sess, conf)
+	if err != nil {
+		return nil, fmt.Errorf("can't create S3 client: %s", err)
 	}
 
-	return &transport{
-		session:       sess,
-		defaultClient: client,
+	clientRegion := *sess.Config.Region
 
-		clientsByRegion: map[string]*s3.S3{clientRegion: client},
-		clientsByBucket: make(map[string]*s3.S3),
+	return &transport{
+		session:         sess,
+		defaultClient:   client,
+		defaultConfig:   conf,
+		clientsByRegion: map[string]s3Client{clientRegion: client},
+		clientsByBucket: make(map[string]s3Client),
 	}, nil
 }
 
@@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		return handleError(req, err)
 	}
 
-	s3req, _ := client.GetObjectRequest(input)
+	s3req, objectOutput := client.GetObjectRequest(input)
 	s3req.SetContext(req.Context())
 
 	if err := s3req.Send(); err != nil {
@@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		return handleError(req, err)
 	}
 
+	if config.S3DecryptionClientEnabled {
+		s3req.HTTPResponse.Body = objectOutput.Body
+
+		if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
+			contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
+			if err != nil {
+				handleError(req, err)
+			}
+			s3req.HTTPResponse.ContentLength = contentLength
+		}
+	}
+
 	return s3req.HTTPResponse, nil
 }
 
-func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) {
+func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
 	if !config.S3MultiRegion {
 		return t.defaultClient, nil
 	}
 
-	var client *s3.S3
+	var client s3Client
 
 	func() {
 		t.mu.RLock()
@@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
 		return client, nil
 	}
 
-	region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket)
+	region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
 	if err != nil {
 		return nil, err
 	}
@@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
 		return client, nil
 	}
 
-	conf := t.defaultClient.Config.Copy()
+	conf := t.defaultConfig.Copy()
 	conf.Region = aws.String(region)
 
-	client = s3.New(t.session, conf)
+	client, err = createClient(t.session, conf)
+	if err != nil {
+		return nil, fmt.Errorf("can't create regional S3 client: %s", err)
+	}
 
 	t.clientsByRegion[region] = client
 	t.clientsByBucket[bucket] = client
@@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
 	return client, nil
 }
 
+func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
+	if config.S3DecryptionClientEnabled {
+		// `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
+		// need to copy the session with an additional config
+		sess = sess.Copy(conf)
+
+		cryptoRegistry, err := createCryptoRegistry(sess)
+		if err != nil {
+			return nil, err
+		}
+
+		return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
+	} else {
+		return s3.New(sess, conf), nil
+	}
+}
+
+func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
+	kmsClient := kms.New(sess)
+
+	cr := s3crypto.NewCryptoRegistry()
+	if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
+		return nil, err
+	}
+	if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
+		return nil, err
+	}
+	if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
+		return nil, err
+	}
+
+	return cr, nil
+}
+
 func handleError(req *http.Request, err error) (*http.Response, error) {
 	if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
 		if e := s3err.OrigErr(); e != nil {

+ 5 - 2
transport/s3/s3_test.go

@@ -50,15 +50,18 @@ func (s *S3TestSuite) SetupSuite() {
 	svc, err := s.transport.(*transport).getClient(context.Background(), "test")
 	require.Nil(s.T(), err)
 	require.NotNil(s.T(), svc)
+	require.IsType(s.T(), &s3.S3{}, svc)
 
-	_, err = svc.PutObject(&s3.PutObjectInput{
+	client := svc.(*s3.S3)
+
+	_, err = client.PutObject(&s3.PutObjectInput{
 		Body:   bytes.NewReader(make([]byte, 32)),
 		Bucket: aws.String("test"),
 		Key:    aws.String("foo/test.png"),
 	})
 	require.Nil(s.T(), err)
 
-	obj, err := svc.GetObject(&s3.GetObjectInput{
+	obj, err := client.GetObject(&s3.GetObjectInput{
 		Bucket: aws.String("test"),
 		Key:    aws.String("foo/test.png"),
 	})