|
@@ -1,10 +1,12 @@
|
|
|
package s3
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
http "net/http"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
@@ -13,6 +15,7 @@ import (
|
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
|
+ "github.com/aws/aws-sdk-go/service/s3/s3manager"
|
|
|
|
|
|
"github.com/imgproxy/imgproxy/v3/config"
|
|
|
defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
|
|
@@ -20,26 +23,28 @@ import (
|
|
|
|
|
|
// transport implements RoundTripper for the 's3' protocol.
|
|
|
type transport struct {
|
|
|
- svc *s3.S3
|
|
|
+ session *session.Session
|
|
|
+ defaultClient *s3.S3
|
|
|
+
|
|
|
+ clientsByRegion map[string]*s3.S3
|
|
|
+ clientsByBucket map[string]*s3.S3
|
|
|
+
|
|
|
+ mu sync.RWMutex
|
|
|
}
|
|
|
|
|
|
func New() (http.RoundTripper, error) {
|
|
|
- s3Conf := aws.NewConfig()
|
|
|
+ conf := aws.NewConfig()
|
|
|
|
|
|
trans, err := defaultTransport.New(false)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- s3Conf.HTTPClient = &http.Client{Transport: trans}
|
|
|
-
|
|
|
- if len(config.S3Region) != 0 {
|
|
|
- s3Conf.Region = aws.String(config.S3Region)
|
|
|
- }
|
|
|
+ conf.HTTPClient = &http.Client{Transport: trans}
|
|
|
|
|
|
if len(config.S3Endpoint) != 0 {
|
|
|
- s3Conf.Endpoint = aws.String(config.S3Endpoint)
|
|
|
- s3Conf.S3ForcePathStyle = aws.Bool(true)
|
|
|
+ conf.Endpoint = aws.String(config.S3Endpoint)
|
|
|
+ conf.S3ForcePathStyle = aws.Bool(true)
|
|
|
}
|
|
|
|
|
|
sess, err := session.NewSession()
|
|
@@ -47,18 +52,35 @@ func New() (http.RoundTripper, error) {
|
|
|
return nil, fmt.Errorf("Can't create S3 session: %s", err)
|
|
|
}
|
|
|
|
|
|
- if len(config.S3AssumeRoleArn) != 0 {
|
|
|
- s3Conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
|
|
|
+ if len(config.S3Region) != 0 {
|
|
|
+ sess.Config.Region = aws.String(config.S3Region)
|
|
|
}
|
|
|
|
|
|
if sess.Config.Region == nil || len(*sess.Config.Region) == 0 {
|
|
|
sess.Config.Region = aws.String("us-west-1")
|
|
|
}
|
|
|
|
|
|
- return transport{s3.New(sess, s3Conf)}, nil
|
|
|
+ if len(config.S3AssumeRoleArn) != 0 {
|
|
|
+ conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
|
|
|
+ }
|
|
|
+
|
|
|
+ client := s3.New(sess, conf)
|
|
|
+
|
|
|
+ clientRegion := "us-west-1"
|
|
|
+ if client.Config.Region != nil {
|
|
|
+ clientRegion = *client.Config.Region
|
|
|
+ }
|
|
|
+
|
|
|
+ return &transport{
|
|
|
+ session: sess,
|
|
|
+ defaultClient: client,
|
|
|
+
|
|
|
+ clientsByRegion: map[string]*s3.S3{clientRegion: client},
|
|
|
+ clientsByBucket: make(map[string]*s3.S3),
|
|
|
+ }, nil
|
|
|
}
|
|
|
|
|
|
-func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
|
+func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
input := &s3.GetObjectInput{
|
|
|
Bucket: aws.String(req.URL.Host),
|
|
|
Key: aws.String(req.URL.Path),
|
|
@@ -86,7 +108,12 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- s3req, _ := t.svc.GetObjectRequest(input)
|
|
|
+ client, err := t.getClient(req.Context(), *input.Bucket)
|
|
|
+ if err != nil {
|
|
|
+ return handleError(req, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ s3req, _ := client.GetObjectRequest(input)
|
|
|
s3req.SetContext(req.Context())
|
|
|
|
|
|
if err := s3req.Send(); err != nil {
|
|
@@ -94,29 +121,81 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error)
|
|
|
s3req.HTTPResponse.Body.Close()
|
|
|
}
|
|
|
|
|
|
- if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
|
|
|
- if e := s3err.OrigErr(); e != nil {
|
|
|
- return nil, e
|
|
|
- }
|
|
|
- }
|
|
|
+ return handleError(req, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return s3req.HTTPResponse, nil
|
|
|
+}
|
|
|
|
|
|
- if s3err, ok := err.(awserr.RequestFailure); !ok || s3err.StatusCode() < 100 || s3err.StatusCode() == 301 {
|
|
|
- return nil, err
|
|
|
- } else {
|
|
|
- body := strings.NewReader(s3err.Message())
|
|
|
- return &http.Response{
|
|
|
- StatusCode: s3err.StatusCode(),
|
|
|
- Proto: "HTTP/1.0",
|
|
|
- ProtoMajor: 1,
|
|
|
- ProtoMinor: 0,
|
|
|
- Header: http.Header{},
|
|
|
- ContentLength: int64(body.Len()),
|
|
|
- Body: io.NopCloser(body),
|
|
|
- Close: false,
|
|
|
- Request: s3req.HTTPRequest,
|
|
|
- }, nil
|
|
|
+func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) {
|
|
|
+ if !config.S3MultiRegion {
|
|
|
+ return t.defaultClient, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ var client *s3.S3
|
|
|
+
|
|
|
+ func() {
|
|
|
+ t.mu.RLock()
|
|
|
+ defer t.mu.RUnlock()
|
|
|
+ client = t.clientsByBucket[bucket]
|
|
|
+ }()
|
|
|
+
|
|
|
+ if client != nil {
|
|
|
+ return client, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ t.mu.Lock()
|
|
|
+ defer t.mu.Unlock()
|
|
|
+
|
|
|
+ // Check again if someone did this before us
|
|
|
+ if client = t.clientsByBucket[bucket]; client != nil {
|
|
|
+ return client, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if client = t.clientsByRegion[region]; client != nil {
|
|
|
+ t.clientsByBucket[bucket] = client
|
|
|
+ return client, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ conf := t.defaultClient.Config.Copy()
|
|
|
+ conf.Region = aws.String(region)
|
|
|
+
|
|
|
+ client = s3.New(t.session, conf)
|
|
|
+
|
|
|
+ t.clientsByRegion[region] = client
|
|
|
+ t.clientsByBucket[bucket] = client
|
|
|
+
|
|
|
+ return client, nil
|
|
|
+}
|
|
|
+
|
|
|
+func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
+ if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
|
|
|
+ if e := s3err.OrigErr(); e != nil {
|
|
|
+ return nil, e
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return s3req.HTTPResponse, nil
|
|
|
+ s3err, ok := err.(awserr.RequestFailure)
|
|
|
+ if !ok || s3err.StatusCode() < 100 || s3err.StatusCode() == 301 {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ body := strings.NewReader(s3err.Message())
|
|
|
+
|
|
|
+ return &http.Response{
|
|
|
+ StatusCode: s3err.StatusCode(),
|
|
|
+ Proto: "HTTP/1.0",
|
|
|
+ ProtoMajor: 1,
|
|
|
+ ProtoMinor: 0,
|
|
|
+ Header: http.Header{},
|
|
|
+ ContentLength: int64(body.Len()),
|
|
|
+ Body: io.NopCloser(body),
|
|
|
+ Close: false,
|
|
|
+ Request: req,
|
|
|
+ }, nil
|
|
|
}
|