|
@@ -16,7 +16,6 @@ import (
|
|
|
awsHttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
|
|
|
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
|
|
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
|
|
- s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/kms"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/sts"
|
|
@@ -149,17 +148,30 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- client, err := t.getClient(req.Context(), *input.Bucket)
|
|
|
- if err != nil {
|
|
|
- return handleError(req, err)
|
|
|
- }
|
|
|
+ client := t.getBucketClient(bucket)
|
|
|
|
|
|
output, err := client.GetObject(req.Context(), input)
|
|
|
- if err != nil {
|
|
|
- if output != nil && output.Body != nil {
|
|
|
+
|
|
|
+ defer func() {
|
|
|
+ if err != nil && output != nil && output.Body != nil {
|
|
|
output.Body.Close()
|
|
|
}
|
|
|
+ }()
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ // Check if the error is the region mismatch error.
|
|
|
+ // If so, create a new client with the correct region and retry the request.
|
|
|
+ if region := regionFromError(err); len(region) != 0 {
|
|
|
+ client, err = t.createBucketClient(req.Context(), bucket, region)
|
|
|
+ if err != nil {
|
|
|
+ return handleError(req, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ output, err = client.GetObject(req.Context(), input)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
+ if err != nil {
|
|
|
return handleError(req, err)
|
|
|
}
|
|
|
|
|
@@ -221,11 +233,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
-func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
|
|
|
- if !config.S3MultiRegion {
|
|
|
- return t.defaultClient, nil
|
|
|
- }
|
|
|
-
|
|
|
+func (t *transport) getBucketClient(bucket string) s3Client {
|
|
|
var client s3Client
|
|
|
|
|
|
func() {
|
|
@@ -235,27 +243,22 @@ func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, err
|
|
|
}()
|
|
|
|
|
|
if client != nil {
|
|
|
- return client, nil
|
|
|
+ return client
|
|
|
}
|
|
|
|
|
|
+ return t.defaultClient
|
|
|
+}
|
|
|
+
|
|
|
+func (t *transport) createBucketClient(ctx context.Context, bucket, region string) (s3Client, error) {
|
|
|
t.mu.Lock()
|
|
|
defer t.mu.Unlock()
|
|
|
|
|
|
// Check again if someone did this before us
|
|
|
- if client = t.clientsByBucket[bucket]; client != nil {
|
|
|
+ if client := t.clientsByBucket[bucket]; client != nil {
|
|
|
return client, nil
|
|
|
}
|
|
|
|
|
|
- region, err := s3Manager.GetBucketRegion(ctx, t.defaultClient, bucket)
|
|
|
- if err != nil {
|
|
|
- return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't get bucket region"))
|
|
|
- }
|
|
|
-
|
|
|
- if len(region) == 0 {
|
|
|
- region = t.defaultConfig.Region
|
|
|
- }
|
|
|
-
|
|
|
- if client = t.clientsByRegion[region]; client != nil {
|
|
|
+ if client := t.clientsByRegion[region]; client != nil {
|
|
|
t.clientsByBucket[bucket] = client
|
|
|
return client, nil
|
|
|
}
|
|
@@ -263,7 +266,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, err
|
|
|
conf := t.defaultConfig.Copy()
|
|
|
conf.Region = region
|
|
|
|
|
|
- client, err = createClient(conf, t.clientOptions)
|
|
|
+ client, err := createClient(conf, t.clientOptions)
|
|
|
if err != nil {
|
|
|
return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create regional S3 client"))
|
|
|
}
|
|
@@ -292,6 +295,19 @@ func createClient(conf aws.Config, opts []func(*s3.Options)) (s3Client, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func regionFromError(err error) string {
|
|
|
+ var rerr *awsHttp.ResponseError
|
|
|
+ if !errors.As(err, &rerr) {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ if rerr.Response == nil || rerr.Response.StatusCode != 301 {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ return rerr.Response.Header.Get("X-Amz-Bucket-Region")
|
|
|
+}
|
|
|
+
|
|
|
func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
var rerr *awsHttp.ResponseError
|
|
|
if !errors.As(err, &rerr) {
|