Parcourir la source

stream.go replaced with handlers/stream

Viktor Sokolov il y a 1 mois
Parent
commit
8ba3e85913
6 fichiers modifiés avec 660 ajouts et 241 suppressions
  1. 1 0
      handlers/stream/config.go
  2. 0 33
      handlers/stream/factory.go
  3. 63 34
      handlers/stream/handler.go
  4. 453 37
      handlers/stream/handler_test.go
  5. 7 1
      processing_handler.go
  6. 136 136
      stream.go

+ 1 - 0
handlers/stream/config.go

@@ -23,6 +23,7 @@ func NewConfigFromEnv() *Config {
 		CookiePassthrough: config.CookiePassthrough,
 		PassthroughRequestHeaders: []string{
 			httpheaders.IfNoneMatch,
+			httpheaders.IfModifiedSince,
 			httpheaders.AcceptEncoding,
 			httpheaders.Range,
 		},

+ 0 - 33
handlers/stream/factory.go

@@ -1,33 +0,0 @@
-package stream
-
-import (
-	"context"
-	"net/http"
-
-	"github.com/imgproxy/imgproxy/v3/headerwriter"
-	"github.com/imgproxy/imgproxy/v3/imagefetcher"
-)
-
-// Factory is a struct which stores dependencies for the image streaming service
-// NOTE: Probably, we'll use the same factory for all handlers in the future.
-type Factory struct {
-	config   *Config
-	hwConfig *headerwriter.Config
-	fetcher  *imagefetcher.Fetcher
-}
-
-// New creates a new handler instance with the provided configuration
-func New(config *Config, hwConfig *headerwriter.Config, fetcher *imagefetcher.Fetcher) *Factory {
-	return &Factory{config: config, fetcher: fetcher, hwConfig: hwConfig}
-}
-
-// Stream streams the image based on the provided request
-func (h *Factory) NewHandler(ctx context.Context, p *StreamingParams, rr http.ResponseWriter) *Handler {
-	return &Handler{
-		fetcher:  h.fetcher,
-		config:   h.config,
-		hwConfig: h.hwConfig,
-		params:   p,
-		res:      rr,
-	}
-}

+ 63 - 34
handlers/stream/handler.go

@@ -33,25 +33,55 @@ var (
 	}
 )
 
-// StreamingParams represents an image request params that will be processed by the image streamer
-type StreamingParams struct {
-	UserRequest       *http.Request              // Original user request to imgproxy
-	ImageURL          string                     // URL of the image to be streamed
-	ReqID             string                     // Unique identifier for the request
-	ProcessingOptions *options.ProcessingOptions // Processing options for the image
-}
-
 // Handler handles image passthrough requests, allowing images to be streamed directly
 type Handler struct {
 	fetcher  *imagefetcher.Fetcher // Fetcher instance to handle image fetching
 	config   *Config               // Configuration for the streamer
 	hwConfig *headerwriter.Config  // Configuration for header writing
-	params   *StreamingParams      // Streaming request
-	res      http.ResponseWriter   // Response writer to write the streamed image
+}
+
+// request holds the parameters and state for a single streaming request
+type request struct {
+	handler     *Handler
+	userRequest *http.Request
+	imageURL    string
+	reqID       string
+	po          *options.ProcessingOptions
+	rw          http.ResponseWriter
+}
+
+// New creates new handler object
+func New(config *Config, hwConfig *headerwriter.Config, fetcher *imagefetcher.Fetcher) *Handler {
+	return &Handler{
+		fetcher:  fetcher,
+		config:   config,
+		hwConfig: hwConfig,
+	}
 }
 
 // Stream handles the image passthrough request, streaming the image directly to the response writer
