1
0
Эх сурвалжийг харах

Add cache control headers to 304 response

DarthSim 3 жил өмнө
parent
commit
80331cd94e

+ 1 - 0
CHANGELOG.md

@@ -6,6 +6,7 @@
 
 ### Change
 - `dpr` processing option doesn't enlarge image unless `enlarge` is true.
+- `304 Not Modified` responses includes `Cache-Control`, `Expires`, and `Vary` headers.
 
 ### Fix
 - Fix Client Hints behavior. `Width` is physical size, so we should divide it by `DPR` value.

+ 3 - 10
ierrors/errors.go

@@ -62,13 +62,13 @@ func Wrap(err error, skip int) *Error {
 	return NewUnexpected(err.Error(), skip+1)
 }
 
-func WrapWithMessage(err error, skip int, msg string) *Error {
+func WrapWithPrefix(err error, skip int, prefix string) *Error {
 	if ierr, ok := err.(*Error); ok {
 		newErr := *ierr
-		ierr.Message = msg
+		newErr.Message = fmt.Sprintf("%s: %s", prefix, ierr.Message)
 		return &newErr
 	}
-	return NewUnexpected(err.Error(), skip+1)
+	return NewUnexpected(fmt.Sprintf("%s: %s", prefix, err), skip+1)
 }
 
 func callers(skip int) []uintptr {
@@ -87,10 +87,3 @@ func formatStack(stack []uintptr) string {
 
 	return strings.Join(lines, "\n")
 }
-
-func StatusCode(err error) int {
-	if ierr, ok := err.(*Error); ok {
-		return ierr.StatusCode
-	}
-	return 0
-}

+ 23 - 9
imagedata/download.go

@@ -29,12 +29,19 @@ var (
 
 	// For tests
 	redirectAllRequestsTo string
-
-	ErrNotModified = ierrors.New(http.StatusNotModified, "Not Modified", "Not Modified")
 )
 
 const msgSourceImageIsUnreachable = "Source image is unreachable"
 
+type ErrorNotModified struct {
+	Message string
+	Headers map[string]string
+}
+
+func (e *ErrorNotModified) Error() string {
+	return e.Message
+}
+
 func initDownloading() error {
 	transport := &http.Transport{
 		Proxy:               http.ProxyFromEnvironment,
@@ -84,6 +91,18 @@ func initDownloading() error {
 	return nil
 }
 
+func headersToStore(res *http.Response) map[string]string {
+	m := make(map[string]string)
+
+	for _, h := range imageHeadersToStore {
+		if val := res.Header.Get(h); len(val) != 0 {
+			m[h] = val
+		}
+	}
+
+	return m
+}
+
 func requestImage(imageURL string, header http.Header) (*http.Response, error) {
 	req, err := http.NewRequest("GET", imageURL, nil)
 	if err != nil {
@@ -104,7 +123,7 @@ func requestImage(imageURL string, header http.Header) (*http.Response, error) {
 	}
 
 	if res.StatusCode == http.StatusNotModified {
-		return nil, ErrNotModified
+		return nil, &ErrorNotModified{Message: "Not Modified", Headers: headersToStore(res)}
 	}
 
 	if res.StatusCode != 200 {
@@ -152,12 +171,7 @@ func download(imageURL string, header http.Header) (*ImageData, error) {
 		return nil, err
 	}
 
-	imgdata.Headers = make(map[string]string)
-	for _, h := range imageHeadersToStore {
-		if val := res.Header.Get(h); len(val) != 0 {
-			imgdata.Headers[h] = val
-		}
-	}
+	imgdata.Headers = headersToStore(res)
 
 	return imgdata, nil
 }

+ 5 - 1
imagedata/image_data.go

@@ -130,7 +130,11 @@ func FromFile(path, desc string) (*ImageData, error) {
 func Download(imageURL, desc string, header http.Header) (*ImageData, error) {
 	imgdata, err := download(imageURL, header)
 	if err != nil {
-		return nil, ierrors.WrapWithMessage(err, 1, fmt.Sprintf("Can't download %s: %s", desc, err))
+		if nmErr, ok := err.(*ErrorNotModified); ok {
+			nmErr.Message = fmt.Sprintf("Can't download %s: %s", desc, nmErr.Message)
+			return nil, nmErr
+		}
+		return nil, ierrors.WrapWithPrefix(err, 1, fmt.Sprintf("Can't download %s", desc))
 	}
 
 	return imgdata, nil

+ 42 - 32
processing_handler.go

@@ -48,35 +48,14 @@ func initProcessingHandler() {
 	headerVaryValue = strings.Join(vary, ", ")
 }
 
-func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, resultData *imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData *imagedata.ImageData) {
-	var contentDisposition string
-	if len(po.Filename) > 0 {
-		contentDisposition = resultData.Type.ContentDisposition(po.Filename)
-	} else {
-		contentDisposition = resultData.Type.ContentDispositionFromURL(originURL)
-	}
-
-	rw.Header().Set("Content-Type", resultData.Type.Mime())
-	rw.Header().Set("Content-Disposition", contentDisposition)
-
-	if po.Dpr != 1 {
-		rw.Header().Set("Content-DPR", strconv.FormatFloat(po.Dpr, 'f', 2, 32))
-	}
-
-	if config.SetCanonicalHeader {
-		if strings.HasPrefix(originURL, "https://") || strings.HasPrefix(originURL, "http://") {
-			linkHeader := fmt.Sprintf(`<%s>; rel="canonical"`, originURL)
-			rw.Header().Set("Link", linkHeader)
-		}
-	}
-
+func setCacheControl(rw http.ResponseWriter, originHeaders map[string]string) {
 	var cacheControl, expires string
 
-	if config.CacheControlPassthrough && originData.Headers != nil {
-		if val, ok := originData.Headers["Cache-Control"]; ok {
+	if config.CacheControlPassthrough && originHeaders != nil {
+		if val, ok := originHeaders["Cache-Control"]; ok {
 			cacheControl = val
 		}
-		if val, ok := originData.Headers["Expires"]; ok {
+		if val, ok := originHeaders["Expires"]; ok {
 			expires = val
 		}
 	}
@@ -92,10 +71,38 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, res
 	if len(expires) > 0 {
 		rw.Header().Set("Expires", expires)
 	}
+}
 
+func setVary(rw http.ResponseWriter) {
 	if len(headerVaryValue) > 0 {
 		rw.Header().Set("Vary", headerVaryValue)
 	}
+}
+
+func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, resultData *imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData *imagedata.ImageData) {
+	var contentDisposition string
+	if len(po.Filename) > 0 {
+		contentDisposition = resultData.Type.ContentDisposition(po.Filename)
+	} else {
+		contentDisposition = resultData.Type.ContentDispositionFromURL(originURL)
+	}
+
+	rw.Header().Set("Content-Type", resultData.Type.Mime())
+	rw.Header().Set("Content-Disposition", contentDisposition)
+
+	if po.Dpr != 1 {
+		rw.Header().Set("Content-DPR", strconv.FormatFloat(po.Dpr, 'f', 2, 32))
+	}
+
+	if config.SetCanonicalHeader {
+		if strings.HasPrefix(originURL, "https://") || strings.HasPrefix(originURL, "http://") {
+			linkHeader := fmt.Sprintf(`<%s>; rel="canonical"`, originURL)
+			rw.Header().Set("Link", linkHeader)
+		}
+	}
+
+	setCacheControl(rw, originData.Headers)
+	setVary(rw)
 
 	if config.EnableDebugHeaders {
 		rw.Header().Set("X-Origin-Content-Length", strconv.Itoa(len(originData.Data)))
@@ -120,7 +127,10 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, res
 	)
 }
 
-func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string) {
+func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders map[string]string) {
+	setCacheControl(rw, originHeaders)
+	setVary(rw)
+
 	rw.WriteHeader(304)
 	router.LogResponse(
 		reqID, r, 304, nil,
@@ -209,14 +219,14 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 		defer metrics.StartDownloadingSegment(ctx)()
 		return imagedata.Download(imageURL, "source image", imgRequestHeader)
 	}()
-	switch {
-	case err == nil:
+
+	if err == nil {
 		defer originData.Close()
-	case ierrors.StatusCode(err) == http.StatusNotModified:
+	} else if nmErr, ok := err.(*imagedata.ErrorNotModified); ok && config.ETagEnabled {
 		rw.Header().Set("ETag", etagHandler.GenerateExpectedETag())
-		respondWithNotModified(reqID, r, rw, po, imageURL)
+		respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers)
 		return
-	default:
+	} else {
 		if ierr, ok := err.(*ierrors.Error); !ok || ierr.Unexpected {
 			errorreport.Report(err, r)
 		}
@@ -240,7 +250,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set("ETag", etagHandler.GenerateActualETag())
 
 		if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
-			respondWithNotModified(reqID, r, rw, po, imageURL)
+			respondWithNotModified(reqID, r, rw, po, imageURL, originData.Headers)
 			return
 		}
 	}