Browse Source

DRY image downloading

DarthSim 4 years ago
parent
commit
27dbed077f
3 changed files with 55 additions and 45 deletions
  1. 41 28
      download.go
  2. 5 12
      image_data.go
  3. 9 5
      processing_handler.go

+ 41 - 28
download.go

@@ -16,9 +16,12 @@ import (
 var (
 	downloadClient *http.Client
 
-	imageDataCtxKey          = ctxKey("imageData")
-	cacheControlHeaderCtxKey = ctxKey("cacheControlHeader")
-	expiresHeaderCtxKey      = ctxKey("expiresHeader")
+	imageDataCtxKey = ctxKey("imageData")
+
+	imageHeadersToStore = []string{
+		"Cache-Control",
+		"Expires",
+	}
 
 	errSourceResolutionTooBig      = newError(422, "Source image resolution is too big", "Invalid source image")
 	errSourceFileTooBig            = newError(422, "Source image file is too big", "Invalid source image")
@@ -147,10 +150,14 @@ func readAndCheckImage(r io.Reader, contentLength int) (*imageData, error) {
 
 	if _, err = buf.ReadFrom(r); err != nil {
 		cancel()
-		return nil, newError(404, err.Error(), msgSourceImageIsUnreachable)
+		return nil, newError(404, err.Error(), msgSourceImageIsUnreachable).SetUnexpected(conf.ReportDownloadingErrors)
 	}
 
-	return &imageData{buf.Bytes(), imgtype, cancel}, nil
+	return &imageData{
+		Data:   buf.Bytes(),
+		Type:   imgtype,
+		cancel: cancel,
+	}, nil
 }
 
 func requestImage(imageURL string) (*http.Response, error) {
@@ -168,6 +175,8 @@ func requestImage(imageURL string) (*http.Response, error) {
 
 	if res.StatusCode != 200 {
 		body, _ := ioutil.ReadAll(res.Body)
+		res.Body.Close()
+
 		msg := fmt.Sprintf("Can't download image; Status: %d; %s", res.StatusCode, string(body))
 		return res, newError(404, msg, msgSourceImageIsUnreachable).SetUnexpected(conf.ReportDownloadingErrors)
 	}
@@ -175,7 +184,31 @@ func requestImage(imageURL string) (*http.Response, error) {
 	return res, nil
 }
 
-func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, error) {
+func downloadImage(imageURL string) (*imageData, error) {
+	res, err := requestImage(imageURL)
+	if res != nil {
+		defer res.Body.Close()
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength))
+	if err != nil {
+		return nil, err
+	}
+
+	imgdata.Headers = make(map[string]string)
+	for _, h := range imageHeadersToStore {
+		if val := res.Header.Get(h); len(val) != 0 {
+			imgdata.Headers[h] = val
+		}
+	}
+
+	return imgdata, nil
+}
+
+func downloadImageCtx(ctx context.Context) (context.Context, context.CancelFunc, error) {
 	imageURL := getImageURL(ctx)
 
 	if newRelicEnabled {
@@ -187,36 +220,16 @@ func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, er
 		defer startPrometheusDuration(prometheusDownloadDuration)()
 	}
 
-	res, err := requestImage(imageURL)
-	if res != nil {
-		defer res.Body.Close()
-	}
-	if err != nil {
-		return ctx, func() {}, err
-	}
-
-	imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength))
+	imgdata, err := downloadImage(imageURL)
 	if err != nil {
 		return ctx, func() {}, err
 	}
 
 	ctx = context.WithValue(ctx, imageDataCtxKey, imgdata)
-	ctx = context.WithValue(ctx, cacheControlHeaderCtxKey, res.Header.Get("Cache-Control"))
-	ctx = context.WithValue(ctx, expiresHeaderCtxKey, res.Header.Get("Expires"))
 
-	return ctx, imgdata.Close, err
+	return ctx, imgdata.Close, nil
 }
 
 func getImageData(ctx context.Context) *imageData {
 	return ctx.Value(imageDataCtxKey).(*imageData)
 }
-
-func getCacheControlHeader(ctx context.Context) string {
-	str, _ := ctx.Value(cacheControlHeaderCtxKey).(string)
-	return str
-}
-
-func getExpiresHeader(ctx context.Context) string {
-	str, _ := ctx.Value(expiresHeaderCtxKey).(string)
-	return str
-}

+ 5 - 12
image_data.go

@@ -9,8 +9,9 @@ import (
 )
 
 type imageData struct {
-	Data []byte
-	Type imageType
+	Data    []byte
+	Type    imageType
+	Headers map[string]string
 
 	cancel context.CancelFunc
 }
@@ -87,18 +88,10 @@ func fileImageData(path, desc string) (*imageData, error) {
 }
 
 func remoteImageData(imageURL, desc string) (*imageData, error) {
-	res, err := requestImage(imageURL)
-	if res != nil {
-		defer res.Body.Close()
-	}
-	if err != nil {
-		return nil, fmt.Errorf("Can't download %s: %s", desc, err)
-	}
-
-	imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength))
+	imgdata, err := downloadImage(imageURL)
 	if err != nil {
 		return nil, fmt.Errorf("Can't download %s: %s", desc, err)
 	}
 
-	return imgdata, err
+	return imgdata, nil
 }

+ 9 - 5
processing_handler.go

@@ -42,6 +42,7 @@ func initProcessingHandler() error {
 
 func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, data []byte) {
 	po := getProcessingOptions(ctx)
+	imgdata := getImageData(ctx)
 
 	var contentDisposition string
 	if len(po.Filename) > 0 {
@@ -63,9 +64,13 @@ func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw htt
 
 	var cacheControl, expires string
 
-	if conf.CacheControlPassthrough {
-		cacheControl = getCacheControlHeader(ctx)
-		expires = getExpiresHeader(ctx)
+	if conf.CacheControlPassthrough && imgdata.Headers != nil {
+		if val, ok := imgdata.Headers["Cache-Control"]; ok {
+			cacheControl = val
+		}
+		if val, ok := imgdata.Headers["Expires"]; ok {
+			expires = val
+		}
 	}
 
 	if len(cacheControl) == 0 && len(expires) == 0 {
@@ -85,7 +90,6 @@ func respondWithImage(ctx context.Context, reqID string, r *http.Request, rw htt
 	}
 
 	if conf.EnableDebugHeaders {
-		imgdata := getImageData(ctx)
 		rw.Header().Set("X-Origin-Content-Length", strconv.Itoa(len(imgdata.Data)))
 	}
 
@@ -135,7 +139,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 		panic(err)
 	}
 
-	ctx, downloadcancel, err := downloadImage(ctx)
+	ctx, downloadcancel, err := downloadImageCtx(ctx)
 	defer downloadcancel()
 	if err != nil {
 		if newRelicEnabled {