-func (s *Handler) Execute(ctx context.Context) error {
+func (s *Handler) Execute(
+	ctx context.Context,
+	userRequest *http.Request,
+	imageURL string,
+	reqID string,
+	po *options.ProcessingOptions,
+	rw http.ResponseWriter,
+) error {
+	stream := &request{
+		handler:     s,
+		userRequest: userRequest,
+		imageURL:    imageURL,
+		reqID:       reqID,
+		po:          po,
+		rw:          rw,
+	}
+
+	return stream.execute(ctx)
+}
+
+// execute handles the actual streaming logic
+func (s *request) execute(ctx context.Context) error {
 	stats.IncImagesInProgress()
 	defer stats.DecImagesInProgress()
 	defer monitoring.StartStreamingSegment(ctx)()
@@ -64,7 +94,7 @@ func (s *Handler) Execute(ctx context.Context) error {
 	}
 
 	// Build the request to fetch the image
-	r, err := s.fetcher.BuildRequest(ctx, s.params.ImageURL, requestHeaders, cookieJar)
+	r, err := s.handler.fetcher.BuildRequest(ctx, s.imageURL, requestHeaders, cookieJar)
 	defer r.Cancel()
 	if err != nil {
 		return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
@@ -80,18 +110,18 @@ func (s *Handler) Execute(ctx context.Context) error {
 	}
 
 	// Output streaming response headers
-	hw := headerwriter.New(s.hwConfig, res.Header, s.params.ImageURL)
-	hw.Passthrough(s.config.KeepResponseHeaders) // NOTE: priority? This is lowest as it was
+	hw := headerwriter.New(s.handler.hwConfig, res.Header, s.imageURL)
+	hw.Passthrough(s.handler.config.KeepResponseHeaders) // NOTE: priority? This is lowest as it was
 	hw.SetContentLength(int(res.ContentLength))
 	hw.SetCanonical()
-	hw.SetMaxAge(s.params.ProcessingOptions.Expires, 0)
-	hw.Write(s.res)
+	hw.SetMaxAge(s.po.Expires, 0)
+	hw.Write(s.rw)
 
 	// Write Content-Disposition header
-	s.writeContentDisposition(s.params.ImageURL, res)
+	s.writeContentDisposition(r.URL().Path, res)
 
 	// Copy the status code from the original response
-	s.res.WriteHeader(res.StatusCode)
+	s.rw.WriteHeader(res.StatusCode)
 
 	// Write the actual data
 	s.streamData(res)
@@ -100,21 +130,21 @@ func (s *Handler) Execute(ctx context.Context) error {
 }
 
 // getCookieJar returns non-empty cookie jar if cookie passthrough is enabled
-func (s *Handler) getCookieJar() (http.CookieJar, error) {
-	if !s.config.CookiePassthrough {
+func (s *request) getCookieJar() (http.CookieJar, error) {
+	if !s.handler.config.CookiePassthrough {
 		return nil, nil
 	}
 
-	return cookies.JarFromRequest(s.params.UserRequest)
+	return cookies.JarFromRequest(s.userRequest)
 }
 
 // getPassthroughRequestHeaders returns a new http.Header containing only
 // the headers that should be passed through from the user request
-func (s *Handler) getPassthroughRequestHeaders() http.Header {
+func (s *request) getPassthroughRequestHeaders() http.Header {
 	h := make(http.Header)
 
-	for _, key := range s.config.PassthroughRequestHeaders {
-		values := s.params.UserRequest.Header.Values(key)
+	for _, key := range s.handler.config.PassthroughRequestHeaders {
+		values := s.userRequest.Header.Values(key)
 
 		for _, value := range values {
 			h.Add(key, value)
@@ -125,38 +155,37 @@ func (s *Handler) getPassthroughRequestHeaders() http.Header {
 }
 
 // writeContentDisposition writes the headers to the response writer
-func (s *Handler) writeContentDisposition(imagePath string, serverResponse *http.Response) {
+func (s *request) writeContentDisposition(imagePath string, serverResponse *http.Response) {
 	// Try to set correct Content-Disposition file name and extension
 	if serverResponse.StatusCode >= 200 && serverResponse.StatusCode < 300 {
 		ct := serverResponse.Header.Get(httpheaders.ContentType)
-		po := s.params.ProcessingOptions
 
 		// Try to best guess the file name and extension
 		cd := httpheaders.ContentDispositionValue(
 			imagePath,
-			po.Filename,
+			s.po.Filename,
 			"",
 			ct,
-			po.ReturnAttachment,
+			s.po.ReturnAttachment,
 		)
 
 		// Write the Content-Disposition header
-		s.res.Header().Set(httpheaders.ContentDisposition, cd)
+		s.rw.Header().Set(httpheaders.ContentDisposition, cd)
 	}
 }
 
 // streamData copies the image data from the response body to the response writer
-func (s *Handler) streamData(res *http.Response) {
+func (s *request) streamData(res *http.Response) {
 	buf := streamBufPool.Get().(*[]byte)
 	defer streamBufPool.Put(buf)
 
-	_, copyerr := io.CopyBuffer(s.res, res.Body, *buf)
+	_, copyerr := io.CopyBuffer(s.rw, res.Body, *buf)
 
 	server.LogResponse(
-		s.params.ReqID, s.params.UserRequest, res.StatusCode, nil,
+		s.reqID, s.userRequest, res.StatusCode, nil,
 		log.Fields{
-			"image_url":          s.params.ImageURL,
-			"processing_options": s.params.ProcessingOptions,
+			"image_url":          s.imageURL,
+			"processing_options": s.po,
 		},
 	)
 

+ 453 - 37
handlers/stream/handler_test.go

@@ -2,11 +2,18 @@ package stream
 
 import (
 	"context"
+	"fmt"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"os"
 	"path/filepath"
+	"strconv"
 	"testing"
+	"time"
+
+	"github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/suite"
 
 	"github.com/imgproxy/imgproxy/v3/config"
 	"github.com/imgproxy/imgproxy/v3/headerwriter"
@@ -14,83 +21,492 @@ import (
 	"github.com/imgproxy/imgproxy/v3/imagefetcher"
 	"github.com/imgproxy/imgproxy/v3/options"
 	"github.com/imgproxy/imgproxy/v3/transport"
-	"github.com/stretchr/testify/suite"
 )
 
 const (
 	testDataPath = "../../testdata"
 )
 
-type StreamerTestSuite struct {
+type HandlerTestSuite struct {
 	suite.Suite
-	ts      *httptest.Server
-	factory *Factory
+	handler *Handler
 }
 
-func (s *StreamerTestSuite) SetupSuite() {
+func (s *HandlerTestSuite) SetupSuite() {
 	config.Reset()
 	config.AllowLoopbackSourceAddresses = true
 
-	s.ts = httptest.NewServer(http.FileServer(http.Dir(testDataPath)))
+	// Silence logs during tests
+	logrus.SetOutput(io.Discard)
 }
 
-func (s *StreamerTestSuite) TearDownSuite() {
+func (s *HandlerTestSuite) TearDownSuite() {
 	config.Reset()
-	s.ts.Close()
+	logrus.SetOutput(os.Stdout)
 }
 
-func (s *StreamerTestSuite) SetupTest() {
+func (s *HandlerTestSuite) SetupTest() {
+	config.Reset()
+	config.AllowLoopbackSourceAddresses = true
+
 	tr, err := transport.NewTransport()
 	s.Require().NoError(err)
 
 	fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
 	s.Require().NoError(err)
 
-	s.factory = New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
+	s.handler = New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
 }
 
-func (s *StreamerTestSuite) TestStreamer() {
-	const testFilePath = "/test1.jpg"
+func (s *HandlerTestSuite) readTestFile(name string) []byte {
+	data, err := os.ReadFile(filepath.Join(testDataPath, name))
+	s.Require().NoError(err)
+	return data
+}
+
+// TestHandlerBasicRequest checks basic streaming request
+func (s *HandlerTestSuite) TestHandlerBasicRequest() {
+	data := s.readTestFile("test1.png")
 
-	// Read expected output from test data
-	expected, err := os.ReadFile(filepath.Join(testDataPath, testFilePath))
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
+
+	err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
 	s.Require().NoError(err)
 
-	// Prepare HTTP request and response recorder
-	req := httptest.NewRequest("GET", testFilePath, nil)
-	req.Header.Set(httpheaders.AcceptEncoding, "gzip")
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+	s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
 
-	// Override the test server handler to assert Accept-Encoding header
-	s.ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		// Check that the Accept-Encoding header is passed through from original request
+	// Verify we get the original image data
+	actual := rw.Body.Bytes()
+	s.Require().Equal(data, actual)
+}
+
+// TestHandlerResponseHeadersPassthrough checks that original response headers are
+// passed through to the client
+func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
+	data := s.readTestFile("test1.png")
+	contentLength := len(data)
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
+		w.Header().Set(httpheaders.AcceptRanges, "bytes")
+		w.Header().Set(httpheaders.Etag, "etag")
+		w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
+
+	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+	s.Require().NoError(err)
+
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+	s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
+	s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
+	s.Require().Equal("bytes", res.Header.Get(httpheaders.AcceptRanges))
+	s.Require().Equal("etag", res.Header.Get(httpheaders.Etag))
+	s.Require().Equal("Wed, 21 Oct 2015 07:28:00 GMT", res.Header.Get(httpheaders.LastModified))
+}
+
+// TestHandlerRequestHeadersPassthrough checks that original request headers are passed through
+// to the server
+func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
+	etag := `"test-etag-123"`
+	data := s.readTestFile("test1.png")
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Verify that If-None-Match header is passed through
+		s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
 		s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
-		http.ServeFile(w, r, filepath.Join(testDataPath, r.URL.Path))
-	})
+		s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
+
+		w.Header().Set(httpheaders.Etag, etag)
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
 
+	req := httptest.NewRequest("GET", "/", nil)
+	req.Header.Set(httpheaders.IfNoneMatch, etag)
+	req.Header.Set(httpheaders.AcceptEncoding, "gzip")
+	req.Header.Set(httpheaders.Range, "bytes=*")
+
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
+
+	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+	s.Require().NoError(err)
+
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+	s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
+}
+
+// TestHandlerContentDisposition checks that Content-Disposition header is set correctly
+func (s *HandlerTestSuite) TestHandlerContentDisposition() {
+	data := s.readTestFile("test1.png")
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	rw := httptest.NewRecorder()
 	po := &options.ProcessingOptions{
-		Filename: "xxx", // Override Content-Disposition
+		Filename:         "custom_name",
+		ReturnAttachment: true,
+	}
+
+	// Use a URL with a .png extension to help content disposition logic
+	imageURL := ts.URL + "/test.png"
+	err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
+	s.Require().NoError(err)
+
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+	s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
+	s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
+}
+
+// TestHandlerCacheControl checks that Cache-Control header is set correctly in different cases
+func (s *HandlerTestSuite) TestHandlerCacheControl() {
+	type testCase struct {
+		name                    string
+		cacheControlPassthrough bool
+		setupOriginHeaders      func(http.ResponseWriter)
+		timestampOffset         *time.Duration // nil for no timestamp, otherwise the offset from now
+		expectedStatusCode      int
+		validate                func(*testing.T, *http.Response)
 	}
 
-	rr := httptest.NewRecorder()
+	// Duration variables for test cases
+	var (
+		oneHour          = time.Hour
+		thirtyMinutes    = 30 * time.Minute
+		fortyFiveMinutes = 45 * time.Minute
+		twoHours         = time.Hour * 2
+		oneMinuteDelta   = float64(time.Minute)
+	)
+
+	// Set this explicitly for testing purposes
+	config.TTL = 4242
+
+	testCases := []testCase{
+		{
+			name:                    "Passthrough",
+			cacheControlPassthrough: true,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
+			},
+			timestampOffset:    nil,
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().Equal("max-age=3600, public", res.Header.Get(httpheaders.CacheControl))
+			},
+		},
+		// Checks that expires gets convert to cache-control
+		{
+			name:                    "ExpiresPassthrough",
+			cacheControlPassthrough: true,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
+			},
+			timestampOffset:    nil,
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				// When expires is converted to cache-control, the expires header should be empty
+				s.Require().Empty(res.Header.Get(httpheaders.Expires))
+				s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
+			},
+		},
+		// It would be set to something like default ttl
+		{
+			name:                    "PassthroughDisabled",
+			cacheControlPassthrough: false,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
+			},
+			timestampOffset:    nil,
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second)
+			},
+		},
+		// When expires is set in processing options, but not present in the response
+		{
+			name:                    "WithProcessingOptionsExpires",
+			cacheControlPassthrough: false,
+			setupOriginHeaders:      func(w http.ResponseWriter) {}, // No origin headers
+			timestampOffset:         &oneHour,
+			expectedStatusCode:      200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().InDelta(oneHour, s.maxAgeValue(res), oneMinuteDelta)
+			},
+		},
+		// When expires is set in processing options, and is present in the response,
+		// and passthrough is enabled
+		{
+			name:                    "ProcessingOptionsOverridesOrigin",
+			cacheControlPassthrough: true,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				// Origin has a longer cache time
+				w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
+			},
+			timestampOffset:    &thirtyMinutes,
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
+			},
+		},
+		// When expires is not set in po, but both expires and cc are present in response,
+		// and passthrough is enabled
+		{
+			name:                    "BothHeadersPassthroughEnabled",
+			cacheControlPassthrough: true,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				// Origin has both Cache-Control and Expires headers
+				w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
+				w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
+			},
+			timestampOffset:    nil,
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				// Cache-Control should take precedence over Expires when both are present
+				s.Require().InDelta(thirtyMinutes, s.maxAgeValue(res), oneMinuteDelta)
+				s.Require().Empty(res.Header.Get(httpheaders.Expires))
+			},
+		},
+		// When expires is set in PO AND both cache-control and expires are present in response,
+		// and passthrough is enabled
+		{
+			name:                    "ProcessingOptionsOverridesBothOriginHeaders",
+			cacheControlPassthrough: true,
+			setupOriginHeaders: func(w http.ResponseWriter) {
+				// Origin has both Cache-Control and Expires headers with longer cache times
+				w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
+				w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
+			},
+			timestampOffset:    &fortyFiveMinutes, // Shorter than origin headers
+			expectedStatusCode: 200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().InDelta(fortyFiveMinutes, s.maxAgeValue(res), oneMinuteDelta)
+				s.Require().Empty(res.Header.Get(httpheaders.Expires))
+			},
+		},
+		// No headers set
+		{
+			name:                    "NoOriginHeaders",
+			cacheControlPassthrough: false,
+			setupOriginHeaders:      func(w http.ResponseWriter) {}, // Origin has no cache headers
+			timestampOffset:         nil,
+			expectedStatusCode:      200,
+			validate: func(t *testing.T, res *http.Response) {
+				s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second)
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		s.Run(tc.name, func() {
+			// Set config values for this test
+			config.CacheControlPassthrough = tc.cacheControlPassthrough
+			config.TTL = 4242 // Set consistent TTL for testing
+
+			data := s.readTestFile("test1.png")
+
+			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				tc.setupOriginHeaders(w)
+				w.Header().Set(httpheaders.ContentType, "image/png")
+				w.WriteHeader(200)
+				w.Write(data)
+			}))
+			defer ts.Close()
+
+			// Create new handler with updated config for each test
+			tr, err := transport.NewTransport()
+			s.Require().NoError(err)
 
-	p := StreamingParams{
-		UserRequest:       req,
-		ImageURL:          s.ts.URL + testFilePath,
-		ReqID:             "test-req-id",
-		ProcessingOptions: po,
+			fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
+			s.Require().NoError(err)
+
+			handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
+
+			req := httptest.NewRequest("GET", "/", nil)
+			rw := httptest.NewRecorder()
+			po := &options.ProcessingOptions{}
+
+			if tc.timestampOffset != nil {
+				expires := time.Now().Add(*tc.timestampOffset)
+				po.Expires = &expires
+			}
+
+			err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+			s.Require().NoError(err)
+
+			res := rw.Result()
+			s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
+			tc.validate(s.T(), res)
+		})
+	}
+}
+
+// maxAgeValue parses max-age from cache-control
+func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
+	cacheControl := res.Header.Get(httpheaders.CacheControl)
+	if cacheControl == "" {
+		return 0
 	}
+	var maxAge int
+	fmt.Sscanf(cacheControl, "max-age=%d", &maxAge)
+	return time.Duration(maxAge) * time.Second
+}
+
+// TestHandlerSecurityHeaders tests the security headers set by the streaming service.
+func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
+	data := s.readTestFile("test1.png")
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
 
