فهرست منبع

TestServer, AllowNetworks -> http.Transport (#1523)

Victor Sokolov 3 هفته پیش
والد
کامیت
2d6b5a8d5a

+ 30 - 55
auximageprovider/static_provider_test.go

@@ -3,77 +3,56 @@ package auximageprovider
 import (
 	"encoding/base64"
 	"io"
-	"net/http"
-	"net/http/httptest"
-	"os"
 	"strconv"
 	"testing"
 
 	"github.com/stretchr/testify/suite"
 
-	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/fetcher"
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/imagedata"
 	"github.com/imgproxy/imgproxy/v3/options"
+	"github.com/imgproxy/imgproxy/v3/testutil"
 )
 
 type ImageProviderTestSuite struct {
-	suite.Suite
+	testutil.LazySuite
 
-	server      *httptest.Server
 	testData    []byte
 	testDataB64 string
 
-	// Server state
-	status int
-	data   []byte
-	header http.Header
+	testServer testutil.LazyTestServer
+	idf        *imagedata.Factory
 }
 
 func (s *ImageProviderTestSuite) SetupSuite() {
-	config.Reset()
-	config.AllowLoopbackSourceAddresses = true
+	s.testData = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
+	s.testDataB64 = base64.StdEncoding.EncodeToString(s.testData)
 
-	// Load test image data
-	f, err := os.Open("../testdata/test1.jpg")
-	s.Require().NoError(err)
-	defer f.Close()
+	fc := fetcher.NewDefaultConfig()
+	fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
 
-	data, err := io.ReadAll(f)
+	f, err := fetcher.New(&fc)
 	s.Require().NoError(err)
 
-	s.testData = data
-	s.testDataB64 = base64.StdEncoding.EncodeToString(data)
+	s.idf = imagedata.NewFactory(f)
 
-	// Create test server
-	s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-		for k, vv := range s.header {
-			for _, v := range vv {
-				rw.Header().Add(k, v)
-			}
-		}
-
-		data := s.data
-		if data == nil {
-			data = s.testData
-		}
-
-		rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data)))
-		rw.WriteHeader(s.status)
-		rw.Write(data)
-	}))
-}
+	s.testServer, _ = testutil.NewLazySuiteTestServer(
+		s,
+		func(srv *testutil.TestServer) error {
+			srv.SetHeaders(
+				httpheaders.ContentType, "image/jpeg",
+				httpheaders.ContentLength, strconv.Itoa(len(s.testData)),
+			).SetBody(s.testData)
 
-func (s *ImageProviderTestSuite) TearDownSuite() {
-	s.server.Close()
+			return nil
+		},
+	)
 }
 
-func (s *ImageProviderTestSuite) SetupTest() {
-	s.status = http.StatusOK
-	s.data = nil
-	s.header = http.Header{}
-	s.header.Set(httpheaders.ContentType, "image/jpeg")
+func (s *ImageProviderTestSuite) SetupSubTest() {
+	// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
+	s.ResetLazyObjects()
 }
 
 // Helper function to read data from ImageData
@@ -114,7 +93,7 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
 		},
 		{
 			name:   "URL",
-			config: &StaticConfig{URL: s.server.URL},
+			config: &StaticConfig{URL: s.testServer().URL()},
 			validateFunc: func(provider Provider) {
 				s.Equal(s.testData, s.readImageData(provider))
 			},
@@ -149,10 +128,12 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
 		},
 		{
 			name:   "HeadersPassedThrough",
-			config: &StaticConfig{URL: s.server.URL},
+			config: &StaticConfig{URL: s.testServer().URL()},
 			setupFunc: func() {
-				s.header.Set("X-Custom-Header", "test-value")
-				s.header.Set(httpheaders.CacheControl, "max-age=3600")
+				s.testServer().SetHeaders(
+					"X-Custom-Header", "test-value",
+					httpheaders.CacheControl, "max-age=3600",
+				)
 			},
 			validateFunc: func(provider Provider) {
 				imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{})
@@ -167,19 +148,13 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
 		},
 	}
 
