|
@@ -20,10 +20,9 @@ import (
|
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/sts"
|
|
|
|
|
|
- "github.com/imgproxy/imgproxy/v3/config"
|
|
|
+ "github.com/imgproxy/imgproxy/v3/httpheaders"
|
|
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
|
|
"github.com/imgproxy/imgproxy/v3/transport/common"
|
|
|
- "github.com/imgproxy/imgproxy/v3/transport/generichttp"
|
|
|
)
|
|
|
|
|
|
type s3Client interface {
|
|
@@ -41,33 +40,34 @@ type transport struct {
|
|
|
clientsByBucket map[string]s3Client
|
|
|
|
|
|
mu sync.RWMutex
|
|
|
+
|
|
|
+ config *Config
|
|
|
}
|
|
|
|
|
|
-func New() (http.RoundTripper, error) {
|
|
|
- conf, err := awsConfig.LoadDefaultConfig(context.Background())
|
|
|
- if err != nil {
|
|
|
- return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config"))
|
|
|
+func New(config *Config, trans *http.Transport) (http.RoundTripper, error) {
|
|
|
+ if err := config.Validate(); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- trans, err := generichttp.New(false)
|
|
|
+ conf, err := awsConfig.LoadDefaultConfig(context.Background())
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
+ return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config"))
|
|
|
}
|
|
|
|
|
|
conf.HTTPClient = &http.Client{Transport: trans}
|
|
|
|
|
|
- if len(config.S3Region) != 0 {
|
|
|
- conf.Region = config.S3Region
|
|
|
+ if len(config.Region) != 0 {
|
|
|
+ conf.Region = config.Region
|
|
|
}
|
|
|
|
|
|
if len(conf.Region) == 0 {
|
|
|
conf.Region = "us-west-1"
|
|
|
}
|
|
|
|
|
|
- if len(config.S3AssumeRoleArn) != 0 {
|
|
|
- creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.S3AssumeRoleArn, func(o *stscreds.AssumeRoleOptions) {
|
|
|
- if len(config.S3AssumeRoleExternalID) != 0 {
|
|
|
- o.ExternalID = aws.String(config.S3AssumeRoleExternalID)
|
|
|
+ if len(config.AssumeRoleArn) != 0 {
|
|
|
+ creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(conf), config.AssumeRoleArn, func(o *stscreds.AssumeRoleOptions) {
|
|
|
+ if len(config.AssumeRoleExternalID) != 0 {
|
|
|
+ o.ExternalID = aws.String(config.AssumeRoleExternalID)
|
|
|
}
|
|
|
})
|
|
|
conf.Credentials = creds
|
|
@@ -79,18 +79,18 @@ func New() (http.RoundTripper, error) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- if len(config.S3Endpoint) != 0 {
|
|
|
- endpoint := config.S3Endpoint
|
|
|
+ if len(config.Endpoint) != 0 {
|
|
|
+ endpoint := config.Endpoint
|
|
|
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
|
|
|
endpoint = "http://" + endpoint
|
|
|
}
|
|
|
clientOptions = append(clientOptions, func(o *s3.Options) {
|
|
|
o.BaseEndpoint = aws.String(endpoint)
|
|
|
- o.UsePathStyle = config.S3EndpointUsePathStyle
|
|
|
+ o.UsePathStyle = config.EndpointUsePathStyle
|
|
|
})
|
|
|
}
|
|
|
|
|
|
- client, err := createClient(conf, clientOptions)
|
|
|
+ client, err := createClient(conf, clientOptions, config)
|
|
|
if err != nil {
|
|
|
return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create S3 client"))
|
|
|
}
|
|
@@ -101,6 +101,7 @@ func New() (http.RoundTripper, error) {
|
|
|
defaultConfig: conf,
|
|
|
clientsByRegion: map[string]s3Client{conf.Region: client},
|
|
|
clientsByBucket: make(map[string]s3Client),
|
|
|
+ config: config,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
@@ -114,7 +115,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
Proto: "HTTP/1.0",
|
|
|
ProtoMajor: 1,
|
|
|
ProtoMinor: 0,
|
|
|
- Header: http.Header{"Content-Type": {"text/plain"}},
|
|
|
+ Header: http.Header{httpheaders.ContentType: {"text/plain"}},
|
|
|
ContentLength: int64(body.Len()),
|
|
|
Body: io.NopCloser(body),
|
|
|
Close: false,
|
|
@@ -136,17 +137,14 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
if r := req.Header.Get("Range"); len(r) != 0 {
|
|
|
input.Range = aws.String(r)
|
|
|
} else {
|
|
|
- if config.ETagEnabled {
|
|
|
- if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
|
|
|
- input.IfNoneMatch = aws.String(ifNoneMatch)
|
|
|
- }
|
|
|
+ if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
|
|
|
+ input.IfNoneMatch = aws.String(ifNoneMatch)
|
|
|
}
|
|
|
- if config.LastModifiedEnabled {
|
|
|
- if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
|
|
|
- parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
|
|
|
- if err == nil {
|
|
|
- input.IfModifiedSince = &parsedIfModifiedSince
|
|
|
- }
|
|
|
+
|
|
|
+ if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
|
|
|
+ parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
|
|
|
+ if err == nil {
|
|
|
+ input.IfModifiedSince = &parsedIfModifiedSince
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -183,7 +181,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
contentLength = *output.ContentLength
|
|
|
}
|
|
|
|
|
|
- if config.S3DecryptionClientEnabled {
|
|
|
+ if t.config.DecryptionClientEnabled {
|
|
|
if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
|
|
|
cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
|
|
if err != nil {
|
|
@@ -195,31 +193,31 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
|
|
|
header := make(http.Header)
|
|
|
if contentLength > 0 {
|
|
|
- header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
|
|
+ header.Set(httpheaders.ContentLength, strconv.FormatInt(contentLength, 10))
|
|
|
}
|
|
|
if output.ContentType != nil {
|
|
|
- header.Set("Content-Type", *output.ContentType)
|
|
|
+ header.Set(httpheaders.ContentType, *output.ContentType)
|
|
|
}
|
|
|
if output.ContentEncoding != nil {
|
|
|
- header.Set("Content-Encoding", *output.ContentEncoding)
|
|
|
+ header.Set(httpheaders.ContentEncoding, *output.ContentEncoding)
|
|
|
}
|
|
|
if output.CacheControl != nil {
|
|
|
- header.Set("Cache-Control", *output.CacheControl)
|
|
|
+ header.Set(httpheaders.CacheControl, *output.CacheControl)
|
|
|
}
|
|
|
if output.ExpiresString != nil {
|
|
|
- header.Set("Expires", *output.ExpiresString)
|
|
|
+ header.Set(httpheaders.Expires, *output.ExpiresString)
|
|
|
}
|
|
|
if output.ETag != nil {
|
|
|
- header.Set("ETag", *output.ETag)
|
|
|
+ header.Set(httpheaders.Etag, *output.ETag)
|
|
|
}
|
|
|
if output.LastModified != nil {
|
|
|
- header.Set("Last-Modified", output.LastModified.Format(http.TimeFormat))
|
|
|
+ header.Set(httpheaders.LastModified, output.LastModified.Format(http.TimeFormat))
|
|
|
}
|
|
|
if output.AcceptRanges != nil {
|
|
|
- header.Set("Accept-Ranges", *output.AcceptRanges)
|
|
|
+ header.Set(httpheaders.AcceptRanges, *output.AcceptRanges)
|
|
|
}
|
|
|
if output.ContentRange != nil {
|
|
|
- header.Set("Content-Range", *output.ContentRange)
|
|
|
+ header.Set(httpheaders.ContentRange, *output.ContentRange)
|
|
|
statusCode = http.StatusPartialContent
|
|
|
}
|
|
|
|
|
@@ -269,7 +267,7 @@ func (t *transport) createBucketClient(bucket, region string) (s3Client, error)
|
|
|
conf := t.defaultConfig.Copy()
|
|
|
conf.Region = region
|
|
|
|
|
|
- client, err := createClient(conf, t.clientOptions)
|
|
|
+ client, err := createClient(conf, t.clientOptions, t.config)
|
|
|
if err != nil {
|
|
|
return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create regional S3 client"))
|
|
|
}
|
|
@@ -280,10 +278,10 @@ func (t *transport) createBucketClient(bucket, region string) (s3Client, error)
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
-func createClient(conf aws.Config, opts []func(*s3.Options)) (s3Client, error) {
|
|
|
+func createClient(conf aws.Config, opts []func(*s3.Options), config *Config) (s3Client, error) {
|
|
|
client := s3.NewFromConfig(conf, opts...)
|
|
|
|
|
|
- if config.S3DecryptionClientEnabled {
|
|
|
+ if config.DecryptionClientEnabled {
|
|
|
kmsClient := kms.NewFromConfig(conf)
|
|
|
keyring := s3CryptoMaterials.NewKmsDecryptOnlyAnyKeyKeyring(kmsClient)
|
|
|
|