-	err = s.factory.NewHandler(context.Background(), &p, rr).Execute(context.Background())
+	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
 	s.Require().NoError(err)
 
-	// Check response body
-	respBody := rr.Body.Bytes()
-	s.Require().Equal(expected, respBody)
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+	s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
+}
+
+// TestHandlerErrorResponse tests the error responses from the streaming service.
+func (s *HandlerTestSuite) TestHandlerErrorResponse() {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(404)
+		w.Write([]byte("Not Found"))
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
+
+	err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+	s.Require().NoError(err)
+
+	res := rw.Result()
+	s.Require().Equal(404, res.StatusCode)
+}
+
+// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
+func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
+	// Enable cookie passthrough for this test
+	config.CookiePassthrough = true
+	defer func() {
+		config.CookiePassthrough = false // Reset after test
+	}()
+
+	// Create new handler with updated config
+	tr, err := transport.NewTransport()
+	s.Require().NoError(err)
 
-	// Check that Content-Disposition header is set correctly
-	s.Require().Equal("inline; filename=\"xxx.jpg\"", rr.Header().Get("Content-Disposition"))
+	fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
+	s.Require().NoError(err)
+
+	handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
+
+	data := s.readTestFile("test1.png")
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Verify cookies are passed through
+		cookie, cerr := r.Cookie("test_cookie")
+		if cerr == nil {
+			s.Equal("test_value", cookie.Value)
+		}
+
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	req := httptest.NewRequest("GET", "/", nil)
+	req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
+	rw := httptest.NewRecorder()
+	po := &options.ProcessingOptions{}
+
+	err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+	s.Require().NoError(err)
+
+	res := rw.Result()
+	s.Require().Equal(200, res.StatusCode)
+}
+
+// TestHandlerCanonicalHeader tests that the canonical header is set correctly
+func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
+	data := s.readTestFile("test1.png")
+
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set(httpheaders.ContentType, "image/png")
+		w.WriteHeader(200)
+		w.Write(data)
+	}))
+	defer ts.Close()
+
+	for _, sc := range []bool{true, false} {
+		config.SetCanonicalHeader = sc
+
+		// Create new handler with updated config
+		tr, err := transport.NewTransport()
+		s.Require().NoError(err)
+
+		fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv())
+		s.Require().NoError(err)
+
+		handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher)
+
+		req := httptest.NewRequest("GET", "/", nil)
+		rw := httptest.NewRecorder()
+		po := &options.ProcessingOptions{}
+
+		err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
+		s.Require().NoError(err)
+
+		res := rw.Result()
+		s.Require().Equal(200, res.StatusCode)
+
+		if sc {
+			s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
+		} else {
+			s.Require().Empty(res.Header.Get(httpheaders.Link))
+		}
+	}
 }
 
