Преглед на файлове

Rebuild headerwriter to server.ResponseWriter

DarthSim преди 3 седмици
родител
ревизия
1f6d007948

+ 0 - 7
config.go

@@ -6,7 +6,6 @@ import (
 	"github.com/imgproxy/imgproxy/v3/fetcher"
 	processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
 	streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/semaphores"
 	"github.com/imgproxy/imgproxy/v3/server"
 )
@@ -19,7 +18,6 @@ type HandlerConfigs struct {
 
 // Config represents an instance configuration
 type Config struct {
-	HeaderWriter   headerwriter.Config
 	Semaphores     semaphores.Config
 	FallbackImage  auximageprovider.StaticConfig
 	WatermarkImage auximageprovider.StaticConfig
@@ -31,7 +29,6 @@ type Config struct {
 // NewDefaultConfig creates a new default configuration
 func NewDefaultConfig() Config {
 	return Config{
-		HeaderWriter:   headerwriter.NewDefaultConfig(),
 		Semaphores:     semaphores.NewDefaultConfig(),
 		FallbackImage:  auximageprovider.NewDefaultStaticConfig(),
 		WatermarkImage: auximageprovider.NewDefaultStaticConfig(),
@@ -62,10 +59,6 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
 		return nil, err
 	}
 
-	if _, err = headerwriter.LoadConfigFromEnv(&c.HeaderWriter); err != nil {
-		return nil, err
-	}
-
 	if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil {
 		return nil, err
 	}

+ 1 - 1
handlers/health/handler.go

@@ -22,7 +22,7 @@ func New() *Handler {
 // Execute handles the health request
 func (h *Handler) Execute(
 	reqID string,
-	rw http.ResponseWriter,
+	rw server.ResponseWriter,
 	req *http.Request,
 ) error {
 	var (

+ 8 - 1
handlers/health/handler_test.go

@@ -6,11 +6,18 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
+	"github.com/imgproxy/imgproxy/v3/server/responsewriter"
 )
 
 func TestHealthHandler(t *testing.T) {
+	// Create responsewriter.Factory
+	rwConf := responsewriter.NewDefaultConfig()
+	rwf, err := responsewriter.NewFactory(&rwConf)
+	require.NoError(t, err)
+
 	// Create a ResponseRecorder to record the response
 	rr := httptest.NewRecorder()
 
@@ -18,7 +25,7 @@ func TestHealthHandler(t *testing.T) {
 	h := New()
 
 	// Call the handler function directly (no need for actual HTTP request)
-	h.Execute("test-req-id", rr, nil)
+	h.Execute("test-req-id", rwf.NewWriter(rr), nil)
 
 	// Check that we get a valid response (either 200 or 500 depending on vips state)
 	assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)

+ 2 - 1
handlers/landing/handler.go

@@ -5,6 +5,7 @@ import (
 	"net/http"
 
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
+	"github.com/imgproxy/imgproxy/v3/server"
 )
 
 //go:embed body.html
@@ -21,7 +22,7 @@ func New() *Handler {
 // Execute handles the landing request
 func (h *Handler) Execute(
 	reqID string,
-	rw http.ResponseWriter,
+	rw server.ResponseWriter,
 	req *http.Request,
 ) error {
 	rw.Header().Set(httpheaders.ContentType, "text/html")

+ 2 - 4
handlers/processing/handler.go

@@ -9,7 +9,6 @@ import (
 	"github.com/imgproxy/imgproxy/v3/errorreport"
 	"github.com/imgproxy/imgproxy/v3/handlers"
 	"github.com/imgproxy/imgproxy/v3/handlers/stream"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/imagedata"
 	"github.com/imgproxy/imgproxy/v3/monitoring"
@@ -17,11 +16,11 @@ import (
 	"github.com/imgproxy/imgproxy/v3/options"
 	"github.com/imgproxy/imgproxy/v3/security"
 	"github.com/imgproxy/imgproxy/v3/semaphores"
+	"github.com/imgproxy/imgproxy/v3/server"
 )
 
 // HandlerContext provides access to shared handler dependencies
 type HandlerContext interface {
-	HeaderWriter() *headerwriter.Writer
 	Semaphores() *semaphores.Semaphores
 	FallbackImage() auximageprovider.Provider
 	WatermarkImage() auximageprovider.Provider
@@ -56,7 +55,7 @@ func New(
 // Execute handles the image processing request
 func (h *Handler) Execute(
 	reqID string,
-	rw http.ResponseWriter,
+	rw server.ResponseWriter,
 	req *http.Request,
 ) error {
 	// Increment the number of requests in progress
@@ -86,7 +85,6 @@ func (h *Handler) Execute(
 		po:             po,
 		imageURL:       imageURL,
 		monitoringMeta: mm,
-		hwr:            h.HeaderWriter().NewRequest(),
 	}
 
 	return hReq.execute(ctx)

+ 4 - 6
handlers/processing/request.go

@@ -7,7 +7,6 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/fetcher"
 	"github.com/imgproxy/imgproxy/v3/handlers"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/imagetype"
 	"github.com/imgproxy/imgproxy/v3/monitoring"
@@ -23,12 +22,11 @@ type request struct {
 
 	reqID          string
 	req            *http.Request
-	rw             http.ResponseWriter
+	rw             server.ResponseWriter
 	config         *Config
 	po             *options.ProcessingOptions
 	imageURL       string
 	monitoringMeta monitoring.Meta
-	hwr            *headerwriter.Request
 }
 
 // execute handles the actual processing logic
@@ -84,13 +82,13 @@ func (r *request) execute(ctx context.Context) error {
 	var nmErr fetcher.NotModifiedError
 
 	if errors.As(err, &nmErr) {
-		r.hwr.SetOriginHeaders(nmErr.Headers())
+		r.rw.SetOriginHeaders(nmErr.Headers())
 
 		return r.respondWithNotModified()
 	}
 
 	// Prepare to write image response headers
-	r.hwr.SetOriginHeaders(originHeaders)
+	r.rw.SetOriginHeaders(originHeaders)
 
 	// If error is not related to NotModified, respond with fallback image and replace image data
 	if err != nil {
@@ -123,7 +121,7 @@ func (r *request) execute(ctx context.Context) error {
 		return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing))
 	}
 
-	// Write debug headers. It seems unlogical to move they to headerwriter since they're
+	// Write debug headers. It seems unlogical to move they to responsewriter since they're
 	// not used anywhere else.
 	err = r.writeDebugHeaders(result, originData)
 	if err != nil {

+ 14 - 18
handlers/processing/request_methods.go

@@ -123,8 +123,8 @@ func (r *request) handleDownloadError(
 	headers.Del(httpheaders.Expires)
 	headers.Del(httpheaders.LastModified)
 
-	r.hwr.SetOriginHeaders(headers)
-	r.hwr.SetIsFallbackImage()
+	r.rw.SetOriginHeaders(headers)
+	r.rw.SetIsFallbackImage()
 
 	return data, statusCode, nil
 }
@@ -186,19 +186,17 @@ func (r *request) writeDebugHeaders(result *processing.Result, originData imaged
 
 // respondWithNotModified writes not-modified response
 func (r *request) respondWithNotModified() error {
-	r.hwr.SetExpires(r.po.Expires)
-	r.hwr.SetVary()
+	r.rw.SetExpires(r.po.Expires)
+	r.rw.SetVary()
 
 	if r.config.LastModifiedEnabled {
-		r.hwr.Passthrough(httpheaders.LastModified)
+		r.rw.Passthrough(httpheaders.LastModified)
 	}
 
 	if r.config.ETagEnabled {
-		r.hwr.Passthrough(httpheaders.Etag)
+		r.rw.Passthrough(httpheaders.Etag)
 	}
 
-	r.hwr.Write(r.rw)
-
 	r.rw.WriteHeader(http.StatusNotModified)
 
 	server.LogResponse(
@@ -221,29 +219,27 @@ func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageDat
 		return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize))
 	}
 
-	r.hwr.SetContentType(resultData.Format().Mime())
-	r.hwr.SetContentLength(resultSize)
-	r.hwr.SetContentDisposition(
+	r.rw.SetContentType(resultData.Format().Mime())
+	r.rw.SetContentLength(resultSize)
+	r.rw.SetContentDisposition(
 		r.imageURL,
 		r.po.Filename,
 		resultData.Format().Ext(),
 		"",
 		r.po.ReturnAttachment,
 	)
-	r.hwr.SetExpires(r.po.Expires)
-	r.hwr.SetVary()
-	r.hwr.SetCanonical(r.imageURL)
+	r.rw.SetExpires(r.po.Expires)
+	r.rw.SetVary()
+	r.rw.SetCanonical(r.imageURL)
 
 	if r.config.LastModifiedEnabled {
-		r.hwr.Passthrough(httpheaders.LastModified)
+		r.rw.Passthrough(httpheaders.LastModified)
 	}
 
 	if r.config.ETagEnabled {
-		r.hwr.Passthrough(httpheaders.Etag)
+		r.rw.Passthrough(httpheaders.Etag)
 	}
 
-	r.hwr.Write(r.rw)
-
 	r.rw.WriteHeader(statusCode)
 
 	_, err = io.Copy(r.rw, resultData.Reader())

+ 15 - 22
handlers/stream/handler.go

@@ -6,16 +6,16 @@ import (
 	"net/http"
 	"sync"
 
+	log "github.com/sirupsen/logrus"
+
 	"github.com/imgproxy/imgproxy/v3/cookies"
 	"github.com/imgproxy/imgproxy/v3/fetcher"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/monitoring"
 	"github.com/imgproxy/imgproxy/v3/monitoring/stats"
 	"github.com/imgproxy/imgproxy/v3/options"
 	"github.com/imgproxy/imgproxy/v3/server"
-	log "github.com/sirupsen/logrus"
 )
 
 const (
@@ -35,9 +35,8 @@ var (
 
 // Handler handles image passthrough requests, allowing images to be streamed directly
 type Handler struct {
-	config  *Config              // Configuration for the streamer
-	fetcher *fetcher.Fetcher     // Fetcher instance to handle image fetching
-	hw      *headerwriter.Writer // Configured HeaderWriter instance
+	config  *Config          // Configuration for the streamer
+	fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching
 }
 
 // request holds the parameters and state for a single streaming request
@@ -47,12 +46,11 @@ type request struct {
 	imageURL     string
 	reqID        string
 	po           *options.ProcessingOptions
-	rw           http.ResponseWriter
-	hw           *headerwriter.Request
+	rw           server.ResponseWriter
 }
 
 // New creates new handler object
-func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Handler, error) {
+func New(config *Config, fetcher *fetcher.Fetcher) (*Handler, error) {
 	if err := config.Validate(); err != nil {
 		return nil, err
 	}
@@ -60,7 +58,6 @@ func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Ha
 	return &Handler{
 		fetcher: fetcher,
 		config:  config,
-		hw:      hw,
 	}, nil
 }
 
@@ -71,7 +68,7 @@ func (s *Handler) Execute(
 	imageURL string,
 	reqID string,
 	po *options.ProcessingOptions,
-	rw http.ResponseWriter,
+	rw server.ResponseWriter,
 ) error {
 	stream := &request{
 		handler:      s,
@@ -80,7 +77,6 @@ func (s *Handler) Execute(
 		reqID:        reqID,
 		po:           po,
 		rw:           rw,
-		hw:           s.hw.NewRequest(),
 	}
 
 	return stream.execute(ctx)
@@ -118,17 +114,14 @@ func (s *request) execute(ctx context.Context) error {
 	}
 
 	// Output streaming response headers
-	s.hw.SetOriginHeaders(res.Header)
-	s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
-	s.hw.SetContentLength(int(res.ContentLength))
-	s.hw.SetCanonical(s.imageURL)
-	s.hw.SetExpires(s.po.Expires)
+	s.rw.SetOriginHeaders(res.Header)
+	s.rw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
+	s.rw.SetContentLength(int(res.ContentLength))
+	s.rw.SetCanonical(s.imageURL)
+	s.rw.SetExpires(s.po.Expires)
 
 	// Set the Content-Disposition header
-	s.setContentDisposition(r.URL().Path, res, s.hw)
-
-	// Write headers from writer
-	s.hw.Write(s.rw)
+	s.setContentDisposition(r.URL().Path, res)
 
 	// Copy the status code from the original response
 	s.rw.WriteHeader(res.StatusCode)
@@ -158,7 +151,7 @@ func (s *request) getImageRequestHeaders() http.Header {
 }
 
 // setContentDisposition writes the headers to the response writer
-func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response, hw *headerwriter.Request) {
+func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response) {
 	// Try to set correct Content-Disposition file name and extension
 	if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 {
 		return
@@ -166,7 +159,7 @@ func (s *request) setContentDisposition(imagePath string, serverResponse *http.R
 
 	ct := serverResponse.Header.Get(httpheaders.ContentType)
 
-	hw.SetContentDisposition(
+	s.rw.SetContentDisposition(
 		imagePath,
 		s.po.Filename,
 		"",

+ 84 - 117
handlers/stream/handler_test.go

@@ -1,7 +1,6 @@
 package stream
 
 import (
-	"context"
 	"fmt"
 	"io"
 	"net/http"
@@ -17,9 +16,10 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/fetcher"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/options"
+	"github.com/imgproxy/imgproxy/v3/server/responsewriter"
+	"github.com/imgproxy/imgproxy/v3/testutil"
 )
 
 const (
@@ -27,14 +27,54 @@ const (
 )
 
 type HandlerTestSuite struct {
-	suite.Suite
-	handler *Handler
+	testutil.LazySuite
+
+	rwConf    testutil.LazyObj[*responsewriter.Config]
+	rwFactory testutil.LazyObj[*responsewriter.Factory]
+
+	config  testutil.LazyObj[*Config]
+	handler testutil.LazyObj[*Handler]
 }
 
 func (s *HandlerTestSuite) SetupSuite() {
 	config.Reset()
 	config.AllowLoopbackSourceAddresses = true
 
+	s.rwConf, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*responsewriter.Config, error) {
+			c := responsewriter.NewDefaultConfig()
+			return &c, nil
+		},
+	)
+
+	s.rwFactory, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*responsewriter.Factory, error) {
+			return responsewriter.NewFactory(s.rwConf())
+		},
+	)
+
+	s.config, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*Config, error) {
+			c := NewDefaultConfig()
+			return &c, nil
+		},
+	)
+
+	s.handler, _ = testutil.NewLazySuiteObj(
+		s,
+		func() (*Handler, error) {
+			fc := fetcher.NewDefaultConfig()
+
+			fetcher, err := fetcher.New(&fc)
+			s.Require().NoError(err)
+
+			return New(s.config(), fetcher)
+		},
+	)
+
 	// Silence logs during tests
 	logrus.SetOutput(io.Discard)
 }
@@ -47,27 +87,35 @@ func (s *HandlerTestSuite) TearDownSuite() {
 func (s *HandlerTestSuite) SetupTest() {
 	config.Reset()
 	config.AllowLoopbackSourceAddresses = true
+}
 
-	fc := fetcher.NewDefaultConfig()
+func (s *HandlerTestSuite) SetupSubTest() {
+	// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
+	s.ResetLazyObjects()
+}
 
-	fetcher, err := fetcher.New(&fc)
+func (s *HandlerTestSuite) readTestFile(name string) []byte {
+	data, err := os.ReadFile(filepath.Join(testDataPath, name))
 	s.Require().NoError(err)
+	return data
+}
 
-	cfg := NewDefaultConfig()
+func (s *HandlerTestSuite) execute(
+	imageURL string,
+	header http.Header,
+	po *options.ProcessingOptions,
+) *httptest.ResponseRecorder {
+	req := httptest.NewRequest("GET", "/", nil)
+	httpheaders.CopyAll(header, req.Header, true)
 
-	hwc := headerwriter.NewDefaultConfig()
-	hw, err := headerwriter.New(&hwc)
-	s.Require().NoError(err)
+	ctx := s.T().Context()
+	rw := httptest.NewRecorder()
+	rww := s.rwFactory().NewWriter(rw)
 
-	h, err := New(&cfg, hw, fetcher)
+	err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
 	s.Require().NoError(err)
-	s.handler = h
-}
 
-func (s *HandlerTestSuite) readTestFile(name string) []byte {
-	data, err := os.ReadFile(filepath.Join(testDataPath, name))
-	s.Require().NoError(err)
-	return data
+	return rw
 }
 
 // TestHandlerBasicRequest checks basic streaming request
@@ -81,12 +129,7 @@ func (s *HandlerTestSuite) TestHandlerBasicRequest() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
-
-	err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -114,12 +157,7 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
-
-	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -148,16 +186,12 @@ func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	req.Header.Set(httpheaders.IfNoneMatch, etag)
-	req.Header.Set(httpheaders.AcceptEncoding, "gzip")
-	req.Header.Set(httpheaders.Range, "bytes=*")
-
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
+	h := make(http.Header)
+	h.Set(httpheaders.IfNoneMatch, etag)
+	h.Set(httpheaders.AcceptEncoding, "gzip")
+	h.Set(httpheaders.Range, "bytes=*")
 
-	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -175,8 +209,6 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	rw := httptest.NewRecorder()
 	po := &options.ProcessingOptions{
 		Filename:         "custom_name",
 		ReturnAttachment: true,
@@ -184,8 +216,7 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
 
 	// Use a URL with a .png extension to help content disposition logic
 	imageURL := ts.URL + "/test.png"
-	err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(imageURL, nil, po)
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -342,25 +373,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 			}))
 			defer ts.Close()
 
-			fc, err := fetcher.LoadConfigFromEnv(nil)
-			s.Require().NoError(err)
-
-			fetcher, err := fetcher.New(fc)
-			s.Require().NoError(err)
-
-			cfg := NewDefaultConfig()
-			hwc := headerwriter.NewDefaultConfig()
-			hwc.CacheControlPassthrough = tc.cacheControlPassthrough
-			hwc.DefaultTTL = 4242
+			s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
+			s.rwConf().DefaultTTL = 4242
 
-			hw, err := headerwriter.New(&hwc)
-			s.Require().NoError(err)
-
-			handler, err := New(&cfg, hw, fetcher)
-			s.Require().NoError(err)
-
-			req := httptest.NewRequest("GET", "/", nil)
-			rw := httptest.NewRecorder()
 			po := &options.ProcessingOptions{}
 
 			if tc.timestampOffset != nil {
@@ -368,8 +383,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
 				po.Expires = &expires
 			}
 
-			err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-			s.Require().NoError(err)
+			rw := s.execute(ts.URL, nil, po)
 
 			res := rw.Result()
 			s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
@@ -400,12 +414,7 @@ func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
-
-	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -420,12 +429,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
-
-	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(404, res.StatusCode)
@@ -433,21 +437,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
 
 // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
 func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
-	fc, err := fetcher.LoadConfigFromEnv(nil)
-	s.Require().NoError(err)
-
-	fetcher, err := fetcher.New(fc)
-	s.Require().NoError(err)
-
-	cfg := NewDefaultConfig()
-	cfg.CookiePassthrough = true
-
-	hwc := headerwriter.NewDefaultConfig()
-	hw, err := headerwriter.New(&hwc)
-	s.Require().NoError(err)
-
-	handler, err := New(&cfg, hw, fetcher)
-	s.Require().NoError(err)
+	s.config().CookiePassthrough = true
 
 	data := s.readTestFile("test1.png")
 
@@ -464,13 +454,10 @@ func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
 	}))
 	defer ts.Close()
 
-	req := httptest.NewRequest("GET", "/", nil)
-	req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
-	rw := httptest.NewRecorder()
-	po := &options.ProcessingOptions{}
+	h := make(http.Header)
+	h.Set(httpheaders.Cookie, "test_cookie=test_value")
 
-	err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-	s.Require().NoError(err)
+	rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
 
 	res := rw.Result()
 	s.Require().Equal(200, res.StatusCode)
@@ -488,29 +475,9 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
 	defer ts.Close()
 
 	for _, sc := range []bool{true, false} {
-		fc, err := fetcher.LoadConfigFromEnv(nil)
-		s.Require().NoError(err)
-
-		fetcher, err := fetcher.New(fc)
-		s.Require().NoError(err)
-
-		cfg := NewDefaultConfig()
-		hwc := headerwriter.NewDefaultConfig()
-
-		hwc.SetCanonicalHeader = sc
-
-		hw, err := headerwriter.New(&hwc)
-		s.Require().NoError(err)
-
-		handler, err := New(&cfg, hw, fetcher)
-		s.Require().NoError(err)
-
-		req := httptest.NewRequest("GET", "/", nil)
-		rw := httptest.NewRecorder()
-		po := &options.ProcessingOptions{}
+		s.rwConf().SetCanonicalHeader = sc
 
-		err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
-		s.Require().NoError(err)
+		rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
 
 		res := rw.Result()
 		s.Require().Equal(200, res.StatusCode)

+ 0 - 62
headerwriter/config.go

@@ -1,62 +0,0 @@
-package headerwriter
-
-import (
-	"fmt"
-
-	"github.com/imgproxy/imgproxy/v3/config"
-	"github.com/imgproxy/imgproxy/v3/ensure"
-)
-
-// Config is the package-local configuration
-type Config struct {
-	SetCanonicalHeader      bool // Indicates whether to set the canonical header
-	DefaultTTL              int  // Default Cache-Control max-age= value for cached images
-	FallbackImageTTL        int  // TTL for images served as fallbacks
-	CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
-	EnableClientHints       bool // Enable Vary header
-	SetVaryAccept           bool // Whether to include Accept in Vary header
-}
-
-// NewDefaultConfig returns a new Config instance with default values.
-func NewDefaultConfig() Config {
-	return Config{
-		SetCanonicalHeader:      false,
-		DefaultTTL:              31536000,
-		FallbackImageTTL:        0,
-		CacheControlPassthrough: false,
-		EnableClientHints:       false,
-		SetVaryAccept:           false,
-	}
-}
-
-// LoadConfigFromEnv overrides configuration variables from environment
-func LoadConfigFromEnv(c *Config) (*Config, error) {
-	c = ensure.Ensure(c, NewDefaultConfig)
-
-	c.SetCanonicalHeader = config.SetCanonicalHeader
-	c.DefaultTTL = config.TTL
-	c.FallbackImageTTL = config.FallbackImageTTL
-	c.CacheControlPassthrough = config.CacheControlPassthrough
-	c.EnableClientHints = config.EnableClientHints
-	c.SetVaryAccept = config.AutoWebp ||
-		config.EnforceWebp ||
-		config.AutoAvif ||
-		config.EnforceAvif ||
-		config.AutoJxl ||
-		config.EnforceJxl
-
-	return c, nil
-}
-
-// Validate checks config for errors
-func (c *Config) Validate() error {
-	if c.DefaultTTL < 0 {
-		return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
-	}
-
-	if c.FallbackImageTTL < 0 {
-		return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
-	}
-
-	return nil
-}

+ 0 - 214
headerwriter/writer.go

@@ -1,214 +0,0 @@
-// headerwriter is responsible for writing processing/stream response headers
-package headerwriter
-
-import (
-	"fmt"
-	"net/http"
-	"strconv"
-	"strings"
-	"time"
-
-	"github.com/imgproxy/imgproxy/v3/httpheaders"
-)
-
-// Writer is a struct that creates header writer factories.
-type Writer struct {
-	config    *Config
-	varyValue string
-}
-
-// Request is a private struct that builds HTTP response headers for a specific request.
-type Request struct {
-	writer        *Writer
-	originHeaders http.Header // Original response headers
-	result        http.Header // Headers to be written to the response
-	maxAge        int         // Current max age for Cache-Control header
-}
-
-// New creates a new header writer factory with the provided config.
-func New(config *Config) (*Writer, error) {
-	if err := config.Validate(); err != nil {
-		return nil, err
-	}
-
-	vary := make([]string, 0)
-
-	if config.SetVaryAccept {
-		vary = append(vary, "Accept")
-	}
-
-	if config.EnableClientHints {
-		vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
-	}
-
-	varyValue := strings.Join(vary, ", ")
-
-	return &Writer{
-		config:    config,
-		varyValue: varyValue,
-	}, nil
-}
-
-// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL.
-func (w *Writer) NewRequest() *Request {
-	return &Request{
-		writer:        w,
-		result:        make(http.Header),
-		maxAge:        -1,
-		originHeaders: make(http.Header),
-	}
-}
-
-// SetOriginHeaders sets the origin headers for the request.
-func (r *Request) SetOriginHeaders(h http.Header) {
-	r.originHeaders = h
-}
-
-// SetIsFallbackImage sets the Fallback-Image header to
-// indicate that the fallback image was used.
-func (r *Request) SetIsFallbackImage() {
-	// We set maxAge to FallbackImageTTL if it's explicitly passed
-	if r.writer.config.FallbackImageTTL < 0 {
-		return
-	}
-
-	// However, we should not overwrite existing value if set (or greater than ours)
-	if r.maxAge < 0 || r.maxAge > r.writer.config.FallbackImageTTL {
-		r.maxAge = r.writer.config.FallbackImageTTL
-	}
-}
-
-// SetExpires sets the TTL from time
-func (r *Request) SetExpires(expires *time.Time) {
-	if expires == nil {
-		return
-	}
-
-	// Convert current maxAge to time
-	currentMaxAgeTime := time.Now().Add(time.Duration(r.maxAge) * time.Second)
-
-	// If maxAge outlives expires or was not set, we'll use expires as maxAge.
-	if r.maxAge < 0 || expires.Before(currentMaxAgeTime) {
-		r.maxAge = min(r.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
-	}
-}
-
-// SetVary sets the Vary header
-func (r *Request) SetVary() {
-	if len(r.writer.varyValue) > 0 {
-		r.result.Set(httpheaders.Vary, r.writer.varyValue)
-	}
-}
-
-// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
-func (r *Request) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
-	value := httpheaders.ContentDispositionValue(
-		originURL,
-		filename,
-		ext,
-		contentType,
-		returnAttachment,
-	)
-
-	if value != "" {
-		r.result.Set(httpheaders.ContentDisposition, value)
-	}
-}
-
-// Passthrough copies specified headers from the original response headers to the response headers.
-func (r *Request) Passthrough(only ...string) {
-	httpheaders.Copy(r.originHeaders, r.result, only)
-}
-
-// CopyFrom copies specified headers from the headers object. Please note that
-// all the past operations may overwrite those values.
-func (r *Request) CopyFrom(headers http.Header, only []string) {
-	httpheaders.Copy(headers, r.result, only)
-}
-
-// SetContentLength sets the Content-Length header
-func (r *Request) SetContentLength(contentLength int) {
-	if contentLength < 0 {
-		return
-	}
-
-	r.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
-}
-
-// SetContentType sets the Content-Type header
-func (r *Request) SetContentType(mime string) {
-	r.result.Set(httpheaders.ContentType, mime)
-}
-
-// writeCanonical sets the Link header with the canonical URL.
-// It is mandatory for any response if enabled in the configuration.
-func (r *Request) SetCanonical(url string) {
-	if !r.writer.config.SetCanonicalHeader {
-		return
-	}
-
-	if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
-		value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
-		r.result.Set(httpheaders.Link, value)
-	}
-}
-
-// setCacheControl sets the Cache-Control header with the specified value.
-func (r *Request) setCacheControl(value int) bool {
-	if value <= 0 {
-		return false
-	}
-
-	r.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
-	return true
-}
-
-// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
-func (r *Request) setCacheControlNoCache() {
-	r.result.Set(httpheaders.CacheControl, "no-cache")
-}
-
-// setCacheControlPassthrough sets the Cache-Control header from the request
-// if passthrough is enabled in the configuration.
-func (r *Request) setCacheControlPassthrough() bool {
-	if !r.writer.config.CacheControlPassthrough || r.maxAge > 0 {
-		return false
-	}
-
-	if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" {
-		r.result.Set(httpheaders.CacheControl, val)
-		return true
-	}
-
-	if val := r.originHeaders.Get(httpheaders.Expires); val != "" {
-		if t, err := time.Parse(http.TimeFormat, val); err == nil {
-			maxAge := max(0, int(time.Until(t).Seconds()))
-			return r.setCacheControl(maxAge)
-		}
-	}
-
-	return false
-}
-
-// setCSP sets the Content-Security-Policy header to prevent script execution.
-func (r *Request) setCSP() {
-	r.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
-}
-
-// Write writes the headers to the response writer. It does not overwrite
-// target headers, which were set outside the header writer.
-func (r *Request) Write(rw http.ResponseWriter) {
-	// Then, let's try to set Cache-Control using priority order
-	switch {
-	case r.setCacheControl(r.maxAge): // First, try set explicit
-	case r.setCacheControlPassthrough(): // Try to pick up from request headers
-	case r.setCacheControl(r.writer.config.DefaultTTL): // Fallback to default value
-	default:
-		r.setCacheControlNoCache() // By default we use no-cache
-	}
-
-	r.setCSP()
-
-	// Copy all headers to the response without overwriting existing ones
-	httpheaders.CopyAll(r.result, rw.Header(), false)
-}

+ 1 - 13
imgproxy.go

@@ -11,7 +11,6 @@ import (
 	landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing"
 	processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
 	streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/imagedata"
 	"github.com/imgproxy/imgproxy/v3/memory"
 	"github.com/imgproxy/imgproxy/v3/monitoring/prometheus"
@@ -34,7 +33,6 @@ type ImgproxyHandlers struct {
 
 // Imgproxy holds all the components needed for imgproxy to function
 type Imgproxy struct {
-	headerWriter     *headerwriter.Writer
 	semaphores       *semaphores.Semaphores
 	fallbackImage    auximageprovider.Provider
 	watermarkImage   auximageprovider.Provider
@@ -46,11 +44,6 @@ type Imgproxy struct {
 
 // New creates a new imgproxy instance
 func New(ctx context.Context, config *Config) (*Imgproxy, error) {
-	headerWriter, err := headerwriter.New(&config.HeaderWriter)
-	if err != nil {
-		return nil, err
-	}
-
 	fetcher, err := fetcher.New(&config.Fetcher)
 	if err != nil {
 		return nil, err
@@ -74,7 +67,6 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
 	}
 
 	imgproxy := &Imgproxy{
-		headerWriter:     headerWriter,
 		semaphores:       semaphores,
 		fallbackImage:    fallbackImage,
 		watermarkImage:   watermarkImage,
@@ -86,7 +78,7 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
 	imgproxy.handlers.Health = healthhandler.New()
 	imgproxy.handlers.Landing = landinghandler.New()
 
-	imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, headerWriter, fetcher)
+	imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, fetcher)
 	if err != nil {
 		return nil, err
 	}
@@ -180,10 +172,6 @@ func (i *Imgproxy) startMemoryTicker(ctx context.Context) {
 	}
 }
 
-func (i *Imgproxy) HeaderWriter() *headerwriter.Writer {
-	return i.headerWriter
-}
-
 func (i *Imgproxy) Semaphores() *semaphores.Semaphores {
 	return i.semaphores
 }

+ 1 - 1
integration/processing_handler_test.go

@@ -238,7 +238,7 @@ func (s *ProcessingHandlerTestSuite) TestErrorSavingToSVG() {
 }
 
 func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
-	s.Config().HeaderWriter.CacheControlPassthrough = true
+	s.Config().Server.ResponseWriter.CacheControlPassthrough = true
 
 	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public")

+ 9 - 6
server/config.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/ensure"
+	"github.com/imgproxy/imgproxy/v3/server/responsewriter"
 )
 
 // Config represents HTTP server config
@@ -18,7 +19,6 @@ type Config struct {
 	PathPrefix            string        // Path prefix for the server
 	MaxClients            int           // Maximum number of concurrent clients
 	ReadRequestTimeout    time.Duration // Timeout for reading requests
-	WriteResponseTimeout  time.Duration // Timeout for writing responses
 	KeepAliveTimeout      time.Duration // Timeout for keep-alive connections
 	GracefulTimeout       time.Duration // Timeout for graceful shutdown
 	CORSAllowOrigin       string        // CORS allowed origin
@@ -27,6 +27,8 @@ type Config struct {
 	SocketReusePort       bool          // Enable SO_REUSEPORT socket option
 	HealthCheckPath       string        // Health check path from config
 
+	ResponseWriter responsewriter.Config // Response writer config
+
 	// TODO: We are not sure where to put it yet
 	FreeMemoryInterval time.Duration // Interval for freeing memory
 	LogMemStats        bool          // Log memory stats
@@ -41,7 +43,6 @@ func NewDefaultConfig() Config {
 		MaxClients:            2048,
 		ReadRequestTimeout:    10 * time.Second,
 		KeepAliveTimeout:      10 * time.Second,
-		WriteResponseTimeout:  10 * time.Second,
 		GracefulTimeout:       20 * time.Second,
 		CORSAllowOrigin:       "",
 		Secret:                "",
@@ -50,6 +51,8 @@ func NewDefaultConfig() Config {
 		HealthCheckPath:       "",
 		FreeMemoryInterval:    10 * time.Second,
 		LogMemStats:           false,
+
+		ResponseWriter: responsewriter.NewDefaultConfig(),
 	}
 }
 
@@ -72,6 +75,10 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
 	c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second
 	c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0
 
+	if _, err := responsewriter.LoadConfigFromEnv(&c.ResponseWriter); err != nil {
+		return nil, err
+	}
+
 	return c, nil
 }
 
@@ -89,10 +96,6 @@ func (c *Config) Validate() error {
 		return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
 	}
 
-	if c.WriteResponseTimeout <= 0 {
-		return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
-	}
-
 	if c.KeepAliveTimeout < 0 {
 		return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
 	}

+ 9 - 6
server/middlewares.go

@@ -23,10 +23,13 @@ func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
 		return h
 	}
 
-	return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
-		ctx, cancel, rw := monitoring.StartRequest(req.Context(), rw, req)
+	return func(reqID string, rw ResponseWriter, req *http.Request) error {
+		ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
 		defer cancel()
 
+		// Replace rw.ResponseWriter with new one returned from monitoring
+		rw.SetHTTPResponseWriter(newRw)
+
 		return h(reqID, rw, req.WithContext(ctx))
 	}
 }
@@ -37,7 +40,7 @@ func (r *Router) WithCORS(h RouteHandler) RouteHandler {
 		return h
 	}
 
-	return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	return func(reqID string, rw ResponseWriter, req *http.Request) error {
 		rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
 		rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
 
@@ -53,7 +56,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
 
 	authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
 
-	return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	return func(reqID string, rw ResponseWriter, req *http.Request) error {
 		if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
 			return h(reqID, rw, req)
 		} else {
@@ -64,7 +67,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
 
 // WithPanic recovers panic and converts it to normal error
 func (r *Router) WithPanic(h RouteHandler) RouteHandler {
-	return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
+	return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
 		defer func() {
 			// try to recover from panic
 			rerr := recover()
@@ -94,7 +97,7 @@ func (r *Router) WithPanic(h RouteHandler) RouteHandler {
 // WithReportError handles error reporting.
 // It should be placed after `WithMonitoring`, but before `WithPanic`.
 func (r *Router) WithReportError(h RouteHandler) RouteHandler {
-	return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	return func(reqID string, rw ResponseWriter, req *http.Request) error {
 		// Open the error context
 		ctx := errorreport.StartRequest(req)
 		req = req.WithContext(ctx)

+ 87 - 0
server/responsewriter/config.go

@@ -0,0 +1,87 @@
+package responsewriter
+
+import (
+	"fmt"
+	"strings"
+	"time"
+
+	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/ensure"
+)
+
+// Config holds configuration for response writer
+type Config struct {
+	SetCanonicalHeader      bool          // Indicates whether to set the canonical header
+	DefaultTTL              int           // Default Cache-Control max-age= value for cached images
+	FallbackImageTTL        int           // TTL for images served as fallbacks
+	CacheControlPassthrough bool          // Passthrough the Cache-Control from the original response
+	VaryValue               string        // Value for Vary header
+	WriteResponseTimeout    time.Duration // Timeout for response write operations
+}
+
+// NewDefaultConfig returns a new Config instance with default values.
+func NewDefaultConfig() Config {
+	return Config{
+		SetCanonicalHeader:      false,
+		DefaultTTL:              31536000,
+		FallbackImageTTL:        0,
+		CacheControlPassthrough: false,
+		VaryValue:               "",
+		WriteResponseTimeout:    10 * time.Second,
+	}
+}
+
+// LoadConfigFromEnv overrides configuration variables from environment
+func LoadConfigFromEnv(c *Config) (*Config, error) {
+	c = ensure.Ensure(c, NewDefaultConfig)
+
+	c.SetCanonicalHeader = config.SetCanonicalHeader
+	c.DefaultTTL = config.TTL
+	c.FallbackImageTTL = config.FallbackImageTTL
+	c.CacheControlPassthrough = config.CacheControlPassthrough
+	c.WriteResponseTimeout = time.Duration(config.WriteResponseTimeout) * time.Second
+
+	vary := make([]string, 0)
+
+	if c.envEnableFormatDetection() {
+		vary = append(vary, "Accept")
+	}
+
+	if c.envEnableClientHints() {
+		vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
+	}
+
+	c.VaryValue = strings.Join(vary, ", ")
+
+	return c, nil
+}
+
+func (c *Config) envEnableFormatDetection() bool {
+	return config.AutoWebp ||
+		config.EnforceWebp ||
+		config.AutoAvif ||
+		config.EnforceAvif ||
+		config.AutoJxl ||
+		config.EnforceJxl
+}
+
+func (c *Config) envEnableClientHints() bool {
+	return config.EnableClientHints
+}
+
+// Validate checks config for errors
+func (c *Config) Validate() error {
+	if c.DefaultTTL < 0 {
+		return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
+	}
+
+	if c.FallbackImageTTL < 0 {
+		return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
+	}
+
+	if c.WriteResponseTimeout <= 0 {
+		return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
+	}
+
+	return nil
+}

+ 114 - 0
server/responsewriter/config_test.go

@@ -0,0 +1,114 @@
+package responsewriter
+
+import (
+	"fmt"
+	"io"
+	"os"
+	"testing"
+
+	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/suite"
+)
+
+type ResponseWriterConfigSuite struct {
+	suite.Suite
+}
+
+func (s *ResponseWriterConfigSuite) SetupSuite() {
+	logrus.SetOutput(io.Discard)
+}
+
+func (s *ResponseWriterConfigSuite) TearDownSuite() {
+	logrus.SetOutput(os.Stdout)
+}
+
+func (s *ResponseWriterConfigSuite) TestLoadingVaryValueFromEnv() {
+	defaultEnv := map[string]string{
+		"IMGPROXY_AUTO_WEBP":           "",
+		"IMGPROXY_ENFORCE_WEBP":        "",
+		"IMGPROXY_AUTO_AVIF":           "",
+		"IMGPROXY_ENFORCE_AVIF":        "",
+		"IMGPROXY_AUTO_JXL":            "",
+		"IMGPROXY_ENFORCE_JXL":         "",
+		"IMGPROXY_ENABLE_CLIENT_HINTS": "",
+	}
+
+	testCases := []struct {
+		name     string
+		env      map[string]string
+		expected string
+	}{
+		{
+			name:     "AutoWebP",
+			env:      map[string]string{"IMGPROXY_AUTO_WEBP": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "EnforceWebP",
+			env:      map[string]string{"IMGPROXY_ENFORCE_WEBP": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "AutoAVIF",
+			env:      map[string]string{"IMGPROXY_AUTO_AVIF": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "EnforceAVIF",
+			env:      map[string]string{"IMGPROXY_ENFORCE_AVIF": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "AutoJXL",
+			env:      map[string]string{"IMGPROXY_AUTO_JXL": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "EnforceJXL",
+			env:      map[string]string{"IMGPROXY_ENFORCE_JXL": "true"},
+			expected: "Accept",
+		},
+		{
+			name:     "EnableClientHints",
+			env:      map[string]string{"IMGPROXY_ENABLE_CLIENT_HINTS": "true"},
+			expected: "Sec-CH-DPR, DPR, Sec-CH-Width, Width",
+		},
+		{
+			name: "Combined",
+			env: map[string]string{
+				"IMGPROXY_AUTO_WEBP":           "true",
+				"IMGPROXY_ENABLE_CLIENT_HINTS": "true",
+			},
+			expected: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
+		},
+	}
+
+	for _, tc := range testCases {
+		s.Run(fmt.Sprintf("%v", tc.env), func() {
+			// Set default environment variables
+			for key, value := range defaultEnv {
+				s.T().Setenv(key, value)
+			}
+			// Set environment variables
+			for key, value := range tc.env {
+				s.T().Setenv(key, value)
+			}
+
+			// TODO: Remove when we removed global config
+			config.Reset()
+			config.Configure()
+
+			// Load config
+			cfg, err := LoadConfigFromEnv(nil)
+
+			// Assert expected values
+			s.Require().NoError(err)
+			s.Require().Equal(tc.expected, cfg.VaryValue)
+		})
+	}
+}
+
+func TestResponseWriterConfig(t *testing.T) {
+	suite.Run(t, new(ResponseWriterConfigSuite))
+}

+ 30 - 0
server/responsewriter/factory.go

@@ -0,0 +1,30 @@
+package responsewriter
+
+import "net/http"
+
+// Factory is a struct that creates response writers.
+type Factory struct {
+	config *Config
+}
+
+func NewFactory(config *Config) (*Factory, error) {
+	if err := config.Validate(); err != nil {
+		return nil, err
+	}
+
+	return &Factory{config}, nil
+}
+
+// NewWriter wraps [http.ResponseWriter] into [Writer].
+func (f *Factory) NewWriter(rw http.ResponseWriter) *Writer {
+	w := &Writer{
+		config:        f.config,
+		result:        make(http.Header),
+		originHeaders: make(http.Header),
+		maxAge:        -1,
+	}
+
+	w.SetHTTPResponseWriter(rw)
+
+	return w
+}

+ 226 - 0
server/responsewriter/writer.go

@@ -0,0 +1,226 @@
+package responsewriter
+
+import (
+	"fmt"
+	"net/http"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/imgproxy/imgproxy/v3/httpheaders"
+)
+
+// Just aliases for [http.ResponseWriter] and [http.ResponseController].
+// We need them to make them private in [Writer] so they can't be accessed directly.
+type httpResponseWriter = http.ResponseWriter
+type httpResponseController = *http.ResponseController
+
+// Writer is an implementation of [http.ResponseWriter] with additional
+// functionality for managing response headers.
+type Writer struct {
+	httpResponseWriter
+	httpResponseController
+
+	config        *Config     // Configuration for the writer
+	originHeaders http.Header // Original response headers
+	result        http.Header // Headers to be written to the response
+	maxAge        int         // Current max age for Cache-Control header
+
+	beforeWriteOnce sync.Once
+}
+
+// HTTPResponseWriter returns the underlying http.ResponseWriter.
+func (w *Writer) HTTPResponseWriter() http.ResponseWriter {
+	return w.httpResponseWriter
+}
+
+// SetHTTPResponseWriter replaces the underlying http.ResponseWriter.
+func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) {
+	w.httpResponseWriter = rw
+	w.httpResponseController = http.NewResponseController(rw)
+}
+
+// SetOriginHeaders sets the origin headers for the request.
+func (w *Writer) SetOriginHeaders(h http.Header) {
+	w.originHeaders = h
+}
+
+// SetIsFallbackImage sets the Fallback-Image header to
+// indicate that the fallback image was used.
+func (w *Writer) SetIsFallbackImage() {
+	// We set maxAge to FallbackImageTTL if it's explicitly passed
+	if w.config.FallbackImageTTL < 0 {
+		return
+	}
+
+	// However, we should not overwrite existing value if set (or greater than ours)
+	if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
+		w.maxAge = w.config.FallbackImageTTL
+	}
+}
+
+// SetExpires sets the TTL from time
+func (w *Writer) SetExpires(expires *time.Time) {
+	if expires == nil {
+		return
+	}
+
+	// Convert current maxAge to time
+	currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
+
+	// If maxAge outlives expires or was not set, we'll use expires as maxAge.
+	if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
+		w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
+	}
+}
+
+// SetVary sets the Vary header
+func (w *Writer) SetVary() {
+	if val := w.config.VaryValue; len(val) > 0 {
+		w.result.Set(httpheaders.Vary, val)
+	}
+}
+
+// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
+func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
+	value := httpheaders.ContentDispositionValue(
+		originURL,
+		filename,
+		ext,
+		contentType,
+		returnAttachment,
+	)
+
+	if value != "" {
+		w.result.Set(httpheaders.ContentDisposition, value)
+	}
+}
+
+// Passthrough copies specified headers from the original response headers to the response headers.
+func (w *Writer) Passthrough(only ...string) {
+	httpheaders.Copy(w.originHeaders, w.result, only)
+}
+
+// CopyFrom copies specified headers from the headers object. Please note that
+// all the past operations may overwrite those values.
+func (w *Writer) CopyFrom(headers http.Header, only []string) {
+	httpheaders.Copy(headers, w.result, only)
+}
+
+// SetContentLength sets the Content-Length header
+func (w *Writer) SetContentLength(contentLength int) {
+	if contentLength < 0 {
+		return
+	}
+
+	w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
+}
+
+// SetContentType sets the Content-Type header
+func (w *Writer) SetContentType(mime string) {
+	w.result.Set(httpheaders.ContentType, mime)
+}
+
+// writeCanonical sets the Link header with the canonical URL.
+// It is mandatory for any response if enabled in the configuration.
+func (w *Writer) SetCanonical(url string) {
+	if !w.config.SetCanonicalHeader {
+		return
+	}
+
+	if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
+		value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
+		w.result.Set(httpheaders.Link, value)
+	}
+}
+
+// setCacheControl sets the Cache-Control header with the specified value.
+func (w *Writer) setCacheControl(value int) bool {
+	if value <= 0 {
+		return false
+	}
+
+	w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
+	return true
+}
+
+// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
+func (w *Writer) setCacheControlNoCache() {
+	w.result.Set(httpheaders.CacheControl, "no-cache")
+}
+
+// setCacheControlPassthrough sets the Cache-Control header from the request
+// if passthrough is enabled in the configuration.
+func (w *Writer) setCacheControlPassthrough() bool {
+	if !w.config.CacheControlPassthrough || w.maxAge > 0 {
+		return false
+	}
+
+	if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" {
+		w.result.Set(httpheaders.CacheControl, val)
+		return true
+	}
+
+	if val := w.originHeaders.Get(httpheaders.Expires); val != "" {
+		if t, err := time.Parse(http.TimeFormat, val); err == nil {
+			maxAge := max(0, int(time.Until(t).Seconds()))
+			return w.setCacheControl(maxAge)
+		}
+	}
+
+	return false
+}
+
+// setCSP sets the Content-Security-Policy header to prevent script execution.
+func (w *Writer) setCSP() {
+	w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
+}
+
+// flushHeaders writes the headers to the response writer. It does not overwrite
+// target headers, which were set outside the header writer.
+func (w *Writer) flushHeaders() {
+	// Then, let's try to set Cache-Control using priority order
+	switch {
+	case w.setCacheControl(w.maxAge): // First, try set explicit
+	case w.setCacheControlPassthrough(): // Try to pick up from request headers
+	case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value
+	default:
+		w.setCacheControlNoCache() // By default we use no-cache
+	}
+
+	w.setCSP()
+
+	// Copy all headers to the response without overwriting existing ones
+	httpheaders.CopyAll(w.result, w.Header(), false)
+}
+
+// beforeWrite is called before [WriteHeader] and [Write]
+func (w *Writer) beforeWrite() {
+	w.beforeWriteOnce.Do(func() {
+		// We're going to start writing response.
+		// Set write deadline.
+		w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout))
+
+		// Flush headers before we write anything
+		w.flushHeaders()
+	})
+}
+
+// WriteHeader writes the HTTP response header.
+//
+// It ensures that all headers are flushed before writing the status code.
+func (w *Writer) WriteHeader(statusCode int) {
+	w.beforeWrite()
+
+	w.httpResponseWriter.WriteHeader(statusCode)
+}
+
+// Write writes the HTTP response body.
+//
+// It ensures that all headers are flushed before writing the body.
+func (w *Writer) Write(b []byte) (int, error) {
+	w.beforeWrite()
+
+	return w.httpResponseWriter.Write(b)
+}

+ 60 - 70
headerwriter/writer_test.go → server/responsewriter/writer_test.go

@@ -1,4 +1,4 @@
-package headerwriter
+package responsewriter
 
 import (
 	"fmt"
@@ -13,7 +13,7 @@ import (
 	"github.com/stretchr/testify/suite"
 )
 
-type HeaderWriterSuite struct {
+type ResponseWriterSuite struct {
 	suite.Suite
 }
 
@@ -22,16 +22,18 @@ type writerTestCase struct {
 	req    http.Header
 	res    http.Header
 	config Config
-	fn     func(*Request)
+	fn     func(*Writer)
 }
 
-func (s *HeaderWriterSuite) TestHeaderCases() {
+func (s *ResponseWriterSuite) TestHeaderCases() {
 	expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC)
 	expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds()))
 
 	shortExpires := time.Now().Add(10 * time.Second)
 	shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds()))
 
+	writeResponseTimeout := 10 * time.Second
+
 	tt := []writerTestCase{
 		{
 			name: "MinimalHeaders",
@@ -44,8 +46,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				SetCanonicalHeader:      false,
 				DefaultTTL:              0,
 				CacheControlPassthrough: false,
-				EnableClientHints:       false,
-				SetVaryAccept:           false,
+				WriteResponseTimeout:    writeResponseTimeout,
 			},
 		},
 		{
@@ -60,6 +61,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			config: Config{
 				CacheControlPassthrough: true,
 				DefaultTTL:              3600,
+				WriteResponseTimeout:    writeResponseTimeout,
 			},
 		},
 		{
@@ -74,6 +76,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			config: Config{
 				CacheControlPassthrough: true,
 				DefaultTTL:              3600,
+				WriteResponseTimeout:    writeResponseTimeout,
 			},
 		},
 		{
@@ -88,6 +91,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			config: Config{
 				CacheControlPassthrough: true,
 				DefaultTTL:              3600,
+				WriteResponseTimeout:    writeResponseTimeout,
 			},
 		},
 		{
@@ -99,10 +103,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				SetCanonicalHeader: true,
-				DefaultTTL:         3600,
+				SetCanonicalHeader:   true,
+				DefaultTTL:           3600,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetCanonical("https://example.com/image.jpg")
 			},
 		},
@@ -114,8 +119,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				SetCanonicalHeader: true,
-				DefaultTTL:         3600,
+				SetCanonicalHeader:   true,
+				DefaultTTL:           3600,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
 		},
 		{
@@ -126,10 +132,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				SetCanonicalHeader: false,
-				DefaultTTL:         3600,
+				SetCanonicalHeader:   false,
+				DefaultTTL:           3600,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetCanonical("https://example.com/image.jpg")
 			},
 		},
@@ -141,10 +148,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				DefaultTTL:       3600,
-				FallbackImageTTL: 1,
+				DefaultTTL:           3600,
+				FallbackImageTTL:     1,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetIsFallbackImage()
 			},
 		},
@@ -156,9 +164,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				DefaultTTL: math.MaxInt32,
+				DefaultTTL:           math.MaxInt32,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetExpires(&expires)
 			},
 		},
@@ -170,10 +179,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				DefaultTTL:       math.MaxInt32,
-				FallbackImageTTL: 600,
+				DefaultTTL:           math.MaxInt32,
+				FallbackImageTTL:     600,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetIsFallbackImage()
 				w.SetExpires(&shortExpires)
 			},
@@ -187,10 +197,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				EnableClientHints: true,
-				SetVaryAccept:     true,
+				VaryValue:            "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetVary()
 			},
 		},
@@ -204,8 +214,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.CacheControl:          []string{"no-cache"},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
-			config: Config{},
-			fn: func(w *Request) {
+			config: Config{
+				WriteResponseTimeout: writeResponseTimeout,
+			},
+			fn: func(w *Writer) {
 				w.Passthrough("X-Test")
 			},
 		},
@@ -217,8 +229,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.CacheControl:          []string{"no-cache"},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
-			config: Config{},
-			fn: func(w *Request) {
+			config: Config{
+				WriteResponseTimeout: writeResponseTimeout,
+			},
+			fn: func(w *Writer) {
 				h := http.Header{}
 				h.Set("X-From", "baz")
 				w.CopyFrom(h, []string{"X-From"})
@@ -232,8 +246,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.CacheControl:          []string{"no-cache"},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
-			config: Config{},
-			fn: func(w *Request) {
+			config: Config{
+				WriteResponseTimeout: writeResponseTimeout,
+			},
+			fn: func(w *Writer) {
 				w.SetContentLength(123)
 			},
 		},
@@ -245,8 +261,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.CacheControl:          []string{"no-cache"},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
-			config: Config{},
-			fn: func(w *Request) {
+			config: Config{
+				WriteResponseTimeout: writeResponseTimeout,
+			},
+			fn: func(w *Writer) {
 				w.SetContentType("image/png")
 			},
 		},
@@ -258,58 +276,30 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
 			},
 			config: Config{
-				DefaultTTL: 3600,
+				DefaultTTL:           3600,
+				WriteResponseTimeout: writeResponseTimeout,
 			},
-			fn: func(w *Request) {
+			fn: func(w *Writer) {
 				w.SetExpires(nil)
 			},
 		},
-		{
-			name: "WriteVaryAcceptOnly",
-			req:  http.Header{},
-			res: http.Header{
-				httpheaders.Vary:                  []string{"Accept"},
-				httpheaders.CacheControl:          []string{"no-cache"},
-				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
-			},
-			config: Config{
-				SetVaryAccept: true,
-			},
-			fn: func(w *Request) {
-				w.SetVary()
-			},
-		},
-		{
-			name: "WriteVaryClientHintsOnly",
-			req:  http.Header{},
-			res: http.Header{
-				httpheaders.Vary:                  []string{"Sec-CH-DPR, DPR, Sec-CH-Width, Width"},
-				httpheaders.CacheControl:          []string{"no-cache"},
-				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
-			},
-			config: Config{
-				EnableClientHints: true,
-			},
-			fn: func(w *Request) {
-				w.SetVary()
-			},
-		},
 	}
 
 	for _, tc := range tt {
 		s.Run(tc.name, func() {
-			factory, err := New(&tc.config)
+			factory, err := NewFactory(&tc.config)
 			s.Require().NoError(err)
 
-			writer := factory.NewRequest()
+			r := httptest.NewRecorder()
+
+			writer := factory.NewWriter(r)
 			writer.SetOriginHeaders(tc.req)
 
 			if tc.fn != nil {
 				tc.fn(writer)
 			}
 
-			r := httptest.NewRecorder()
-			writer.Write(r)
+			writer.WriteHeader(http.StatusOK)
 
 			s.Require().Equal(tc.res, r.Header())
 		})
@@ -317,5 +307,5 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 }
 
 func TestHeaderWriter(t *testing.T) {
-	suite.Run(t, new(HeaderWriterSuite))
+	suite.Run(t, new(ResponseWriterSuite))
 }

+ 24 - 10
server/router.go

@@ -11,6 +11,7 @@ import (
 	nanoid "github.com/matoous/go-nanoid/v2"
 
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
+	"github.com/imgproxy/imgproxy/v3/server/responsewriter"
 )
 
 const (
@@ -23,8 +24,10 @@ var (
 	requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
 )
 
+type ResponseWriter = *responsewriter.Writer
+
 // RouteHandler is a function that handles HTTP requests.
-type RouteHandler func(string, http.ResponseWriter, *http.Request) error
+type RouteHandler func(string, ResponseWriter, *http.Request) error
 
 // Middleware is a function that wraps a RouteHandler with additional functionality.
 type Middleware func(next RouteHandler) RouteHandler
@@ -40,6 +43,9 @@ type route struct {
 
 // Router is responsible for routing HTTP requests
 type Router struct {
+	// Response writers factory
+	rwFactory *responsewriter.Factory
+
 	// config represents the server configuration
 	config *Config
 
@@ -53,7 +59,15 @@ func NewRouter(config *Config) (*Router, error) {
 		return nil, err
 	}
 
-	return &Router{config: config}, nil
+	rwf, err := responsewriter.NewFactory(&config.ResponseWriter)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Router{
+		rwFactory: rwf,
+		config:    config,
+	}, nil
 }
 
 // add adds an abitary route to the router
@@ -114,8 +128,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	req, timeoutCancel := startRequestTimer(req)
 	defer timeoutCancel()
 
-	// Create the response writer which times out on write
-	rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
+	// Create the [ResponseWriter]
+	rww := r.rwFactory.NewWriter(rw)
 
 	// Get/create request ID
 	reqID := r.getRequestID(req)
@@ -123,8 +137,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	// Replace request IP from headers
 	r.replaceRemoteAddr(req)
 
-	rw.Header().Set(httpheaders.Server, defaultServerName)
-	rw.Header().Set(httpheaders.XRequestID, reqID)
+	rww.Header().Set(httpheaders.Server, defaultServerName)
+	rww.Header().Set(httpheaders.XRequestID, reqID)
 
 	for _, rr := range r.routes {
 		if !rr.isMatch(req) {
@@ -138,18 +152,18 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 			LogRequest(reqID, req)
 		}
 
-		rr.handler(reqID, rw, req)
+		rr.handler(reqID, rww, req)
 		return
 	}
 
 	// Means that we have not found matching route
 	LogRequest(reqID, req)
 	LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
-	r.NotFoundHandler(reqID, rw, req)
+	r.NotFoundHandler(reqID, rww, req)
 }
 
 // NotFoundHandler is default 404 handler
-func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
+func (r *Router) NotFoundHandler(reqID string, rw ResponseWriter, req *http.Request) error {
 	rw.Header().Set(httpheaders.ContentType, "text/plain")
 	rw.WriteHeader(http.StatusNotFound)
 	rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
@@ -158,7 +172,7 @@ func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http
 }
 
 // OkHandler is a default 200 OK handler
-func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
+func (r *Router) OkHandler(reqID string, rw ResponseWriter, req *http.Request) error {
 	rw.Header().Set(httpheaders.ContentType, "text/plain")
 	rw.WriteHeader(http.StatusOK)
 	rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy

+ 11 - 11
server/router_test.go

@@ -30,7 +30,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
 	var capturedMethod string
 	var capturedPath string
 
-	getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		capturedMethod = req.Method
 		capturedPath = req.URL.Path
 		rw.WriteHeader(200)
@@ -38,7 +38,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
 		return nil
 	}
 
-	optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		capturedMethod = req.Method
 		capturedPath = req.URL.Path
 		rw.WriteHeader(200)
@@ -46,7 +46,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
 		return nil
 	}
 
-	headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		capturedMethod = req.Method
 		capturedPath = req.URL.Path
 		rw.WriteHeader(200)
@@ -114,20 +114,20 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
 	var order []string
 
 	middleware1 := func(next RouteHandler) RouteHandler {
-		return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+		return func(reqID string, rw ResponseWriter, req *http.Request) error {
 			order = append(order, "middleware1")
 			return next(reqID, rw, req)
 		}
 	}
 
 	middleware2 := func(next RouteHandler) RouteHandler {
-		return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+		return func(reqID string, rw ResponseWriter, req *http.Request) error {
 			order = append(order, "middleware2")
 			return next(reqID, rw, req)
 		}
 	}
 
-	handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		order = append(order, "handler")
 		rw.WriteHeader(200)
 		return nil
@@ -146,7 +146,7 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
 
 // TestServeHTTP tests ServeHTTP method
 func (s *RouterTestSuite) TestServeHTTP() {
-	handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		rw.Header().Set("Custom-Header", "test-value")
 		rw.WriteHeader(200)
 		rw.Write([]byte("success"))
@@ -169,7 +169,7 @@ func (s *RouterTestSuite) TestServeHTTP() {
 
 // TestRequestID checks request ID generation and validation
 func (s *RouterTestSuite) TestRequestID() {
-	handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		rw.WriteHeader(200)
 		return nil
 	}
@@ -209,7 +209,7 @@ func (s *RouterTestSuite) TestRequestID() {
 
 // TestLambdaRequestIDExtraction checks AWS lambda request id extraction
 func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
-	handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		rw.WriteHeader(200)
 		return nil
 	}
@@ -229,7 +229,7 @@ func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
 // Test IP address handling
 func (s *RouterTestSuite) TestReplaceIP() {
 	var capturedRemoteAddr string
-	handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		capturedRemoteAddr = req.RemoteAddr
 		rw.WriteHeader(200)
 		return nil
@@ -298,7 +298,7 @@ func (s *RouterTestSuite) TestReplaceIP() {
 // TestRouteOrder checks exact/non-exact insertion order
 func (s *RouterTestSuite) TestRouteOrder() {
 
-	h := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
+	h := func(reqID string, rw ResponseWriter, req *http.Request) error {
 		return nil
 	}
 

+ 17 - 13
server/server_test.go

@@ -29,10 +29,14 @@ func (s *ServerTestSuite) SetupTest() {
 	s.blankRouter = r
 }
 
-func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
+func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
 	return nil
 }
 
+func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
+	return s.blankRouter.rwFactory.NewWriter(rw)
+}
+
 func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
 	ctx, cancel := context.WithCancel(s.T().Context())
 
@@ -121,7 +125,7 @@ func (s *ServerTestSuite) TestWithCORS() {
 			req := httptest.NewRequest("GET", "/test", nil)
 			rw := httptest.NewRecorder()
 
-			wrappedHandler("test-req-id", rw, req)
+			wrappedHandler("test-req-id", s.wrapRW(rw), req)
 
 			s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
 			s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
@@ -170,7 +174,7 @@ func (s *ServerTestSuite) TestWithSecret() {
 			}
 			rw := httptest.NewRecorder()
 
-			err = wrappedHandler("test-req-id", rw, req)
+			err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
 
 			if tt.expectError {
 				s.Require().Error(err)
@@ -182,7 +186,7 @@ func (s *ServerTestSuite) TestWithSecret() {
 }
 
 func (s *ServerTestSuite) TestIntoSuccess() {
-	mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
+	mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
 		rw.WriteHeader(http.StatusOK)
 		return nil
 	}
@@ -192,14 +196,14 @@ func (s *ServerTestSuite) TestIntoSuccess() {
 	req := httptest.NewRequest("GET", "/test", nil)
 	rw := httptest.NewRecorder()
 
-	wrappedHandler("test-req-id", rw, req)
+	wrappedHandler("test-req-id", s.wrapRW(rw), req)
 
 	s.Equal(http.StatusOK, rw.Code)
 }
 
 func (s *ServerTestSuite) TestIntoWithError() {
 	testError := errors.New("test error")
-	mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
+	mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
 		return testError
 	}
 
@@ -208,7 +212,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
 	req := httptest.NewRequest("GET", "/test", nil)
 	rw := httptest.NewRecorder()
 
-	wrappedHandler("test-req-id", rw, req)
+	wrappedHandler("test-req-id", s.wrapRW(rw), req)
 
 	s.Equal(http.StatusInternalServerError, rw.Code)
 	s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
@@ -216,7 +220,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
 
 func (s *ServerTestSuite) TestIntoPanicWithError() {
 	testError := errors.New("panic error")
-	mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
+	mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
 		panic(testError)
 	}
 
@@ -226,7 +230,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
 	rw := httptest.NewRecorder()
 
 	s.NotPanics(func() {
-		err := wrappedHandler("test-req-id", rw, req)
+		err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
 		s.Require().Error(err, "panic error")
 	})
 
@@ -234,7 +238,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
 }
 
 func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
-	mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
+	mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
 		panic(http.ErrAbortHandler)
 	}
 
@@ -245,12 +249,12 @@ func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
 
 	// Should re-panic with ErrAbortHandler
 	s.Panics(func() {
-		wrappedHandler("test-req-id", rw, req)
+		wrappedHandler("test-req-id", s.wrapRW(rw), req)
 	})
 }
 
 func (s *ServerTestSuite) TestIntoPanicWithNonError() {
-	mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
+	mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
 		panic("string panic")
 	}
 
@@ -261,7 +265,7 @@ func (s *ServerTestSuite) TestIntoPanicWithNonError() {
 
 	// Should re-panic with non-error panics
 	s.NotPanics(func() {
-		err := wrappedHandler("test-req-id", rw, req)
+		err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
 		s.Require().Error(err, "string panic")
 	})
 }

+ 0 - 47
server/timeout_response.go

@@ -1,47 +0,0 @@
-package server
-
-import (
-	"net/http"
-	"time"
-)
-
-// timeoutResponse manages response writer with timeout. It has
-// timeout on all write methods.
-type timeoutResponse struct {
-	http.ResponseWriter
-	controller *http.ResponseController
-	timeout    time.Duration
-}
-
-// newTimeoutResponse creates a new timeoutResponse
-func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
-	return &timeoutResponse{
-		ResponseWriter: rw,
-		controller:     http.NewResponseController(rw),
-		timeout:        timeout,
-	}
-}
-
-// Write implements http.ResponseWriter.Write
-func (rw *timeoutResponse) Write(b []byte) (int, error) {
-	var (
-		n   int
-		err error
-	)
-	rw.withWriteDeadline(func() {
-		n, err = rw.ResponseWriter.Write(b)
-	})
-	return n, err
-}
-
-// withWriteDeadline executes a Write* function with a deadline
-func (rw *timeoutResponse) withWriteDeadline(f func()) {
-	deadline := time.Now().Add(rw.timeout)
-
-	// Set write deadline
-	rw.controller.SetWriteDeadline(deadline)
-
-	// Reset write deadline after method has finished
-	defer rw.controller.SetWriteDeadline(time.Time{})
-	f()
-}