Browse Source

transport isolated, imagefetcher introduced (#1465)

Victor Sokolov 2 months ago
parent
commit
dd3b430f87

+ 1 - 0
go.mod

@@ -205,6 +205,7 @@ require (
 	go.uber.org/atomic v1.11.0 // indirect
 	go.uber.org/multierr v1.11.0 // indirect
 	go.uber.org/zap v1.27.0 // indirect
+	go.withmatt.com/httpheaders v1.0.0 // indirect
 	go.yaml.in/yaml/v2 v2.4.2 // indirect
 	go.yaml.in/yaml/v3 v3.0.4 // indirect
 	golang.org/x/crypto v0.39.0 // indirect

+ 2 - 0
go.sum

@@ -560,6 +560,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
 go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
 go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
 go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
+go.withmatt.com/httpheaders v1.0.0 h1:xZhtLWyIWCd8FT3CvUBRQLhQpgZaMmHNfIIT0wwNc1A=
+go.withmatt.com/httpheaders v1.0.0/go.mod h1:bKAYNgm9s2ViHIoGOnMKo4F2zJXBdvpfGuSEJQYF8pQ=
 go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
 go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
 go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=

+ 47 - 247
imagedata/download.go

@@ -1,51 +1,36 @@
 package imagedata
 
 import (
-	"compress/gzip"
 	"context"
-	"io"
 	"net/http"
-	"net/http/cookiejar"
-	"regexp"
-	"strconv"
-	"strings"
-	"time"
+	"slices"
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
+	"github.com/imgproxy/imgproxy/v3/imagefetcher"
 	"github.com/imgproxy/imgproxy/v3/security"
-
-	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
-	azureTransport "github.com/imgproxy/imgproxy/v3/transport/azure"
-	transportCommon "github.com/imgproxy/imgproxy/v3/transport/common"
-	fsTransport "github.com/imgproxy/imgproxy/v3/transport/fs"
-	gcsTransport "github.com/imgproxy/imgproxy/v3/transport/gcs"
-	s3Transport "github.com/imgproxy/imgproxy/v3/transport/s3"
-	swiftTransport "github.com/imgproxy/imgproxy/v3/transport/swift"
+	"github.com/imgproxy/imgproxy/v3/transport"
+	"go.withmatt.com/httpheaders"
 )
 
 var (
-	downloadClient *http.Client
-
-	enabledSchemes = map[string]struct{}{
-		"http":  {},
-		"https": {},
-	}
-
-	imageHeadersToStore = []string{
-		"Cache-Control",
-		"Expires",
-		"ETag",
-		"Last-Modified",
-	}
-
-	contentRangeRe = regexp.MustCompile(`^bytes ((\d+)-(\d+)|\*)/(\d+|\*)$`)
+	Fetcher *imagefetcher.Fetcher
 
 	// For tests
 	redirectAllRequestsTo string
-)
 
-const msgSourceImageIsUnreachable = "Source image is unreachable"
+	// keepResponseHeaders is a list of HTTP headers that should be preserved in the response
+	keepResponseHeaders = []string{
+		httpheaders.CacheControl,
+		httpheaders.Expires,
+		httpheaders.LastModified,
+		// NOTE:
+		// httpheaders.Etag == "Etag".
+		// Http header names are case-insensitive, but we rely on the case in most cases.
+		// We must migrate to http.Headers and the subsequent methods everywhere.
+		httpheaders.Etag,
+	}
+)
 
 type DownloadOptions struct {
 	Header    http.Header
@@ -53,224 +38,40 @@ type DownloadOptions struct {
 }
 
 func initDownloading() error {
-	transport, err := defaultTransport.New(true)
+	ts, err := transport.NewTransport()
 	if err != nil {
 		return err
 	}
 
-	registerProtocol := func(scheme string, rt http.RoundTripper) {
-		transport.RegisterProtocol(scheme, rt)
-		enabledSchemes[scheme] = struct{}{}
-	}
-
-	if config.LocalFileSystemRoot != "" {
-		registerProtocol("local", fsTransport.New())
-	}
-
-	if config.S3Enabled {
-		if t, err := s3Transport.New(); err != nil {
-			return err
-		} else {
-			registerProtocol("s3", t)
-		}
-	}
-
-	if config.GCSEnabled {
-		if t, err := gcsTransport.New(); err != nil {
-			return err
-		} else {
-			registerProtocol("gs", t)
-		}
-	}
-
-	if config.ABSEnabled {
-		if t, err := azureTransport.New(); err != nil {
-			return err
-		} else {
-			registerProtocol("abs", t)
-		}
-	}
-
-	if config.SwiftEnabled {
-		if t, err := swiftTransport.New(); err != nil {
-			return err
-		} else {
-			registerProtocol("swift", t)
-		}
-	}
-
-	downloadClient = &http.Client{
-		Transport: transport,
-		CheckRedirect: func(req *http.Request, via []*http.Request) error {
-			redirects := len(via)
-			if redirects >= config.MaxRedirects {
-				return newImageTooManyRedirectsError(redirects)
-			}
-			return nil
-		},
-	}
-
-	return nil
-}
-
-func headersToStore(res *http.Response) map[string]string {
-	m := make(map[string]string)
-
-	for _, h := range imageHeadersToStore {
-		if val := res.Header.Get(h); len(val) != 0 {
-			m[h] = val
-		}
-	}
-
-	return m
-}
-
-func BuildImageRequest(ctx context.Context, imageURL string, header http.Header, jar http.CookieJar) (*http.Request, context.CancelFunc, error) {
-	reqCtx, reqCancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second)
-
-	imageURL = transportCommon.EscapeURL(imageURL)
-
-	req, err := http.NewRequestWithContext(reqCtx, "GET", imageURL, nil)
+	Fetcher, err = imagefetcher.NewFetcher(ts, config.MaxRedirects)
 	if err != nil {
-		reqCancel()
-		return nil, func() {}, newImageRequestError(err)
-	}
-
-	if _, ok := enabledSchemes[req.URL.Scheme]; !ok {
-		reqCancel()
-		return nil, func() {}, newImageRequstSchemeError(req.URL.Scheme)
-	}
-
-	if jar != nil {
-		for _, cookie := range jar.Cookies(req.URL) {
-			req.AddCookie(cookie)
-		}
+		return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create image fetcher"))
 	}
 
-	req.Header.Set("User-Agent", config.UserAgent)
-
-	for k, v := range header {
-		if len(v) > 0 {
-			req.Header.Set(k, v[0])
-		}
-	}
-
-	return req, reqCancel, nil
+	return nil
 }
 
-func SendRequest(req *http.Request) (*http.Response, error) {
-	var client *http.Client
-	if req.URL.Scheme == "http" || req.URL.Scheme == "https" {
-		clientCopy := *downloadClient
-
-		jar, err := cookiejar.New(nil)
-		if err != nil {
-			return nil, err
-		}
-		clientCopy.Jar = jar
-		client = &clientCopy
-	} else {
-		client = downloadClient
-	}
-
-	for {
-		res, err := client.Do(req)
-		if err == nil {
-			return res, nil
-		}
-
-		if res != nil && res.Body != nil {
-			res.Body.Close()
-		}
-
-		if strings.Contains(err.Error(), "client connection lost") {
-			select {
-			case <-req.Context().Done():
-				return nil, err
-			case <-time.After(100 * time.Microsecond):
-				continue
-			}
-		}
-
-		return nil, wrapError(err)
+func download(ctx context.Context, imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
+	// We use this for testing
+	if len(redirectAllRequestsTo) > 0 {
+		imageURL = redirectAllRequestsTo
 	}
-}
 
-func requestImage(ctx context.Context, imageURL string, opts DownloadOptions) (*http.Response, context.CancelFunc, error) {
-	req, reqCancel, err := BuildImageRequest(ctx, imageURL, opts.Header, opts.CookieJar)
+	req, err := Fetcher.BuildRequest(ctx, imageURL, opts.Header, opts.CookieJar)
 	if err != nil {
-		reqCancel()
-		return nil, func() {}, err
+		return nil, err
 	}
+	defer req.Cancel()
 
-	res, err := SendRequest(req)
+	res, err := req.FetchImage()
 	if err != nil {
-		reqCancel()
-		return nil, func() {}, err
-	}
-
-	if res.StatusCode == http.StatusNotModified {
-		res.Body.Close()
-		reqCancel()
-		return nil, func() {}, newNotModifiedError(headersToStore(res))
-	}
-
-	// If the source responds with 206, check if the response contains entire image.
-	// If not, return an error.
-	if res.StatusCode == http.StatusPartialContent {
-		contentRange := res.Header.Get("Content-Range")
-		rangeParts := contentRangeRe.FindStringSubmatch(contentRange)
-		if len(rangeParts) == 0 {
-			res.Body.Close()
-			reqCancel()
-			return nil, func() {}, newImagePartialResponseError("Partial response with invalid Content-Range header")
-		}
-
-		if rangeParts[1] == "*" || rangeParts[2] != "0" {
-			res.Body.Close()
-			reqCancel()
-			return nil, func() {}, newImagePartialResponseError("Partial response with incomplete content")
-		}
-
-		contentLengthStr := rangeParts[4]
-		if contentLengthStr == "*" {
-			contentLengthStr = res.Header.Get("Content-Length")
-		}
-
-		contentLength, _ := strconv.Atoi(contentLengthStr)
-		rangeEnd, _ := strconv.Atoi(rangeParts[3])
-
-		if contentLength <= 0 || rangeEnd != contentLength-1 {
+		if res != nil {
 			res.Body.Close()
-			reqCancel()
-			return nil, func() {}, newImagePartialResponseError("Partial response with incomplete content")
 		}
-	} else if res.StatusCode != http.StatusOK {
-		var body string
-
-		if strings.HasPrefix(res.Header.Get("Content-Type"), "text/") {
-			bbody, _ := io.ReadAll(io.LimitReader(res.Body, 1024))
-			body = string(bbody)
-		}
-
-		res.Body.Close()
-		reqCancel()
-
-		return nil, func() {}, newImageResponseStatusError(res.StatusCode, body)
-	}
-
-	return res, reqCancel, nil
-}
-
-func download(ctx context.Context, imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
-	// We use this for testing
-	if len(redirectAllRequestsTo) > 0 {
-		imageURL = redirectAllRequestsTo
+		return nil, err
 	}
 
-	res, reqCancel, err := requestImage(ctx, imageURL, opts)
-	defer reqCancel()
-
+	res, err = security.LimitResponseSize(res, secopts)
 	if res != nil {
 		defer res.Body.Close()
 	}
@@ -278,27 +79,26 @@ func download(ctx context.Context, imageURL string, opts DownloadOptions, secopt
 		return nil, err
 	}
 
-	body := res.Body
-	contentLength := int(res.ContentLength)
+	imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength), secopts)
+	if err != nil {
+		return nil, ierrors.Wrap(err, 0)
+	}
 
-	if res.Header.Get("Content-Encoding") == "gzip" {
-		gzipBody, errGzip := gzip.NewReader(res.Body)
-		if gzipBody != nil {
-			defer gzipBody.Close()
+	h := make(map[string]string)
+	for k := range res.Header {
+		if !slices.Contains(keepResponseHeaders, k) {
+			continue
 		}
-		if errGzip != nil {
-			return nil, err
-		}
-		body = gzipBody
-		contentLength = 0
-	}
 
-	imgdata, err := readAndCheckImage(body, contentLength, secopts)
-	if err != nil {
-		return nil, ierrors.Wrap(err, 0)
+		// TODO: Fix Etag/ETag inconsistency
+		if k == "Etag" {
+			h["ETag"] = res.Header.Get(k)
+		} else {
+			h[k] = res.Header.Get(k)
+		}
 	}
 
-	imgdata.Headers = headersToStore(res)
+	imgdata.Headers = h
 
 	return imgdata, nil
 }

+ 4 - 9
imagedata/read.go

@@ -8,6 +8,7 @@ import (
 	"github.com/imgproxy/imgproxy/v3/bufpool"
 	"github.com/imgproxy/imgproxy/v3/bufreader"
 	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/imagefetcher"
 	"github.com/imgproxy/imgproxy/v3/imagemeta"
 	"github.com/imgproxy/imgproxy/v3/security"
 )
@@ -19,15 +20,9 @@ func initRead() {
 }
 
 func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options) (*ImageData, error) {
-	if err := security.CheckFileSize(contentLength, secopts); err != nil {
-		return nil, err
-	}
-
 	buf := downloadBufPool.Get(contentLength, false)
 	cancel := func() { downloadBufPool.Put(buf) }
 
-	r = security.LimitFileSize(r, secopts)
-
 	br := bufreader.New(r, buf)
 
 	meta, err := imagemeta.DecodeMeta(br)
@@ -35,14 +30,14 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
 		buf.Reset()
 		cancel()
 
-		return nil, wrapError(err)
+		return nil, imagefetcher.WrapError(err)
 	}
 
 	if err = security.CheckDimensions(meta.Width(), meta.Height(), 1, secopts); err != nil {
 		buf.Reset()
 		cancel()
 
-		return nil, wrapError(err)
+		return nil, imagefetcher.WrapError(err)
 	}
 
 	downloadBufPool.GrowBuffer(buf, contentLength)
@@ -51,7 +46,7 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
 		buf.Reset()
 		cancel()
 
-		return nil, wrapError(err)
+		return nil, imagefetcher.WrapError(err)
 	}
 
 	return &ImageData{

+ 8 - 5
imagedata/errors.go → imagefetcher/errors.go

@@ -1,4 +1,4 @@
-package imagedata
+package imagefetcher
 
 import (
 	"context"
@@ -10,6 +10,8 @@ import (
 	"github.com/imgproxy/imgproxy/v3/security"
 )
 
+const msgSourceImageIsUnreachable = "Source image is unreachable"
+
 type (
 	ImageRequestError          struct{ error }
 	ImageRequstSchemeError     string
@@ -20,7 +22,7 @@ type (
 	ImageRequestTimeoutError   struct{ error }
 
 	NotModifiedError struct {
-		headers map[string]string
+		headers http.Header
 	}
 
 	httpError interface {
@@ -135,7 +137,7 @@ func (e ImageRequestTimeoutError) Error() string {
 
 func (e ImageRequestTimeoutError) Unwrap() error { return e.error }
 
-func newNotModifiedError(headers map[string]string) error {
+func newNotModifiedError(headers http.Header) error {
 	return ierrors.Wrap(
 		NotModifiedError{headers},
 		1,
@@ -147,11 +149,12 @@ func newNotModifiedError(headers map[string]string) error {
 
 func (e NotModifiedError) Error() string { return "Not modified" }
 
-func (e NotModifiedError) Headers() map[string]string {
+func (e NotModifiedError) Headers() http.Header {
 	return e.headers
 }
 
-func wrapError(err error) error {
+// NOTE: make private when we remove download functions from imagedata package
+func WrapError(err error) error {
 	isTimeout := false
 
 	var secArrdErr security.SourceAddressError

+ 86 - 0
imagefetcher/fetcher.go

@@ -0,0 +1,86 @@
+// imagefetcher is responsible for downloading images using HTTP requests through various protocols
+// defined in transport package
+package imagefetcher
+
+import (
+	"context"
+	"net/http"
+	"time"
+
+	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/transport"
+	"github.com/imgproxy/imgproxy/v3/transport/common"
+	"go.withmatt.com/httpheaders"
+)
+
+const (
+	connectionLostError = "client connection lost" // Error message indicating a lost connection
+	bounceDelay         = 100 * time.Microsecond   // Delay before retrying a request
+)
+
+// Fetcher is a struct that holds the HTTP client and transport for fetching images
+type Fetcher struct {
+	transport    *transport.Transport // Transport used for making HTTP requests
+	maxRedirects int                  // Maximum number of redirects allowed
+}
+
+// NewFetcher creates a new ImageFetcher with the provided transport
+func NewFetcher(transport *transport.Transport, maxRedirects int) (*Fetcher, error) {
+	return &Fetcher{transport, maxRedirects}, nil
+}
+
+// checkRedirect is a method that checks if the number of redirects exceeds the maximum allowed
+func (f *Fetcher) checkRedirect(req *http.Request, via []*http.Request) error {
+	redirects := len(via)
+	if redirects >= f.maxRedirects {
+		return newImageTooManyRedirectsError(redirects)
+	}
+	return nil
+}
+
+// newHttpClient returns new HTTP client
+func (f *Fetcher) newHttpClient() *http.Client {
+	return &http.Client{
+		Transport:     f.transport.Transport(), // Connection pool is there
+		CheckRedirect: f.checkRedirect,
+	}
+}
+
+// NewImageFetcherRequest creates a new ImageFetcherRequest with the provided context, URL, headers, and cookie jar
+func (f *Fetcher) BuildRequest(ctx context.Context, url string, header http.Header, jar http.CookieJar) (*Request, error) {
+	url = common.EscapeURL(url)
+
+	// Set request timeout and get cancel function
+	ctx, cancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second)
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		cancel()
+		return nil, newImageRequestError(err)
+	}
+
+	// Check if the URL scheme is supported
+	if !f.transport.IsProtocolRegistered(req.URL.Scheme) {
+		cancel()
+		return nil, newImageRequstSchemeError(req.URL.Scheme)
+	}
+
+	// Add cookies from the jar to the request (if any)
+	if jar != nil {
+		for _, cookie := range jar.Cookies(req.URL) {
+			req.AddCookie(cookie)
+		}
+	}
+
+	// Set user agent header
+	req.Header.Set(httpheaders.UserAgent, config.UserAgent)
+
+	// Set headers
+	for k, v := range header {
+		if len(v) > 0 {
+			req.Header.Set(k, v[0])
+		}
+	}
+
+	return &Request{f, req, cancel}, nil
+}

+ 204 - 0
imagefetcher/request.go

@@ -0,0 +1,204 @@
+package imagefetcher
+
+import (
+	"compress/gzip"
+	"context"
+	"io"
+	"net/http"
+	"net/http/cookiejar"
+	"net/url"
+	"regexp"
+	"strconv"
+	"strings"
+	"time"
+
+	"go.withmatt.com/httpheaders"
+)
+
+var (
+	// contentRangeRe Content-Range header regex to check if the response is a partial content response
+	contentRangeRe = regexp.MustCompile(`^bytes ((\d+)-(\d+)|\*)/(\d+|\*)$`)
+)
+
+// Request is a struct that holds the request and cancel function for an image fetcher request
+type Request struct {
+	fetcher *Fetcher           // Parent ImageFetcher instance
+	request *http.Request      // HTTP request to fetch the image
+	cancel  context.CancelFunc // Request context cancel function
+}
+
+// Send sends the generic request and returns the http.Response or an error
+func (r *Request) Send() (*http.Response, error) {
+	client := r.fetcher.newHttpClient()
+
+	// Let's add a cookie jar to the client if the request URL is HTTP or HTTPS
+	// This is necessary to pass cookie challenge for some servers.
+	if r.request.URL.Scheme == "http" || r.request.URL.Scheme == "https" {
+		jar, err := cookiejar.New(nil)
+		if err != nil {
+			return nil, err
+		}
+		client.Jar = jar
+	}
+
+	for {
+		// Try request
+		res, err := client.Do(r.request)
+		if err == nil {
+			return res, nil // Return successful response
+		}
+
+		// Close the response body if request was unsuccessful
+		if res != nil && res.Body != nil {
+			res.Body.Close()
+		}
+
+		// Retry if the error is due to a lost connection
+		if strings.Contains(err.Error(), connectionLostError) {
+			select {
+			case <-r.request.Context().Done():
+				return nil, err
+			case <-time.After(bounceDelay):
+				continue
+			}
+		}
+
+		return nil, WrapError(err)
+	}
+}
+
+// FetchImage fetches the image using the request and returns the response or an error.
+// It checks for the NotModified status and handles partial content responses.
+func (r *Request) FetchImage() (*http.Response, error) {
+	res, err := r.Send()
+	if err != nil {
+		r.cancel()
+		return nil, err
+	}
+
+	// Closes the response body and cancels request context
+	cancel := func() {
+		res.Body.Close()
+		r.cancel()
+	}
+
+	// If the source image was not modified, close the body and NotModifiedError
+	if res.StatusCode == http.StatusNotModified {
+		cancel()
+		return nil, newNotModifiedError(res.Header)
+	}
+
+	// If the source responds with 206, check if the response contains an entire image.
+	// If not, return an error.
+	if res.StatusCode == http.StatusPartialContent {
+		err = checkPartialContentResponse(res)
+		if err != nil {
+			cancel()
+			return nil, err
+		}
+	} else if res.StatusCode != http.StatusOK {
+		body := extractErraticBody(res)
+		cancel()
+		return nil, newImageResponseStatusError(res.StatusCode, body)
+	}
+
+	// If the response is gzip encoded, wrap it in a gzip reader
+	err = wrapGzipBody(res)
+	if err != nil {
+		cancel()
+		return nil, err
+	}
+
+	// Wrap the response body in a bodyReader to ensure the request context
+	// is cancelled when the body is closed
+	res.Body = &bodyReader{
+		body:    res.Body,
+		request: r,
+	}
+
+	return res, nil
+}
+
+// Cancel cancels the request context
+func (r *Request) Cancel() {
+	r.cancel()
+}
+
+// URL returns the actual URL of the request
+func (r *Request) URL() *url.URL {
+	return r.request.URL
+}
+
+// checkPartialContentResponse if the response is a partial content response,
+// we check if it contains the entire image.
+func checkPartialContentResponse(res *http.Response) error {
+	contentRange := res.Header.Get(httpheaders.ContentRange)
+	rangeParts := contentRangeRe.FindStringSubmatch(contentRange)
+
+	if len(rangeParts) == 0 {
+		return newImagePartialResponseError("Partial response with invalid Content-Range header")
+	}
+
+	if rangeParts[1] == "*" || rangeParts[2] != "0" {
+		return newImagePartialResponseError("Partial response with incomplete content")
+	}
+
+	contentLengthStr := rangeParts[4]
+	if contentLengthStr == "*" {
+		contentLengthStr = res.Header.Get(httpheaders.ContentLength)
+	}
+
+	contentLength, _ := strconv.Atoi(contentLengthStr)
+	rangeEnd, _ := strconv.Atoi(rangeParts[3])
+
+	if contentLength <= 0 || rangeEnd != contentLength-1 {
+		return newImagePartialResponseError("Partial response with incomplete content")
+	}
+
+	return nil
+}
+
+// extractErraticBody extracts the error body from the response if it is a text-based content type
+func extractErraticBody(res *http.Response) string {
+	if strings.HasPrefix(res.Header.Get(httpheaders.ContentType), "text/") {
+		bbody, _ := io.ReadAll(io.LimitReader(res.Body, 1024))
+		return string(bbody)
+	}
+
+	return ""
+}
+
+// wrapGzipBody wraps the response body in a gzip reader if the Content-Encoding is gzip.
+// We set DisableCompression: true to avoid sending the Accept-Encoding: gzip header,
+// since we do not want to compress image data (which is usually already compressed).
+// However, some servers still send gzip-encoded responses regardless.
+func wrapGzipBody(res *http.Response) error {
+	if res.Header.Get(httpheaders.ContentEncoding) == "gzip" {
+		gzipBody, err := gzip.NewReader(res.Body)
+		if err != nil {
+			return nil
+		}
+		res.Body = gzipBody
+		res.Header.Del(httpheaders.ContentEncoding)
+	}
+
+	return nil
+}
+
+// bodyReader is a wrapper around io.ReadCloser which closes original request context
+// when the body is closed.
+type bodyReader struct {
+	body    io.ReadCloser // The body to read from
+	request *Request
+}
+
+// Read reads data from the response body into the provided byte slice
+func (r *bodyReader) Read(p []byte) (int, error) {
+	return r.body.Read(p)
+}
+
+// Close closes the response body and cancels the request context
+func (r *bodyReader) Close() error {
+	defer r.request.cancel()
+	return r.body.Close()
+}

+ 9 - 2
processing_handler.go

@@ -20,6 +20,7 @@ import (
 	"github.com/imgproxy/imgproxy/v3/etag"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/imagedata"
+	"github.com/imgproxy/imgproxy/v3/imagefetcher"
 	"github.com/imgproxy/imgproxy/v3/imagetype"
 	"github.com/imgproxy/imgproxy/v3/imath"
 	"github.com/imgproxy/imgproxy/v3/metrics"
@@ -348,7 +349,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 		return imagedata.Download(ctx, imageURL, "source image", downloadOpts, po.SecurityOptions)
 	}()
 
-	var nmErr imagedata.NotModifiedError
+	var nmErr imagefetcher.NotModifiedError
 
 	switch {
 	case err == nil:
@@ -358,7 +359,13 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 		if config.ETagEnabled && len(etagHandler.ImageEtagExpected()) != 0 {
 			rw.Header().Set("ETag", etagHandler.GenerateExpectedETag())
 		}
-		respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers())
+
+		h := make(map[string]string)
+		for k := range nmErr.Headers() {
+			h[k] = nmErr.Headers().Get(k)
+		}
+
+		respondWithNotModified(reqID, r, rw, po, imageURL, h)
 		return
 
 	default:

+ 0 - 38
security/file_size.go

@@ -1,38 +0,0 @@
-package security
-
-import (
-	"io"
-)
-
-type hardLimitReader struct {
-	r    io.Reader
-	left int
-}
-
-func (lr *hardLimitReader) Read(p []byte) (n int, err error) {
-	if lr.left <= 0 {
-		return 0, newFileSizeError()
-	}
-	if len(p) > lr.left {
-		p = p[0:lr.left]
-	}
-	n, err = lr.r.Read(p)
-	lr.left -= n
-	return
-}
-
-func CheckFileSize(size int, opts Options) error {
-	if opts.MaxSrcFileSize > 0 && size > opts.MaxSrcFileSize {
-		return newFileSizeError()
-	}
-
-	return nil
-}
-
-func LimitFileSize(r io.Reader, opts Options) io.Reader {
-	if opts.MaxSrcFileSize > 0 {
-		return &hardLimitReader{r: r, left: opts.MaxSrcFileSize}
-	}
-
-	return r
-}

+ 51 - 0
security/response_limit.go

@@ -0,0 +1,51 @@
+package security
+
+import (
+	"io"
+	"net/http"
+)
+
+// hardLimitReadCloser is a wrapper around io.ReadCloser
+// that limits the number of bytes it can read from the upstream reader.
+type hardLimitReadCloser struct {
+	r    io.ReadCloser
+	left int
+}
+
+func (lr *hardLimitReadCloser) Read(p []byte) (n int, err error) {
+	if lr.left <= 0 {
+		return 0, newFileSizeError()
+	}
+	if len(p) > lr.left {
+		p = p[0:lr.left]
+	}
+	n, err = lr.r.Read(p)
+	lr.left -= n
+	return
+}
+
+func (lr *hardLimitReadCloser) Close() error {
+	return lr.r.Close()
+}
+
+// LimitResponseSize limits the size of the response body to MaxSrcFileSize (if set).
+// First, it tries to use Content-Length header to check the limit.
+// If Content-Length is not set, it limits the size of the response body by wrapping
+// body reader with hard limit reader.
+func LimitResponseSize(r *http.Response, opts Options) (*http.Response, error) {
+	if opts.MaxSrcFileSize == 0 {
+		return r, nil
+	}
+
+	// If Content-Length was set, limit the size of the response body before reading it
+	size := int(r.ContentLength)
+
+	if size > opts.MaxSrcFileSize {
+		return nil, newFileSizeError()
+	}
+
+	// hard-limit the response body reader
+	r.Body = &hardLimitReadCloser{r: r.Body, left: opts.MaxSrcFileSize}
+
+	return r, nil
+}

+ 4 - 4
stream.go

@@ -69,11 +69,11 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
 		checkErr(ctx, "streaming", err)
 	}
 
-	req, reqCancel, err := imagedata.BuildImageRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
-	defer reqCancel()
+	req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
+	defer req.Cancel()
 	checkErr(ctx, "streaming", err)
 
-	res, err := imagedata.SendRequest(req)
+	res, err := req.Send()
 	if res != nil {
 		defer res.Body.Close()
 	}
@@ -93,7 +93,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
 	if res.StatusCode < 300 {
 		var filename, ext, mimetype string
 
-		_, filename = filepath.Split(req.URL.Path)
+		_, filename = filepath.Split(req.URL().Path)
 		ext = filepath.Ext(filename)
 
 		if len(po.Filename) > 0 {

+ 2 - 2
transport/azure/azure.go

@@ -18,8 +18,8 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/httprange"
-	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
 	"github.com/imgproxy/imgproxy/v3/transport/common"
+	"github.com/imgproxy/imgproxy/v3/transport/generichttp"
 	"github.com/imgproxy/imgproxy/v3/transport/notmodified"
 )
 
@@ -49,7 +49,7 @@ func New() (http.RoundTripper, error) {
 		return nil, err
 	}
 
-	trans, err := defaultTransport.New(false)
+	trans, err := generichttp.New(false)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
transport/gcs/gcs.go

@@ -17,8 +17,8 @@ import (
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/httprange"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
-	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
 	"github.com/imgproxy/imgproxy/v3/transport/common"
+	"github.com/imgproxy/imgproxy/v3/transport/generichttp"
 	"github.com/imgproxy/imgproxy/v3/transport/notmodified"
 )
 
@@ -30,7 +30,7 @@ type transport struct {
 }
 
 func buildHTTPClient(opts ...option.ClientOption) (*http.Client, error) {
-	trans, err := defaultTransport.New(false)
+	trans, err := generichttp.New(false)
 	if err != nil {
 		return nil, err
 	}

+ 59 - 0
transport/generichttp/generic_http.go

@@ -0,0 +1,59 @@
+// Generic HTTP transport for imgproxy
+package generichttp
+
+import (
+	"crypto/tls"
+	"net"
+	"net/http"
+	"syscall"
+	"time"
+
+	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/security"
+	"golang.org/x/net/http2"
+)
+
+func New(verifyNetworks bool) (*http.Transport, error) {
+	dialer := &net.Dialer{
+		Timeout:   30 * time.Second,
+		KeepAlive: 30 * time.Second,
+		DualStack: true,
+	}
+
+	if verifyNetworks {
+		dialer.Control = func(network, address string, c syscall.RawConn) error {
+			return security.VerifySourceNetwork(address)
+		}
+	}
+
+	transport := &http.Transport{
+		Proxy:                 http.ProxyFromEnvironment,
+		DialContext:           dialer.DialContext,
+		MaxIdleConns:          100,
+		MaxIdleConnsPerHost:   config.Workers + 1,
+		IdleConnTimeout:       time.Duration(config.ClientKeepAliveTimeout) * time.Second,
+		TLSHandshakeTimeout:   10 * time.Second,
+		ExpectContinueTimeout: 1 * time.Second,
+		ForceAttemptHTTP2:     false,
+		DisableCompression:    true,
+	}
+
+	if config.ClientKeepAliveTimeout <= 0 {
+		transport.MaxIdleConnsPerHost = -1
+		transport.DisableKeepAlives = true
+	}
+
+	if config.IgnoreSslVerification {
+		transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
+	}
+
+	transport2, err := http2.ConfigureTransports(transport)
+	if err != nil {
+		return nil, err
+	}
+
+	transport2.PingTimeout = 5 * time.Second
+	transport2.ReadIdleTimeout = time.Second
+
+	return transport, nil
+}

+ 2 - 2
transport/s3/s3.go

@@ -22,8 +22,8 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
-	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
 	"github.com/imgproxy/imgproxy/v3/transport/common"
+	"github.com/imgproxy/imgproxy/v3/transport/generichttp"
 )
 
 type s3Client interface {
@@ -49,7 +49,7 @@ func New() (http.RoundTripper, error) {
 		return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config"))
 	}
 
-	trans, err := defaultTransport.New(false)
+	trans, err := generichttp.New(false)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
transport/swift/swift.go

@@ -12,8 +12,8 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
-	defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
 	"github.com/imgproxy/imgproxy/v3/transport/common"
+	"github.com/imgproxy/imgproxy/v3/transport/generichttp"
 	"github.com/imgproxy/imgproxy/v3/transport/notmodified"
 )
 
@@ -22,7 +22,7 @@ type transport struct {
 }
 
 func New() (http.RoundTripper, error) {
-	trans, err := defaultTransport.New(false)
+	trans, err := generichttp.New(false)
 	if err != nil {
 		return nil, err
 	}

+ 84 - 37
transport/transport.go

@@ -1,59 +1,106 @@
+// Package transport provides a custom HTTP transport that supports multiple protocols
+// such as S3, GCS, ABS, Swift, and local file system.
 package transport
 
 import (
-	"crypto/tls"
-	"net"
 	"net/http"
-	"syscall"
-	"time"
-
-	"golang.org/x/net/http2"
 
 	"github.com/imgproxy/imgproxy/v3/config"
-	"github.com/imgproxy/imgproxy/v3/security"
+	"github.com/imgproxy/imgproxy/v3/transport/generichttp"
+
+	azureTransport "github.com/imgproxy/imgproxy/v3/transport/azure"
+	fsTransport "github.com/imgproxy/imgproxy/v3/transport/fs"
+	gcsTransport "github.com/imgproxy/imgproxy/v3/transport/gcs"
+	s3Transport "github.com/imgproxy/imgproxy/v3/transport/s3"
+	swiftTransport "github.com/imgproxy/imgproxy/v3/transport/swift"
 )
 
-func New(verifyNetworks bool) (*http.Transport, error) {
-	dialer := &net.Dialer{
-		Timeout:   30 * time.Second,
-		KeepAlive: 30 * time.Second,
-		DualStack: true,
+// Transport is a wrapper around http.Transport which allows to track registered protocols
+type Transport struct {
+	transport *http.Transport
+	schemes   map[string]struct{}
+}
+
+// NewTransport creates a new HTTP transport with no protocols registered
+func NewTransport() (*Transport, error) {
+	transport, err := generichttp.New(true)
+	if err != nil {
+		return nil, err
 	}
 
-	if verifyNetworks {
-		dialer.Control = func(network, address string, c syscall.RawConn) error {
-			return security.VerifySourceNetwork(address)
-		}
+	// http and https are always registered
+	schemes := map[string]struct{}{
+		"http":  {},
+		"https": {},
 	}
 
-	transport := &http.Transport{
-		Proxy:                 http.ProxyFromEnvironment,
-		DialContext:           dialer.DialContext,
-		MaxIdleConns:          100,
-		MaxIdleConnsPerHost:   config.Workers + 1,
-		IdleConnTimeout:       time.Duration(config.ClientKeepAliveTimeout) * time.Second,
-		TLSHandshakeTimeout:   10 * time.Second,
-		ExpectContinueTimeout: 1 * time.Second,
-		ForceAttemptHTTP2:     false,
-		DisableCompression:    true,
+	t := &Transport{
+		transport,
+		schemes,
+	}
+
+	err = t.registerAllProtocols()
+	if err != nil {
+		return nil, err
+	}
+
+	return t, nil
+}
+
+// Transport returns the underlying http.Transport
+func (t *Transport) Transport() *http.Transport {
+	return t.transport
+}
+
+// RegisterProtocol registers a new transport protocol with the transport
+func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) {
+	t.transport.RegisterProtocol(scheme, rt)
+	t.schemes[scheme] = struct{}{}
+}
+
+// IsProtocolRegistered checks if a protocol is registered in the transport
+func (t *Transport) IsProtocolRegistered(scheme string) bool {
+	_, ok := t.schemes[scheme]
+	return ok
+}
+
+// RegisterAllProtocols registers all enabled protocols in the given transport
+func (t *Transport) registerAllProtocols() error {
+	if config.LocalFileSystemRoot != "" {
+		t.RegisterProtocol("local", fsTransport.New())
 	}
 
-	if config.ClientKeepAliveTimeout <= 0 {
-		transport.MaxIdleConnsPerHost = -1
-		transport.DisableKeepAlives = true
+	if config.S3Enabled {
+		if tr, err := s3Transport.New(); err != nil {
+			return err
+		} else {
+			t.RegisterProtocol("s3", tr)
+		}
 	}
 
-	if config.IgnoreSslVerification {
-		transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
+	if config.GCSEnabled {
+		if tr, err := gcsTransport.New(); err != nil {
+			return err
+		} else {
+			t.RegisterProtocol("gs", tr)
+		}
 	}
 
-	transport2, err := http2.ConfigureTransports(transport)
-	if err != nil {
-		return nil, err
+	if config.ABSEnabled {
+		if tr, err := azureTransport.New(); err != nil {
+			return err
+		} else {
+			t.RegisterProtocol("abs", tr)
+		}
 	}
 
-	transport2.PingTimeout = 5 * time.Second
-	transport2.ReadIdleTimeout = time.Second
+	if config.SwiftEnabled {
+		if tr, err := swiftTransport.New(); err != nil {
+			return err
+		} else {
+			t.RegisterProtocol("swift", tr)
+		}
+	}
 
-	return transport, nil
+	return nil
 }