-func TestStreamer(t *testing.T) {
-	suite.Run(t, new(StreamerTestSuite))
+func TestHandler(t *testing.T) {
+	suite.Run(t, new(HandlerTestSuite))
 }

+ 7 - 1
processing_handler.go

@@ -17,6 +17,8 @@ import (
 	"github.com/imgproxy/imgproxy/v3/cookies"
 	"github.com/imgproxy/imgproxy/v3/errorreport"
 	"github.com/imgproxy/imgproxy/v3/etag"
+	"github.com/imgproxy/imgproxy/v3/handlers/stream"
+	"github.com/imgproxy/imgproxy/v3/headerwriter"
 	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/imagedata"
@@ -275,7 +277,11 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) err
 	}
 
 	if po.Raw {
-		return streamOriginImage(ctx, reqID, r, rw, po, imageURL)
+		// TODO: Move this up
+		cfg := stream.NewConfigFromEnv()
+		hwCfg := headerwriter.NewConfigFromEnv()
+		handler := stream.New(cfg, hwCfg, imagedata.Fetcher)
+		return handler.Execute(ctx, r, imageURL, reqID, po, rw)
 	}
 
 	// SVG is a special case. Though saving to svg is not supported, SVG->SVG is.

+ 136 - 136
stream.go

@@ -1,138 +1,138 @@
 package main
 