-	fc := fetcher.NewDefaultConfig()
-	f, err := fetcher.New(&fc)
-	s.Require().NoError(err)
-
-	idf := imagedata.NewFactory(f)
-
 	for _, tt := range tests {
 		s.T().Run(tt.name, func(t *testing.T) {
 			if tt.setupFunc != nil {
 				tt.setupFunc()
 			}
 
-			provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", idf)
+			provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", s.idf)
 
 			if tt.expectError {
 				s.Require().Error(err)

+ 3 - 3
config/config.go

@@ -58,7 +58,7 @@ var (
 	PngUnlimited                bool
 	SvgUnlimited                bool
 	MaxResultDimension          int
-	AllowedProcessiongOptions   []string
+	AllowedProcessingOptions    []string
 	AllowSecurityOptions        bool
 
 	JpegProgressive       bool
@@ -267,7 +267,7 @@ func Reset() {
 	PngUnlimited = false
 	SvgUnlimited = false
 	MaxResultDimension = 0
-	AllowedProcessiongOptions = make([]string, 0)
+	AllowedProcessingOptions = make([]string, 0)
 	AllowSecurityOptions = false
 
 	JpegProgressive = false
@@ -502,7 +502,7 @@ func Configure() error {
 	configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED")
 
 	configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION")
-	configurators.StringSlice(&AllowedProcessiongOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
+	configurators.StringSlice(&AllowedProcessingOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
 
 	configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS")
 

+ 11 - 2
fetcher/errors.go

@@ -4,10 +4,11 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"io"
 	"net/http"
 
+	"github.com/imgproxy/imgproxy/v3/fetcher/transport/generichttp"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
-	"github.com/imgproxy/imgproxy/v3/security"
 )
 
 const msgSourceImageIsUnreachable = "Source image is unreachable"
@@ -157,13 +158,21 @@ func (e NotModifiedError) Headers() http.Header {
 func WrapError(err error) error {
 	isTimeout := false
 
-	var secArrdErr security.SourceAddressError
+	var secArrdErr generichttp.SourceAddressError
 
 	switch {
 	case errors.Is(err, context.DeadlineExceeded):
 		isTimeout = true
 	case errors.Is(err, context.Canceled):
 		return newImageRequestCanceledError(err)
+	case err == io.ErrUnexpectedEOF:
+		return ierrors.Wrap(
+			newImageRequestError(err),
+			1,
+			ierrors.WithPublicMessage("source image is corrupted"),
+			ierrors.WithShouldReport(false),
+			ierrors.WithStatusCode(http.StatusUnprocessableEntity),
+		)
 	case errors.As(err, &secArrdErr):
 		return ierrors.Wrap(
 			err,

+ 13 - 4
fetcher/transport/generichttp/config.go

@@ -10,15 +10,21 @@ import (
 
 // Config holds the configuration for the generic HTTP transport
 type Config struct {
-	ClientKeepAliveTimeout time.Duration
-	IgnoreSslVerification  bool
+	ClientKeepAliveTimeout        time.Duration
+	IgnoreSslVerification         bool
+	AllowLoopbackSourceAddresses  bool
+	AllowLinkLocalSourceAddresses bool
+	AllowPrivateSourceAddresses   bool
 }
 
 // NewDefaultConfig returns a new default configuration for the generic HTTP transport
 func NewDefaultConfig() Config {
 	return Config{
-		ClientKeepAliveTimeout: 90 * time.Second,
-		IgnoreSslVerification:  false,
+		ClientKeepAliveTimeout:        90 * time.Second,
+		IgnoreSslVerification:         false,
+		AllowLoopbackSourceAddresses:  false,
+		AllowLinkLocalSourceAddresses: false,
+		AllowPrivateSourceAddresses:   true,
 	}
 }
 
@@ -28,6 +34,9 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
 
 	c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
 	c.IgnoreSslVerification = config.IgnoreSslVerification
+	c.AllowLinkLocalSourceAddresses = config.AllowLinkLocalSourceAddresses
+	c.AllowLoopbackSourceAddresses = config.AllowLoopbackSourceAddresses
+	c.AllowPrivateSourceAddresses = config.AllowPrivateSourceAddresses
 
 	return c, nil
 }

+ 23 - 0
fetcher/transport/generichttp/errors.go

@@ -0,0 +1,23 @@
+package generichttp
+
+import (
+	"net/http"
+
+	"github.com/imgproxy/imgproxy/v3/ierrors"
+)
+
+type (
+	SourceAddressError string
+)
+
+func newSourceAddressError(msg string) error {
+	return ierrors.Wrap(
+		SourceAddressError(msg),
+		1,
+		ierrors.WithStatusCode(http.StatusNotFound),
+		ierrors.WithPublicMessage("Invalid source URL"),
+		ierrors.WithShouldReport(false),
+	)
+}
+
+func (e SourceAddressError) Error() string { return string(e) }

+ 28 - 2
fetcher/transport/generichttp/generic_http.go

@@ -3,12 +3,12 @@ package generichttp
 
 import (
 	"crypto/tls"
+	"fmt"
 	"net"
 	"net/http"
 	"syscall"
 	"time"
 
-	"github.com/imgproxy/imgproxy/v3/security"
 	"golang.org/x/net/http2"
 )
 
@@ -25,7 +25,7 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
 
 	if verifyNetworks {
 		dialer.Control = func(network, address string, c syscall.RawConn) error {
-			return security.VerifySourceNetwork(address)
+			return verifySourceNetwork(address, config)
 		}
 	}
 
@@ -66,3 +66,29 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
 
 	return transport, nil
 }
+
+func verifySourceNetwork(addr string, config *Config) error {
+	host, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		host = addr
+	}
+
+	ip := net.ParseIP(host)
+	if ip == nil {
+		return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
+	}
+
+	if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
+		return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
+	}
+
+	if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
+		return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
+	}
+
+	if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
+		return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
+	}
+
+	return nil
+}

+ 3 - 14
security/source_test.go → fetcher/transport/generichttp/generic_http_test.go

@@ -1,9 +1,8 @@
-package security
+package generichttp
 
 import (
 	"testing"
 
-	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/stretchr/testify/require"
 )
 
@@ -100,24 +99,14 @@ func TestVerifySourceNetwork(t *testing.T) {
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			// Backup original config
-			originalLoopback := config.AllowLoopbackSourceAddresses
-			originalLinkLocal := config.AllowLinkLocalSourceAddresses
-			originalPrivate := config.AllowPrivateSourceAddresses
-
-			// Restore original config after test
-			defer func() {
-				config.AllowLoopbackSourceAddresses = originalLoopback
-				config.AllowLinkLocalSourceAddresses = originalLinkLocal
-				config.AllowPrivateSourceAddresses = originalPrivate
-			}()
+			config := NewDefaultConfig()
 
 			// Override config for the test
 			config.AllowLoopbackSourceAddresses = tc.allowLoopback
 			config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal
 			config.AllowPrivateSourceAddresses = tc.allowPrivate
 
-			err := VerifySourceNetwork(tc.addr)
+			err := verifySourceNetwork(tc.addr, &config)
 
 			if tc.expectErr {
 				require.Error(t, err)

+ 0 - 1
handlers/processing/handler_test.go

@@ -1 +0,0 @@
-package processing

+ 85 - 139
handlers/stream/handler_test.go

@@ -6,7 +6,6 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
-	"path/filepath"
 	"strconv"
 	"testing"
 	"time"
@@ -22,23 +21,24 @@ import (
 	"github.com/imgproxy/imgproxy/v3/testutil"
 )
 
-const (
-	testDataPath = "../../testdata"
-)
-
 type HandlerTestSuite struct {
 	testutil.LazySuite
 
+	testData *testutil.TestDataProvider
+
 	rwConf    testutil.LazyObj[*responsewriter.Config]
 	rwFactory testutil.LazyObj[*responsewriter.Factory]
 
 	config  testutil.LazyObj[*Config]
 	handler testutil.LazyObj[*Handler]
+
+	testServer testutil.LazyTestServer
 }
 
 func (s *HandlerTestSuite) SetupSuite() {
 	config.Reset()
-	config.AllowLoopbackSourceAddresses = true
+
+	s.testData = testutil.NewTestDataProvider(s.T)
 
 	s.rwConf, _ = testutil.NewLazySuiteObj(
 		s,
@@ -67,6 +67,7 @@ func (s *HandlerTestSuite) SetupSuite() {
 		s,
 		func() (*Handler, error) {
 			fc := fetcher.NewDefaultConfig()
+			fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
 
 			fetcher, err := fetcher.New(&fc)
 			s.Require().NoError(err)
@@ -75,36 +76,27 @@ func (s *HandlerTestSuite) SetupSuite() {
 		},
 	)
 
+	s.testServer, _ = testutil.NewLazySuiteTestServer(s)
+
 	// Silence logs during tests
 	logrus.SetOutput(io.Discard)
 }
 
 func (s *HandlerTestSuite) TearDownSuite() {
-	config.Reset()
 	logrus.SetOutput(os.Stdout)
 }
 
-func (s *HandlerTestSuite) SetupTest() {
-	config.Reset()
-	config.AllowLoopbackSourceAddresses = true
-}
-
 func (s *HandlerTestSuite) SetupSubTest() {
 	// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
 	s.ResetLazyObjects()
 }
 
-func (s *HandlerTestSuite) readTestFile(name string) []byte {
-	data, err := os.ReadFile(filepath.Join(testDataPath, name))
-	s.Require().NoError(err)
-	return data
-}
-
 func (s *HandlerTestSuite) execute(
 	imageURL string,
 	header http.Header,
 	po *options.ProcessingOptions,
-) *httptest.ResponseRecorder {
+) *http.Response {
+	imageURL = s.testServer().URL() + imageURL
 	req := httptest.NewRequest("GET", "/", nil)
 	httpheaders.CopyAll(header, req.Header, true)
 
@@ -115,51 +107,42 @@ func (s *HandlerTestSuite) execute(
 	err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
 	s.Require().NoError(err)
 
-	return rw
+	return rw.Result()
 }
 
 // TestHandlerBasicRequest checks basic streaming request
 func (s *HandlerTestSuite) TestHandlerBasicRequest() {
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
 
-	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
+	res := s.execute("", nil, &options.ProcessingOptions{})
 
-	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
 	s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
 
 	// Verify we get the original image data
-	actual := rw.Body.Bytes()
+	actual, err := io.ReadAll(res.Body)
+	s.Require().NoError(err)
 	s.Require().Equal(data, actual)
 }
 
 // TestHandlerResponseHeadersPassthrough checks that original response headers are
 // passed through to the client
 func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 	contentLength := len(data)
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
-		w.Header().Set(httpheaders.AcceptRanges, "bytes")
-		w.Header().Set(httpheaders.Etag, "etag")
-		w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().SetHeaders(
+		httpheaders.ContentType, "image/png",
+		httpheaders.ContentLength, strconv.Itoa(contentLength),
+		httpheaders.AcceptRanges, "bytes",
+		httpheaders.Etag, "etag",
+		httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
+	).SetBody(data)
 
-	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
+	res := s.execute("", nil, &options.ProcessingOptions{})
 
-	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
 	s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
 	s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
@@ -172,42 +155,34 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
 // to the server
 func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
 	etag := `"test-etag-123"`
-	data := s.readTestFile("test1.png")
-
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		// Verify that If-None-Match header is passed through
-		s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
-		s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
-		s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
-
-		w.Header().Set(httpheaders.Etag, etag)
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	data := s.testData.Read("test1.png")
+
+	s.testServer().
+		SetBody(data).
+		SetHeaders(httpheaders.Etag, etag).
+		SetHook(func(r *http.Request, rw http.ResponseWriter) {
+			// Verify that If-None-Match header is passed through
+			s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
+			s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
+			s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
+		})
 
 	h := make(http.Header)
 	h.Set(httpheaders.IfNoneMatch, etag)
 	h.Set(httpheaders.AcceptEncoding, "gzip")
 	h.Set(httpheaders.Range, "bytes=*")
 
-	rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
+	res := s.execute("", h, &options.ProcessingOptions{})
 
-	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
 	s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
 }
 
 // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
 func (s *HandlerTestSuite) TestHandlerContentDisposition() {
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
 
 	po := &options.ProcessingOptions{
 		Filename:         "custom_name",
@@ -215,10 +190,8 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
 	}
 
 	// Use a URL with a .png extension to help content disposition logic
-	imageURL := ts.URL + "/test.png"
-	rw := s.execute(imageURL, nil, po)
+	res := s.execute("/test.png", nil, po)
 
-	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
 	s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
 	s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
@@ -229,7 +202,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 	type testCase struct {
 		name                    string
 		cacheControlPassthrough bool
-		setupOriginHeaders      func(http.ResponseWriter)
+		setupOriginHeaders      func()
 		timestampOffset         *time.Duration // nil for no timestamp, otherwise the offset from now
 		expectedStatusCode      int
 		validate                func(*testing.T, *http.Response)
@@ -250,8 +223,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "Passthrough",
 			cacheControlPassthrough: true,
-			setupOriginHeaders: func(w http.ResponseWriter) {
-				w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
+			setupOriginHeaders: func() {
+				s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
 			},
 			timestampOffset:    nil,
 			expectedStatusCode: 200,
@@ -263,8 +236,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "ExpiresPassthrough",
 			cacheControlPassthrough: true,
-			setupOriginHeaders: func(w http.ResponseWriter) {
-				w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
+			setupOriginHeaders: func() {
+				s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
 			},
 			timestampOffset:    nil,
 			expectedStatusCode: 200,
@@ -278,8 +251,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "PassthroughDisabled",
 			cacheControlPassthrough: false,
-			setupOriginHeaders: func(w http.ResponseWriter) {
-				w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
+			setupOriginHeaders: func() {
+				s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
 			},
 			timestampOffset:    nil,
 			expectedStatusCode: 200,
@@ -291,7 +264,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "WithProcessingOptionsExpires",
 			cacheControlPassthrough: false,
-			setupOriginHeaders:      func(w http.ResponseWriter) {}, // No origin headers
 			timestampOffset:         &oneHour,
 			expectedStatusCode:      200,
 			validate: func(t *testing.T, res *http.Response) {
@@ -303,9 +275,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "ProcessingOptionsOverridesOrigin",
 			cacheControlPassthrough: true,
-			setupOriginHeaders: func(w http.ResponseWriter) {
+			setupOriginHeaders: func() {
 				// Origin has a longer cache time
-				w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
+				s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
 			},
 			timestampOffset:    &thirtyMinutes,
 			expectedStatusCode: 200,
@@ -318,10 +290,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "BothHeadersPassthroughEnabled",
 			cacheControlPassthrough: true,
-			setupOriginHeaders: func(w http.ResponseWriter) {
+			setupOriginHeaders: func() {
 				// Origin has both Cache-Control and Expires headers
-				w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
-				w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
+				s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
+				s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
 			},
 			timestampOffset:    nil,
 			expectedStatusCode: 200,
@@ -336,10 +308,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "ProcessingOptionsOverridesBothOriginHeaders",
 			cacheControlPassthrough: true,
-			setupOriginHeaders: func(w http.ResponseWriter) {
+			setupOriginHeaders: func() {
 				// Origin has both Cache-Control and Expires headers with longer cache times
-				w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
-				w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
+				s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
+				s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
 			},
 			timestampOffset:    &fortyFiveMinutes, // Shorter than origin headers
 			expectedStatusCode: 200,
@@ -352,7 +324,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 		{
 			name:                    "NoOriginHeaders",
 			cacheControlPassthrough: false,
-			setupOriginHeaders:      func(w http.ResponseWriter) {}, // Origin has no cache headers
 			timestampOffset:         nil,
 			expectedStatusCode:      200,
 			validate: func(t *testing.T, res *http.Response) {
@@ -363,15 +334,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 
 	for _, tc := range testCases {
 		s.Run(tc.name, func() {
-			data := s.readTestFile("test1.png")
+			data := s.testData.Read("test1.png")
+
+			if tc.setupOriginHeaders != nil {
+				tc.setupOriginHeaders()
+			}
 
-			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				tc.setupOriginHeaders(w)
-				w.Header().Set(httpheaders.ContentType, "image/png")
-				w.WriteHeader(200)
-				w.Write(data)
-			}))
-			defer ts.Close()
+			s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
 
 			s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
 			s.rwConf().DefaultTTL = 4242
@@ -383,9 +352,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 				po.Expires = &expires
 			}
 
-			rw := s.execute(ts.URL, nil, po)
-
-			res := rw.Result()
+			res := s.execute("", nil, po)
 			s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
 			tc.validate(s.T(), res)
 		})
@@ -405,85 +372,64 @@ func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
 
 // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
 func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
 
-	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
+	res := s.execute("", nil, &options.ProcessingOptions{})
 
-	res := rw.Result()
-	s.Require().Equal(200, res.StatusCode)
+	s.Require().Equal(http.StatusOK, res.StatusCode)
 	s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
 }
 
 // TestHandlerErrorResponse tests the error responses from the streaming service.
 func (s *HandlerTestSuite) TestHandlerErrorResponse() {
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.WriteHeader(404)
-		w.Write([]byte("Not Found"))
-	}))
-	defer ts.Close()
+	s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
 
-	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
+	res := s.execute("", nil, &options.ProcessingOptions{})
 
-	res := rw.Result()
-	s.Require().Equal(404, res.StatusCode)
+	s.Require().Equal(http.StatusNotFound, res.StatusCode)
 }
 
 // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
 func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
 	s.config().CookiePassthrough = true
 
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		// Verify cookies are passed through
-		cookie, cerr := r.Cookie("test_cookie")
-		if cerr == nil {
-			s.Equal("test_value", cookie.Value)
-		}
-
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().
+		SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
+		SetHook(func(r *http.Request, rw http.ResponseWriter) {
+			// Verify cookies are passed through
+			cookie, cerr := r.Cookie("test_cookie")
+			if cerr == nil {
+				s.Equal("test_value", cookie.Value)
+			}
+		}).SetBody(data)
 
 	h := make(http.Header)
 	h.Set(httpheaders.Cookie, "test_cookie=test_value")
 
-	rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
+	res := s.execute("", h, &options.ProcessingOptions{})
 
-	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
 }
 
 // TestHandlerCanonicalHeader tests that the canonical header is set correctly
 func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
-	data := s.readTestFile("test1.png")
+	data := s.testData.Read("test1.png")
 
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set(httpheaders.ContentType, "image/png")
-		w.WriteHeader(200)
-		w.Write(data)
-	}))
-	defer ts.Close()
+	s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
 
 	for _, sc := range []bool{true, false} {
 		s.rwConf().SetCanonicalHeader = sc
 
-		rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
+		res := s.execute("", nil, &options.ProcessingOptions{})
 
-		res := rw.Result()
 		s.Require().Equal(200, res.StatusCode)
 
 		if sc {
-			s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
+			s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
 		} else {
 			s.Require().Empty(res.Header.Get(httpheaders.Link))
 		}

+ 93 - 107
imagedata/image_data_test.go

@@ -6,17 +6,13 @@ import (
 	"context"
 	"encoding/base64"
 	"fmt"
-	"io"
 	"net"
 	"net/http"
-	"net/http/httptest"
-	"os"
 	"strconv"
 	"testing"
 
 	"github.com/stretchr/testify/suite"
 
-	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/fetcher"
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
@@ -25,88 +21,70 @@ import (
 )
 
 type ImageDataTestSuite struct {
-	suite.Suite
+	testutil.LazySuite
 
-	server *httptest.Server
+	fetcherCfg testutil.LazyObj[*fetcher.Config]
+	factory    testutil.LazyObj[*Factory]
+	testServer testutil.LazyTestServer
 
-	status  int
-	data    []byte
-	header  http.Header
-	check   func(*http.Request)
-	factory *Factory
-
-	defaultData []byte
+	data []byte
 }
 
 func (s *ImageDataTestSuite) SetupSuite() {
-	config.Reset()
-	config.ClientKeepAliveTimeout = 0
-
-	f, err := os.Open("../testdata/test1.jpg")
-	s.Require().NoError(err)
-	defer f.Close()
-
-	data, err := io.ReadAll(f)
-	s.Require().NoError(err)
-
-	s.defaultData = data
-
-	s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-		if s.check != nil {
-			s.check(r)
-		}
-
-		httpheaders.CopyAll(s.header, rw.Header(), true)
-
-		data := s.data
-		if data == nil {
-			data = s.defaultData
-		}
-
-		rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
-
-		rw.WriteHeader(s.status)
-		rw.Write(data)
-	}))
-
-	c, err := fetcher.LoadConfigFromEnv(nil)
-	s.Require().NoError(err)
+	s.data = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
 
-	fetcher, err := fetcher.New(c)
-	s.Require().NoError(err)
+	s.fetcherCfg, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*fetcher.Config, error) {
+			c := fetcher.NewDefaultConfig()
+			c.Transport.HTTP.AllowLoopbackSourceAddresses = true
+			c.Transport.HTTP.ClientKeepAliveTimeout = 0
 
-	s.factory = NewFactory(fetcher)
-}
+			return &c, nil
+		},
+	)
+
+	s.factory, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*Factory, error) {
+			fetcher, err := fetcher.New(s.fetcherCfg())
+			if err != nil {
+				return nil, err
+			}
 
-func (s *ImageDataTestSuite) TearDownSuite() {
-	s.server.Close()
+			return NewFactory(fetcher), nil
+		},
+	)
+
+	s.testServer, _ = testutil.NewLazySuiteTestServer(
+		s,
+		func(srv *testutil.TestServer) error {
+			// Default headers and body for 200 OK response
+			srv.SetHeaders(
+				httpheaders.ContentType, "image/jpeg",
+				httpheaders.ContentLength, strconv.Itoa(len(s.data)),
+			).SetBody(s.data)
+
+			return nil
+		},
+	)
 }
 
-func (s *ImageDataTestSuite) SetupTest() {
-	config.Reset()
-	config.AllowLoopbackSourceAddresses = true
-
-	s.status = http.StatusOK
-	s.data = nil
-	s.check = nil
-
-	s.header = http.Header{}
-	s.header.Set("Content-Type", "image/jpeg")
-
+func (s *ImageDataTestSuite) SetupSubTest() {
+	// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
+	s.ResetLazyObjects()
 }
 
 func (s *ImageDataTestSuite) TestDownloadStatusOK() {
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().NoError(err)
 	s.Require().NotNil(imgdata)
-	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
+	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
 	s.Require().Equal(imagetype.JPEG, imgdata.Format())
 }
 
 func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
-	s.status = http.StatusPartialContent
-
 	testCases := []struct {
 		name         string
 		contentRange string
@@ -114,17 +92,17 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
 	}{
 		{
 			name:         "Full Content-Range",
-			contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
+			contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-1, len(s.data)),
 			expectErr:    false,
 		},
 		{
 			name:         "Partial Content-Range, early end",
-			contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-2, len(s.defaultData)),
+			contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-2, len(s.data)),
 			expectErr:    true,
 		},
 		{
 			name:         "Partial Content-Range, late start",
-			contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
+			contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.data)-1, len(s.data)),
 			expectErr:    true,
 		},
 		{
@@ -139,39 +117,41 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
 		},
 		{
 			name:         "Unknown Content-Range range",
-			contentRange: fmt.Sprintf("bytes */%d", len(s.defaultData)),
+			contentRange: fmt.Sprintf("bytes */%d", len(s.data)),
 			expectErr:    true,
 		},
 		{
 			name:         "Unknown Content-Range size, full range",
-			contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-1),
+			contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-1),
 			expectErr:    false,
 		},
 		{
 			name:         "Unknown Content-Range size, early end",
-			contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-2),
+			contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-2),
 			expectErr:    true,
 		},
 		{
 			name:         "Unknown Content-Range size, late start",
-			contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.defaultData)-1),
+			contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.data)-1),
 			expectErr:    true,
 		},
 	}
 
 	for _, tc := range testCases {
 		s.Run(tc.name, func() {
-			s.header.Set("Content-Range", tc.contentRange)
+			s.testServer().
+				SetHeaders(httpheaders.ContentRange, tc.contentRange).
+				SetStatusCode(http.StatusPartialContent)
 
-			imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+			imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 			if tc.expectErr {
 				s.Require().Error(err)
-				s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
+				s.Require().Equal(http.StatusNotFound, ierrors.Wrap(err, 0).StatusCode())
 			} else {
 				s.Require().NoError(err)
 				s.Require().NotNil(imgdata)
-				s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
+				s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
 				s.Require().Equal(imagetype.JPEG, imgdata.Format())
 			}
 		})
@@ -179,11 +159,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
 }
 
 func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
