|
|
@@ -3,7 +3,7 @@ package s3
|
|
|
import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
- "io"
|
|
|
+ "fmt"
|
|
|
"net/http"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
@@ -20,17 +20,18 @@ import (
|
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
|
"github.com/aws/aws-sdk-go-v2/service/sts"
|
|
|
|
|
|
- "github.com/imgproxy/imgproxy/v3/fetcher/transport/common"
|
|
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
|
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
|
|
+ "github.com/imgproxy/imgproxy/v3/storage/common"
|
|
|
+ "github.com/imgproxy/imgproxy/v3/storage/response"
|
|
|
)
|
|
|
|
|
|
type s3Client interface {
|
|
|
GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error)
|
|
|
}
|
|
|
|
|
|
-// transport implements RoundTripper for the 's3' protocol.
|
|
|
-type transport struct {
|
|
|
+// Storage implements S3 Storage
|
|
|
+type Storage struct {
|
|
|
clientOptions []func(*s3.Options)
|
|
|
|
|
|
defaultClient s3Client
|
|
|
@@ -41,11 +42,11 @@ type transport struct {
|
|
|
|
|
|
mu sync.RWMutex
|
|
|
|
|
|
- config *Config
|
|
|
- querySeparator string
|
|
|
+ config *Config
|
|
|
}
|
|
|
|
|
|
-func New(config *Config, trans *http.Transport, querySeparator string) (http.RoundTripper, error) {
|
|
|
+// New creates a new S3 storage instance
|
|
|
+func New(config *Config, trans *http.Transport) (*Storage, error) {
|
|
|
if err := config.Validate(); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -96,33 +97,32 @@ func New(config *Config, trans *http.Transport, querySeparator string) (http.Rou
|
|
|
return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create S3 client"))
|
|
|
}
|
|
|
|
|
|
- return &transport{
|
|
|
+ return &Storage{
|
|
|
clientOptions: clientOptions,
|
|
|
defaultClient: client,
|
|
|
defaultConfig: conf,
|
|
|
clientsByRegion: map[string]s3Client{conf.Region: client},
|
|
|
clientsByBucket: make(map[string]s3Client),
|
|
|
config: config,
|
|
|
- querySeparator: querySeparator,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
-func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
- bucket, key, query := common.GetBucketAndKey(req.URL, t.querySeparator)
|
|
|
-
|
|
|
+// GetObject retrieves an object from Azure cloud
|
|
|
+func (s *Storage) GetObject(
|
|
|
+ ctx context.Context,
|
|
|
+ reqHeader http.Header,
|
|
|
+ bucket, key, query string,
|
|
|
+) (*response.Object, error) {
|
|
|
+ // If either bucket or object key is empty, return 404
|
|
|
if len(bucket) == 0 || len(key) == 0 {
|
|
|
- body := strings.NewReader("Invalid S3 URL: bucket name or object key is empty")
|
|
|
- return &http.Response{
|
|
|
- StatusCode: http.StatusNotFound,
|
|
|
- Proto: "HTTP/1.0",
|
|
|
- ProtoMajor: 1,
|
|
|
- ProtoMinor: 0,
|
|
|
- Header: http.Header{httpheaders.ContentType: {"text/plain"}},
|
|
|
- ContentLength: int64(body.Len()),
|
|
|
- Body: io.NopCloser(body),
|
|
|
- Close: false,
|
|
|
- Request: req,
|
|
|
- }, nil
|
|
|
+ return response.NewNotFound(
|
|
|
+ "invalid S3 Storage URL: bucket name or object key are empty",
|
|
|
+ ), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if access to the container is allowed
|
|
|
+ if !common.IsBucketAllowed(bucket, s.config.AllowedBuckets, s.config.DeniedBuckets) {
|
|
|
+ return nil, fmt.Errorf("access to the S3 bucket %s is denied", bucket)
|
|
|
}
|
|
|
|
|
|
input := &s3.GetObjectInput{
|
|
|
@@ -134,16 +134,14 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
input.VersionId = aws.String(query)
|
|
|
}
|
|
|
|
|
|
- statusCode := http.StatusOK
|
|
|
-
|
|
|
- if r := req.Header.Get("Range"); len(r) != 0 {
|
|
|
+ if r := reqHeader.Get(httpheaders.Range); len(r) != 0 {
|
|
|
input.Range = aws.String(r)
|
|
|
} else {
|
|
|
- if ifNoneMatch := req.Header.Get("If-None-Match"); len(ifNoneMatch) > 0 {
|
|
|
+ if ifNoneMatch := reqHeader.Get(httpheaders.IfNoneMatch); len(ifNoneMatch) > 0 {
|
|
|
input.IfNoneMatch = aws.String(ifNoneMatch)
|
|
|
}
|
|
|
|
|
|
- if ifModifiedSince := req.Header.Get("If-Modified-Since"); len(ifModifiedSince) > 0 {
|
|
|
+ if ifModifiedSince := reqHeader.Get(httpheaders.IfModifiedSince); len(ifModifiedSince) > 0 {
|
|
|
parsedIfModifiedSince, err := time.Parse(http.TimeFormat, ifModifiedSince)
|
|
|
if err == nil {
|
|
|
input.IfModifiedSince = &parsedIfModifiedSince
|
|
|
@@ -151,9 +149,9 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- client := t.getBucketClient(bucket)
|
|
|
+ client := s.getBucketClient(bucket)
|
|
|
|
|
|
- output, err := client.GetObject(req.Context(), input)
|
|
|
+ output, err := client.GetObject(ctx, input)
|
|
|
|
|
|
defer func() {
|
|
|
if err != nil && output != nil && output.Body != nil {
|
|
|
@@ -165,17 +163,17 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
// Check if the error is the region mismatch error.
|
|
|
// If so, create a new client with the correct region and retry the request.
|
|
|
if region := regionFromError(err); len(region) != 0 {
|
|
|
- client, err = t.createBucketClient(bucket, region)
|
|
|
+ client, err = s.createBucketClient(bucket, region)
|
|
|
if err != nil {
|
|
|
- return handleError(req, err)
|
|
|
+ return handleError(err)
|
|
|
}
|
|
|
|
|
|
- output, err = client.GetObject(req.Context(), input)
|
|
|
+ output, err = client.GetObject(ctx, input)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
- return handleError(req, err)
|
|
|
+ return handleError(err)
|
|
|
}
|
|
|
|
|
|
contentLength := int64(-1)
|
|
|
@@ -183,11 +181,11 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
contentLength = *output.ContentLength
|
|
|
}
|
|
|
|
|
|
- if t.config.DecryptionClientEnabled {
|
|
|
+ if s.config.DecryptionClientEnabled {
|
|
|
if unencryptedContentLength := output.Metadata["X-Amz-Meta-X-Amz-Unencrypted-Content-Length"]; len(unencryptedContentLength) != 0 {
|
|
|
cl, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
|
|
if err != nil {
|
|
|
- handleError(req, err)
|
|
|
+ return handleError(err)
|
|
|
}
|
|
|
contentLength = cl
|
|
|
}
|
|
|
@@ -220,23 +218,13 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
}
|
|
|
if output.ContentRange != nil {
|
|
|
header.Set(httpheaders.ContentRange, *output.ContentRange)
|
|
|
- statusCode = http.StatusPartialContent
|
|
|
- }
|
|
|
-
|
|
|
- return &http.Response{
|
|
|
- StatusCode: statusCode,
|
|
|
- Proto: "HTTP/1.0",
|
|
|
- ProtoMajor: 1,
|
|
|
- ProtoMinor: 0,
|
|
|
- Header: header,
|
|
|
- ContentLength: contentLength,
|
|
|
- Body: output.Body,
|
|
|
- Close: true,
|
|
|
- Request: req,
|
|
|
- }, nil
|
|
|
+ return response.NewPartialContent(header, output.Body), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return response.NewOK(header, output.Body), nil
|
|
|
}
|
|
|
|
|
|
-func (t *transport) getBucketClient(bucket string) s3Client {
|
|
|
+func (t *Storage) getBucketClient(bucket string) s3Client {
|
|
|
var client s3Client
|
|
|
|
|
|
func() {
|
|
|
@@ -252,7 +240,7 @@ func (t *transport) getBucketClient(bucket string) s3Client {
|
|
|
return t.defaultClient
|
|
|
}
|
|
|
|
|
|
-func (t *transport) createBucketClient(bucket, region string) (s3Client, error) {
|
|
|
+func (t *Storage) createBucketClient(bucket, region string) (s3Client, error) {
|
|
|
t.mu.Lock()
|
|
|
defer t.mu.Unlock()
|
|
|
|
|
|
@@ -311,7 +299,7 @@ func regionFromError(err error) string {
|
|
|
return rerr.Response.Header.Get("X-Amz-Bucket-Region")
|
|
|
}
|
|
|
|
|
|
-func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
+func handleError(err error) (*response.Object, error) {
|
|
|
var rerr *awsHttp.ResponseError
|
|
|
if !errors.As(err, &rerr) {
|
|
|
return nil, ierrors.Wrap(err, 0)
|
|
|
@@ -321,15 +309,5 @@ func handleError(req *http.Request, err error) (*http.Response, error) {
|
|
|
return nil, ierrors.Wrap(err, 0)
|
|
|
}
|
|
|
|
|
|
- return &http.Response{
|
|
|
- StatusCode: rerr.Response.StatusCode,
|
|
|
- Proto: "HTTP/1.0",
|
|
|
- ProtoMajor: 1,
|
|
|
- ProtoMinor: 0,
|
|
|
- Header: http.Header{"Content-Type": {"text/plain"}},
|
|
|
- ContentLength: int64(len(err.Error())),
|
|
|
- Body: io.NopCloser(strings.NewReader(err.Error())),
|
|
|
- Close: false,
|
|
|
- Request: req,
|
|
|
- }, nil
|
|
|
+ return response.NewError(rerr.Response.StatusCode, err.Error()), nil
|
|
|
}
|