Selaa lähdekoodia

Add helpers for copying HTTP headers

DarthSim 1 kuukausi sitten
vanhempi
commit
5a945984a7

+ 1 - 8
handlers/stream/handler.go

@@ -147,14 +147,7 @@ func (s *request) getCookieJar() (http.CookieJar, error) {
 // the headers that should be passed through from the user request
 func (s *request) getImageRequestHeaders() http.Header {
 	h := make(http.Header)
-
-	for _, key := range s.handler.config.PassthroughRequestHeaders {
-		values := s.imageRequest.Header.Values(key)
-
-		for _, value := range values {
-			h.Add(key, value)
-		}
-	}
+	httpheaders.CopyFromRequest(s.imageRequest, h, s.handler.config.PassthroughRequestHeaders)
 
 	return h
 }

+ 4 - 24
headerwriter/writer.go

@@ -113,25 +113,13 @@ func (w *writer) SetVary() {
 
 // Passthrough copies specified headers from the original response headers to the response headers.
 func (w *writer) Passthrough(only []string) {
-	for _, key := range only {
-		values := w.originalResponseHeaders.Values(key)
-
-		for _, value := range values {
-			w.result.Add(key, value)
-		}
-	}
+	httpheaders.Copy(w.originalResponseHeaders, w.result, only)
 }
 
 // CopyFrom copies specified headers from the headers object. Please note that
 // all the past operations may overwrite those values.
 func (w *writer) CopyFrom(headers http.Header, only []string) {
-	for _, key := range only {
-		values := headers.Values(key)
-
-		for _, value := range values {
-			w.result.Add(key, value)
-		}
-	}
+	httpheaders.Copy(headers, w.result, only)
 }
 
 // SetContentLength sets the Content-Length header
@@ -217,14 +205,6 @@ func (w *writer) Write(rw http.ResponseWriter) {
 
 	w.setCSP()
 
-	for key, values := range w.result {
-		// Do not overwrite existing headers which were set outside the header writer
-		if len(rw.Header().Get(key)) > 0 {
-			continue
-		}
-
-		for _, value := range values {
-			rw.Header().Add(key, value)
-		}
-	}
+	// Copy all headers to the response without overwriting existing ones
+	httpheaders.CopyAll(w.result, rw.Header(), false)
 }

+ 59 - 0
httpheaders/copy.go

@@ -0,0 +1,59 @@
+package httpheaders
+
+import "net/http"
+
+// Copy copies specified headers from one header to another.
+func Copy(from, to http.Header, only []string) {
+	for _, key := range only {
+		key = http.CanonicalHeaderKey(key)
+		if values := from[key]; len(values) > 0 {
+			to[key] = append([]string(nil), values...)
+		}
+	}
+}
+
+// CopyAll copies all headers from one header to another.
+func CopyAll(from, to http.Header, overwrite bool) {
+	for key, values := range from {
+		// Keys in http.Header are already canonicalized, so no need for http.CanonicalHeaderKey here
+		if !overwrite && len(to.Values(key)) > 0 {
+			continue
+		}
+
+		if len(values) > 0 {
+			to[key] = append([]string(nil), values...)
+		}
+	}
+}
+
+// CopyFromRequest copies specified headers from the http.Request to the provided header.
+func CopyFromRequest(req *http.Request, header http.Header, only []string) {
+	for _, key := range only {
+		key = http.CanonicalHeaderKey(key)
+
+		if key == Host {
+			header.Set(key, req.Host)
+			continue
+		}
+
+		if values := req.Header[key]; len(values) > 0 {
+			header[key] = append([]string(nil), values...)
+		}
+	}
+}
+
+// CopyToRequest copies headers from the provided header to the http.Request.
+func CopyToRequest(header http.Header, req *http.Request) {
+	for key, values := range header {
+		if len(values) == 0 {
+			continue
+		}
+
+		// Keys in http.Header are already canonicalized, so no need for http.CanonicalHeaderKey here
+		if key == Host {
+			req.Host = values[0]
+		} else {
+			req.Header[key] = append([]string(nil), values...)
+		}
+	}
+}

+ 130 - 0
httpheaders/copy_test.go

@@ -0,0 +1,130 @@
+package httpheaders
+
+import (
+	"fmt"
+	"net/http"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestCopy(t *testing.T) {
+	from := http.Header{
+		"X-Test-1": {"value1", "value2"},
+		"X-Test-2": {"value3"},
+		"X-Test-3": {"value4"},
+		"X-Test-4": nil,
+	}
+
+	to := http.Header{
+		"X-Test-1": {"oldvalue"},
+		"X-Test-4": {"value5"},
+		"X-Test-5": {"value6"},
+	}
+
+	Copy(from, to, []string{"X-Test-1", "x-test-3", "X-Non-Existent"})
+
+	require.Equal(t, []string{"value1", "value2"}, to.Values("X-Test-1"))
+	require.Equal(t, []string{"value4"}, to.Values("X-Test-3"))
+	require.Equal(t, []string{"value5"}, to.Values("X-Test-4"))
+	require.Equal(t, []string{"value6"}, to.Values("X-Test-5"))
+	require.Empty(t, to.Values("X-Test-2"))
+}
+
+func TestCopyAll(t *testing.T) {
+	from := http.Header{
+		"X-Test-1": {"value1", "value2"},
+		"X-Test-2": {"value3"},
+		"X-Test-3": nil,
+	}
+
+	to := http.Header{
+		"X-Test-1": {"oldvalue"},
+		"X-Test-3": {"value4"},
+		"X-Test-4": {"value5"},
+	}
+
+	testCases := []struct {
+		overwrite bool
+		expected  http.Header
+	}{
+		{
+			overwrite: false,
+			expected: http.Header{
+				"X-Test-1": {"oldvalue"},
+				"X-Test-2": {"value3"},
+				"X-Test-3": {"value4"},
+				"X-Test-4": {"value5"},
+			},
+		},
+		{
+			overwrite: true,
+			expected: http.Header{
+				"X-Test-1": {"value1", "value2"},
+				"X-Test-2": {"value3"},
+				"X-Test-3": {"value4"},
+				"X-Test-4": {"value5"},
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("overwrite=%v", tc.overwrite), func(t *testing.T) {
+			toCopy := to.Clone() // Clone to avoid modifying the original 'to' header
+			CopyAll(from, toCopy, tc.overwrite)
+			require.Equal(t, tc.expected, toCopy)
+		})
+	}
+}
+
+func TestCopyFromRequest(t *testing.T) {
+	req, err := http.NewRequest("GET", "http://example.com", nil)
+	require.NoError(t, err)
+
+	req.Host = "example.com"
+	req.Header = http.Header{
+		"X-Test-1": {"value1", "value2"},
+		"X-Test-2": {"value3"},
+		"X-Test-3": nil,
+	}
+
+	header := http.Header{
+		"X-Test-1": {"oldvalue"},
+		"X-Test-3": {"value4"},
+		"X-Test-4": {"value5"},
+	}
+
+	CopyFromRequest(req, header, []string{"X-Test-1", "x-test-2", "host", "X-Non-Existent"})
+
+	require.Equal(t, []string{"value1", "value2"}, header.Values("X-Test-1"))
+	require.Equal(t, []string{"value3"}, header.Values("X-Test-2"))
+	require.Equal(t, []string{"value4"}, header.Values("X-Test-3"))
+	require.Equal(t, []string{"value5"}, header.Values("X-Test-4"))
+	require.Equal(t, []string{"example.com"}, header.Values("Host"))
+}
+
+func TestCopyToRequest(t *testing.T) {
+	req, err := http.NewRequest("GET", "http://example.com", nil)
+	require.NoError(t, err)
+
+	req.Header = http.Header{
+		"X-Test-1": {"oldvalue"},
+		"X-Test-3": {"value4"},
+		"X-Test-4": {"value5"},
+	}
+
+	header := http.Header{
+		"X-Test-1": {"value1", "value2"},
+		"X-Test-2": {"value3"},
+		"X-Test-3": nil,
+		"Host":     {"newhost.com"},
+	}
+
+	CopyToRequest(header, req)
+
+	require.Equal(t, []string{"value1", "value2"}, req.Header.Values("X-Test-1"))
+	require.Equal(t, []string{"value3"}, req.Header.Values("X-Test-2"))
+	require.Equal(t, []string{"value4"}, req.Header.Values("X-Test-3"))
+	require.Equal(t, []string{"value5"}, req.Header.Values("X-Test-4"))
+	require.Equal(t, "newhost.com", req.Host)
+}

+ 2 - 5
imagedata/image_data_test.go

@@ -17,6 +17,7 @@ import (
 	"github.com/stretchr/testify/suite"
 
 	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/httpheaders"
 	"github.com/imgproxy/imgproxy/v3/ierrors"
 	"github.com/imgproxy/imgproxy/v3/imagetype"
 	"github.com/imgproxy/imgproxy/v3/testutil"
@@ -55,11 +56,7 @@ func (s *ImageDataTestSuite) SetupSuite() {
 			s.check(r)
 		}
 
-		for k, vv := range s.header {
-			for _, v := range vv {
-				rw.Header().Add(k, v)
-			}
-		}
+		httpheaders.CopyAll(s.header, rw.Header(), true)
 
 		data := s.data
 		if data == nil {

+ 1 - 5
imagefetcher/fetcher.go

@@ -80,11 +80,7 @@ func (f *Fetcher) BuildRequest(ctx context.Context, url string, header http.Head
 	req.Header.Set(httpheaders.UserAgent, config.UserAgent)
 
 	// Set headers
-	for k, v := range header {
-		if len(v) > 0 {
-			req.Header.Set(k, v[0])
-		}
-	}
+	httpheaders.CopyToRequest(header, req)
 
 	return &Request{f, req, cancel}, nil
 }