-	s.status = http.StatusNotFound
-	s.data = []byte("Not Found")
-	s.header.Set("Content-Type", "text/plain")
+	s.testServer().
+		SetStatusCode(http.StatusNotFound).
+		SetBody([]byte("Not Found")).
+		SetHeaders(httpheaders.ContentType, "text/plain")
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
 	s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -191,11 +172,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
 }
 
 func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
-	s.status = http.StatusForbidden
-	s.data = []byte("Forbidden")
-	s.header.Set("Content-Type", "text/plain")
+	s.testServer().
+		SetStatusCode(http.StatusForbidden).
+		SetBody([]byte("Forbidden")).
+		SetHeaders(httpheaders.ContentType, "text/plain")
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
 	s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -203,11 +185,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
 }
 
 func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() {
-	s.status = http.StatusInternalServerError
-	s.data = []byte("Internal Server Error")
-	s.header.Set("Content-Type", "text/plain")
+	s.testServer().
+		SetStatusCode(http.StatusInternalServerError).
+		SetBody([]byte("Internal Server Error")).
+		SetHeaders(httpheaders.ContentType, "text/plain")
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
 	s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -221,7 +204,7 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
 
 	serverURL := fmt.Sprintf("http://%s", l.Addr().String())
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
 	s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -229,19 +212,19 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
 }
 
 func (s *ImageDataTestSuite) TestDownloadInvalidImage() {
-	s.data = []byte("invalid")
+	s.testServer().SetBody([]byte("invalid"))
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
-	s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
+	s.Require().Equal(http.StatusUnprocessableEntity, ierrors.Wrap(err, 0).StatusCode())
 	s.Require().Nil(imgdata)
 }
 
 func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