-import (
-	"context"
-	"io"
-	"net/http"
-	"strconv"
-	"sync"
-
-	log "github.com/sirupsen/logrus"
-
-	"github.com/imgproxy/imgproxy/v3/config"
-	"github.com/imgproxy/imgproxy/v3/cookies"
-	"github.com/imgproxy/imgproxy/v3/httpheaders"
-	"github.com/imgproxy/imgproxy/v3/ierrors"
-	"github.com/imgproxy/imgproxy/v3/imagedata"
-	"github.com/imgproxy/imgproxy/v3/monitoring"
-	"github.com/imgproxy/imgproxy/v3/monitoring/stats"
-	"github.com/imgproxy/imgproxy/v3/options"
-	"github.com/imgproxy/imgproxy/v3/server"
-)
-
-var (
-	streamReqHeaders = []string{
-		"If-None-Match",
-		"If-Modified-Since",
-		"Accept-Encoding",
-		"Range",
-	}
-
-	streamRespHeaders = []string{
-		"ETag",
-		"Content-Type",
-		"Content-Encoding",
-		"Content-Range",
-		"Accept-Ranges",
-		"Last-Modified",
-	}
-
-	streamBufPool = sync.Pool{
-		New: func() interface{} {
-			buf := make([]byte, 4096)
-			return &buf
-		},
-	}
-)
-
-func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error {
-	stats.IncImagesInProgress()
-	defer stats.DecImagesInProgress()
-	defer monitoring.StartStreamingSegment(ctx)()
-
-	var (
-		cookieJar http.CookieJar
-		err       error
-	)
-
-	imgRequestHeader := make(http.Header)
-
-	for _, k := range streamReqHeaders {
-		if v := r.Header.Get(k); len(v) != 0 {
-			imgRequestHeader.Set(k, v)
-		}
-	}
-
-	if config.CookiePassthrough {
-		cookieJar, err = cookies.JarFromRequest(r)
-		if err != nil {
-			return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
-		}
-	}
-
-	req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
-	defer req.Cancel()
-	if err != nil {
-		return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
-	}
-
-	res, err := req.Send()
-	if res != nil {
-		defer res.Body.Close()
-	}
-	if err != nil {
-		return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
-	}
-
-	for _, k := range streamRespHeaders {
-		vv := res.Header.Values(k)
-		for _, v := range vv {
-			rw.Header().Set(k, v)
-		}
-	}
-
-	if res.ContentLength >= 0 {
-		rw.Header().Set("Content-Length", strconv.Itoa(int(res.ContentLength)))
-	}
-
-	if res.StatusCode < 300 {
-		contentDisposition := httpheaders.ContentDispositionValue(
-			req.URL().Path,
-			po.Filename,
-			"",
-			rw.Header().Get(httpheaders.ContentType),
-			po.ReturnAttachment,
-		)
-		rw.Header().Set("Content-Disposition", contentDisposition)
-	}
-
-	setCacheControl(rw, po.Expires, res.Header)
-	setCanonical(rw, imageURL)
-	rw.Header().Set("Content-Security-Policy", "script-src 'none'")
-
-	rw.WriteHeader(res.StatusCode)
-
-	buf := streamBufPool.Get().(*[]byte)
-	defer streamBufPool.Put(buf)
-
-	_, copyerr := io.CopyBuffer(rw, res.Body, *buf)
-	if copyerr == http.ErrBodyNotAllowed {
-		// We can hit this for some statuses like 304 Not Modified.
-		// We can ignore this error.
-		copyerr = nil
-	}
-
-	server.LogResponse(
-		reqID, r, res.StatusCode, nil,
-		log.Fields{
-			"image_url":          imageURL,
-			"processing_options": po,
-		},
-	)
-
-	if copyerr != nil {
-		panic(http.ErrAbortHandler)
-	}
-
-	return nil
-}
+// import (
+// 	"context"
+// 	"io"
+// 	"net/http"
+// 	"strconv"
+// 	"sync"
+
+// 	log "github.com/sirupsen/logrus"
+
+// 	"github.com/imgproxy/imgproxy/v3/config"
+// 	"github.com/imgproxy/imgproxy/v3/cookies"
+// 	"github.com/imgproxy/imgproxy/v3/httpheaders"
+// 	"github.com/imgproxy/imgproxy/v3/ierrors"
+// 	"github.com/imgproxy/imgproxy/v3/imagedata"
+// 	"github.com/imgproxy/imgproxy/v3/monitoring"
+// 	"github.com/imgproxy/imgproxy/v3/monitoring/stats"
+// 	"github.com/imgproxy/imgproxy/v3/options"
+// 	"github.com/imgproxy/imgproxy/v3/server"
+// )
+
+// var (
+// 	streamReqHeaders = []string{
+// 		"If-None-Match",
+// 		"If-Modified-Since",
+// 		"Accept-Encoding",
+// 		"Range",
+// 	}
+
+// 	streamRespHeaders = []string{
+// 		"ETag",
+// 		"Content-Type",
+// 		"Content-Encoding",
+// 		"Content-Range",
+// 		"Accept-Ranges",
+// 		"Last-Modified",
+// 	}
+
+// 	streamBufPool = sync.Pool{
+// 		New: func() interface{} {
+// 			buf := make([]byte, 4096)
+// 			return &buf
+// 		},
+// 	}
+// )
+
+// func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error {
+// 	stats.IncImagesInProgress()
+// 	defer stats.DecImagesInProgress()
+// 	defer monitoring.StartStreamingSegment(ctx)()
+
+// 	var (
+// 		cookieJar http.CookieJar
+// 		err       error
+// 	)
+
+// 	imgRequestHeader := make(http.Header)
+
+// 	for _, k := range streamReqHeaders {
+// 		if v := r.Header.Get(k); len(v) != 0 {
+// 			imgRequestHeader.Set(k, v)
+// 		}
+// 	}
+
+// 	if config.CookiePassthrough {
+// 		cookieJar, err = cookies.JarFromRequest(r)
+// 		if err != nil {
+// 			return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
+// 		}
+// 	}
+
+// 	req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
+// 	defer req.Cancel()
+// 	if err != nil {
+// 		return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
+// 	}
+
+// 	res, err := req.Send()
+// 	if res != nil {
+// 		defer res.Body.Close()
+// 	}
+// 	if err != nil {
+// 		return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
+// 	}
+
+// 	for _, k := range streamRespHeaders {
+// 		vv := res.Header.Values(k)
+// 		for _, v := range vv {
+// 			rw.Header().Set(k, v)
+// 		}
+// 	}
+
+// 	if res.ContentLength >= 0 {
+// 		rw.Header().Set("Content-Length", strconv.Itoa(int(res.ContentLength)))
+// 	}
+
+// 	if res.StatusCode < 300 {
+// 		contentDisposition := httpheaders.ContentDispositionValue(
+// 			req.URL().Path,
+// 			po.Filename,
+// 			"",
+// 			rw.Header().Get(httpheaders.ContentType),
+// 			po.ReturnAttachment,
+// 		)
+// 		rw.Header().Set("Content-Disposition", contentDisposition)
+// 	}
+
+// 	setCacheControl(rw, po.Expires, res.Header)
+// 	setCanonical(rw, imageURL)
+// 	rw.Header().Set("Content-Security-Policy", "script-src 'none'")
+
+// 	rw.WriteHeader(res.StatusCode)
+
+// 	buf := streamBufPool.Get().(*[]byte)
+// 	defer streamBufPool.Put(buf)
+
+// 	_, copyerr := io.CopyBuffer(rw, res.Body, *buf)
+// 	if copyerr == http.ErrBodyNotAllowed {
+// 		// We can hit this for some statuses like 304 Not Modified.
+// 		// We can ignore this error.
+// 		copyerr = nil
+// 	}
+
+// 	server.LogResponse(
+// 		reqID, r, res.StatusCode, nil,
+// 		log.Fields{
+// 			"image_url":          imageURL,
+// 			"processing_options": po,
+// 		},
+// 	)
+
+// 	if copyerr != nil {
+// 		panic(http.ErrAbortHandler)
+// 	}
+
+// 	return nil
+// }