Browse Source

Polish Azure Blob Storage support

DarthSim 2 years ago
parent
commit
894cb55cbe
1 changed files with 10 additions and 7 deletions
  1. 10 7
      transport/azure/azure.go

+ 10 - 7
transport/azure/azure.go

@@ -9,15 +9,15 @@ import (
 	"strconv"
 	"strings"
 
-	"github.com/imgproxy/imgproxy/v3/httprange"
-
 	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
 	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
 	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
 	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
+
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/ctxreader"
+	"github.com/imgproxy/imgproxy/v3/httprange"
 )
 
 type transport struct {
@@ -27,12 +27,12 @@ type transport struct {
 func New() (http.RoundTripper, error) {
 	var (
 		client                 *azblob.Client
+		sharedKeyCredential    *azblob.SharedKeyCredential
 		defaultAzureCredential *azidentity.DefaultAzureCredential
 		err                    error
-		sharedKeyCredential    *azblob.SharedKeyCredential
 	)
 
-	if config.ABSName == "" {
+	if len(config.ABSName) == 0 {
 		return nil, errors.New("IMGPROXY_ABS_NAME must be set")
 	}
 
@@ -40,12 +40,13 @@ func New() (http.RoundTripper, error) {
 	if len(endpoint) == 0 {
 		endpoint = fmt.Sprintf("https://%s.blob.core.windows.net", config.ABSName)
 	}
+
 	endpointURL, err := url.Parse(endpoint)
 	if err != nil {
 		return nil, err
 	}
 
-	if config.ABSKey != "" {
+	if len(config.ABSKey) > 0 {
 		sharedKeyCredential, err = azblob.NewSharedKeyCredential(config.ABSName, config.ABSKey)
 		if err != nil {
 			return nil, err
@@ -72,7 +73,7 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
 	container := req.URL.Host
 	key := req.URL.Path
 
-	var statusCode = http.StatusOK
+	statusCode := http.StatusOK
 
 	header := make(http.Header)
 	opts := &blob.DownloadStreamOptions{}
@@ -139,7 +140,9 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
 
 	header.Set("Accept-Ranges", "bytes")
 
+	contentLength := int64(0)
 	if result.ContentLength != nil {
+		contentLength = *result.ContentLength
 		header.Set("Content-Length", strconv.FormatInt(*result.ContentLength, 10))
 	}
 
@@ -161,7 +164,7 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		ProtoMajor:    1,
 		ProtoMinor:    0,
 		Header:        header,
-		ContentLength: *result.ContentLength,
+		ContentLength: contentLength,
 		Body:          ctxreader.New(req.Context(), result.Body, true),
 		Close:         true,
 		Request:       req,