-	config.AllowLoopbackSourceAddresses = false
+	s.fetcherCfg().Transport.HTTP.AllowLoopbackSourceAddresses = false
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().Error(err)
 	s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -249,11 +232,10 @@ func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
 }
 
 func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() {
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{
 		MaxSrcFileSize: 1,
 	})
 
-	fmt.Println(err)
 	s.Require().Error(err)
 	s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
 	s.Require().Nil(imgdata)
@@ -263,39 +245,43 @@ func (s *ImageDataTestSuite) TestDownloadGzip() {
 	buf := new(bytes.Buffer)
 
 	enc := gzip.NewWriter(buf)
-	_, err := enc.Write(s.defaultData)
+	_, err := enc.Write(s.data)
 	s.Require().NoError(err)
 	err = enc.Close()
 	s.Require().NoError(err)
 
-	s.data = buf.Bytes()
-	s.header.Set("Content-Encoding", "gzip")
+	s.testServer().
+		SetBody(buf.Bytes()).
+		SetHeaders(
+			httpheaders.ContentEncoding, "gzip",
+			httpheaders.ContentLength, strconv.Itoa(buf.Len()), // Update Content-Length
+		)
 
-	imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
+	imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
 
 	s.Require().NoError(err)
 	s.Require().NotNil(imgdata)
-	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
+	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
 	s.Require().Equal(imagetype.JPEG, imgdata.Format())
 }
 
 func (s *ImageDataTestSuite) TestFromFile() {
-	imgdata, err := s.factory.NewFromPath("../testdata/test1.jpg")
+	imgdata, err := s.factory().NewFromPath("../testdata/test1.jpg")
 
 	s.Require().NoError(err)
 	s.Require().NotNil(imgdata)
-	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
+	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
 	s.Require().Equal(imagetype.JPEG, imgdata.Format())
 }
 
 func (s *ImageDataTestSuite) TestFromBase64() {
-	b64 := base64.StdEncoding.EncodeToString(s.defaultData)
+	b64 := base64.StdEncoding.EncodeToString(s.data)
 
-	imgdata, err := s.factory.NewFromBase64(b64)
+	imgdata, err := s.factory().NewFromBase64(b64)
 
 	s.Require().NoError(err)
 	s.Require().NotNil(imgdata)
-	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
+	s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
 	s.Require().Equal(imagetype.JPEG, imgdata.Format())
 }
 

+ 12 - 14
integration/processing_handler_test.go

@@ -27,10 +27,7 @@ type ProcessingHandlerTestSuite struct {
 
 func (s *ProcessingHandlerTestSuite) SetupTest() {
 	config.Reset() // We reset config only at the start of each test
-
-	// NOTE: This must be moved to security config
-	config.AllowLoopbackSourceAddresses = true
-	// NOTE: end note
+	s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
 }
 
 func (s *ProcessingHandlerTestSuite) SetupSubTest() {
@@ -142,13 +139,13 @@ func (s *ProcessingHandlerTestSuite) TestSourceNetworkValidation() {
 
 	// We wrap this in a subtest to reset s.router()
 	s.Run("AllowLoopbackSourceAddressesTrue", func() {
-		config.AllowLoopbackSourceAddresses = true
+		s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
 		res := s.GET(url)
 		s.Require().Equal(http.StatusOK, res.StatusCode)
 	})
 
 	s.Run("AllowLoopbackSourceAddressesFalse", func() {
-		config.AllowLoopbackSourceAddresses = false
+		s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = false
 		res := s.GET(url)
 		s.Require().Equal(http.StatusNotFound, res.StatusCode)
 	})
@@ -256,7 +253,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() {
-	config.CacheControlPassthrough = true
+	s.Config().Server.ResponseWriter.CacheControlPassthrough = true
 
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat))
@@ -290,7 +287,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughDisabled() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
-	config.ETagEnabled = false
+	s.Config().Handlers.Processing.ETagEnabled = false
 
 	res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png")
 
@@ -299,7 +296,7 @@ func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
-	config.ETagEnabled = true
+	s.Config().Handlers.Processing.ETagEnabled = true
 
 	etag := `"loremipsumdolor"`
 
@@ -321,7 +318,8 @@ func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
-	config.LastModifiedEnabled = true
+	s.Config().Handlers.Processing.LastModifiedEnabled = true
+
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
 		rw.WriteHeader(200)
@@ -335,7 +333,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
-	config.LastModifiedEnabled = false
+	s.Config().Handlers.Processing.LastModifiedEnabled = false
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
 		rw.WriteHeader(200)
@@ -349,7 +347,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() {
-	config.LastModifiedEnabled = false
+	s.Config().Handlers.Processing.LastModifiedEnabled = false
 	data := s.TestData.Read("test1.png")
 	lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -368,7 +366,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedD
 }
 
 func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() {
-	config.LastModifiedEnabled = true
+	s.Config().Handlers.Processing.LastModifiedEnabled = true
 	lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
@@ -386,7 +384,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedE
 
 func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() {
 	data := s.TestData.Read("test1.png")
-	config.LastModifiedEnabled = false
+	s.Config().Handlers.Processing.LastModifiedEnabled = false
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
 		s.Empty(modifiedSince)

+ 2 - 2
options/processing_options.go

@@ -1082,10 +1082,10 @@ func applyURLOption(po *ProcessingOptions, name string, args []string, usedPrese
 }
 
 func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error {
-	allowAll = allowAll || len(config.AllowedProcessiongOptions) == 0
+	allowAll = allowAll || len(config.AllowedProcessingOptions) == 0
 
 	for _, opt := range options {
-		if !allowAll && !slices.Contains(config.AllowedProcessiongOptions, opt.Name) {
+		if !allowAll && !slices.Contains(config.AllowedProcessingOptions, opt.Name) {
 			return newForbiddenOptionError("processing", opt.Name)
 		}
 

+ 1 - 1
options/processing_options_test.go

@@ -646,7 +646,7 @@ func (s *ProcessingOptionsTestSuite) TestParseBase64URLOnlyPresets() {
 }
 
 func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() {
-	config.AllowedProcessiongOptions = []string{"w", "h", "pr"}
+	config.AllowedProcessingOptions = []string{"w", "h", "pr"}
 
 	presets["test1"] = urlOptions{
 		urlOption{Name: "blur", Args: []string{"0.2"}},

+ 0 - 13
security/errors.go

@@ -13,7 +13,6 @@ type (
 	ImageResolutionError string
 	SecurityOptionsError struct{}
 	SourceURLError       string
-	SourceAddressError   string
 )
 
 func newSignatureError(msg string) error {
@@ -75,15 +74,3 @@ func newSourceURLError(imageURL string) error {
 }
 
 func (e SourceURLError) Error() string { return string(e) }
-
-func newSourceAddressError(msg string) error {
-	return ierrors.Wrap(
-		SourceAddressError(msg),
-		1,
-		ierrors.WithStatusCode(http.StatusNotFound),
-		ierrors.WithPublicMessage("Invalid source URL"),
-		ierrors.WithShouldReport(false),
-	)
-}
-
-func (e SourceAddressError) Error() string { return string(e) }

+ 0 - 29
security/source.go

@@ -1,9 +1,6 @@
 package security
 
 import (
-	"fmt"
-	"net"
-
 	"github.com/imgproxy/imgproxy/v3/config"
 )
 
@@ -20,29 +17,3 @@ func VerifySourceURL(imageURL string) error {
 
 	return newSourceURLError(imageURL)
 }
-
-func VerifySourceNetwork(addr string) error {
-	host, _, err := net.SplitHostPort(addr)
-	if err != nil {
-		host = addr
-	}
-
-	ip := net.ParseIP(host)
-	if ip == nil {
-		return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
-	}
-
-	if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
-		return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
-	}
-
-	if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
-		return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
-	}
-
-	if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
-		return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
-	}
-
-	return nil
-}

+ 1 - 1
testutil/lasy_suite.go

@@ -51,7 +51,7 @@ func NewLazySuiteObj[T any](
 	// Get the [LazySuite] instance
 	lazy := s.Lazy()
 	// Create the [LazyObj] instance
-	obj, cancel := NewLazyObj(lazy, newFn, dropFn...)
+	obj, cancel := newLazyObj(lazy, newFn, dropFn...)
 	// Add cleanup function to the resets list
 	lazy.resets = append(lazy.resets, cancel)
 

+ 2 - 2
testutil/lazy_obj.go

@@ -23,10 +23,10 @@ type LazyObjNew[T any] func() (T, error)
 // If the object was not yet initialized, the callback is not called.
 type LazyObjDrop[T any] func(T) error
 
-// NewLazyObj creates a new [LazyObj] that initializes the object on the first call.
+// newLazyObj creates a new [LazyObj] that initializes the object on the first call.
 // It returns a function that can be called to get the object and a cancel function
 // that can be called to reset the object.
-func NewLazyObj[T any](
+func newLazyObj[T any](
 	s LazyObjT,
 	newFn LazyObjNew[T],
 	dropFn ...LazyObjDrop[T],

+ 0 - 3
testutil/test_data_provider.go

@@ -26,9 +26,6 @@ type TestDataProvider struct {
 
 // New creates a new TestDataProvider
 func NewTestDataProvider(t TestDataProviderT) *TestDataProvider {
-	// if h, ok := t.(interface{ Helper() }); ok {
-	// 	h.Helper()
-	// }
 	t().Helper()
 
 	path, err := findProjectRoot()

+ 122 - 0
testutil/test_server.go

@@ -0,0 +1,122 @@
+package testutil
+
+import (
+	"context"
+	"net/http"
+	"net/http/httptest"
+
+	"github.com/imgproxy/imgproxy/v3/httpheaders"
+	"github.com/stretchr/testify/require"
+)
+
+// TestServerHookFunc is a function type for in-request hooks
+type TestServerHookFunc func(r *http.Request, rw http.ResponseWriter)
+
+// Sugar alias
+type LazyTestServer = LazyObj[*TestServer]
+
+// TestServer is a syntax sugar wrapper over httptest.Server
+type TestServer struct {
+	testServer *httptest.Server
+	status     int
+	data       []byte
+	header     http.Header
+	hook       TestServerHookFunc
+}
+
+// NewLazySuiteTestServer creates a lazy TestServer object for use in test suites
+func NewLazySuiteTestServer(
+	l LazySuiteFrom,
+	init ...func(*TestServer) error,
+) (LazyObj[*TestServer], context.CancelFunc) {
+	return NewLazySuiteObj(
+		l,
+		func() (*TestServer, error) {
+			s := NewTestServer()
+
+			if len(init) > 0 {
+				for _, fn := range init {
+					if fn == nil {
+						continue
+					}
+
+					err := fn(s)
+					require.NoError(l.Lazy().T(), err, "Failed to reset test server")
+				}
+			}
+
+			return s, nil
+		},
+		func(s *TestServer) error {
+			s.Close()
+			return nil
+		},
+	)
+}
+
+// New creates and starts new http.TestServer
+func NewTestServer() *TestServer {
+	ts := &TestServer{
+		status: http.StatusOK,
+		header: make(http.Header),
+		data:   nil,
+		hook:   nil,
+	}
+
+	return ts.start()
+}
+
+// SetStatusCode sets the status code that will be returned by the server
+func (s *TestServer) SetStatusCode(status int) *TestServer {
+	s.status = status
+	return s
+}
+
+// SetBody sets the body that will be returned by the server
+func (s *TestServer) SetBody(data []byte) *TestServer {
+	s.data = data
+	return s
+}
+
+// WithHeader adds headers that will be returned by the server.
+// Odd arguments are treated as keys, even arguments as values.
+func (s *TestServer) SetHeaders(kv ...string) *TestServer {
+	for i := 0; i+1 < len(kv); i += 2 {
+		key := kv[i]
+		value := kv[i+1]
+		s.header.Set(key, value)
+	}
+
+	return s
+}
+
+// SetHook sets a function that will be called on each request. It is called
+// after headsers are set, but before status and body are written.
+func (s *TestServer) SetHook(f TestServerHookFunc) *TestServer {
+	s.hook = f
+	return s
+}
+
+// Start starts the server
+func (s *TestServer) start() *TestServer {
+	s.testServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		httpheaders.CopyAll(s.header, w.Header(), true)
+		if s.hook != nil {
+			s.hook(r, w)
+		}
+		w.WriteHeader(s.status)
+		w.Write(s.data)
+	}))
+
+	return s
+}
+
+// Close stops the server
+func (s *TestServer) Close() {
+	s.testServer.Close()
+}
+
+// URL returns the server URL
+func (s *TestServer) URL() string {
+	return s.testServer.URL
+}