|
@@ -1,7 +1,6 @@
|
|
|
package main
|
|
|
|
|
|
import (
|
|
|
- "bytes"
|
|
|
"context"
|
|
|
"crypto/tls"
|
|
|
"fmt"
|
|
@@ -21,7 +20,6 @@ import (
|
|
|
|
|
|
var (
|
|
|
downloadClient *http.Client
|
|
|
- imageTypeCtxKey = ctxKey("imageType")
|
|
|
imageDataCtxKey = ctxKey("imageData")
|
|
|
|
|
|
errSourceDimensionsTooBig = newError(422, "Source image dimensions are too big", "Invalid source image")
|
|
@@ -34,8 +32,21 @@ const msgSourceImageIsUnreachable = "Source image is unreachable"
|
|
|
|
|
|
var downloadBufPool *bufPool
|
|
|
|
|
|
+type imageData struct {
|
|
|
+ Data []byte
|
|
|
+ Type imageType
|
|
|
+
|
|
|
+ cancel context.CancelFunc
|
|
|
+}
|
|
|
+
|
|
|
+func (d *imageData) Close() {
|
|
|
+ if d.cancel != nil {
|
|
|
+ d.cancel()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
type limitReader struct {
|
|
|
- r io.ReadCloser
|
|
|
+ r io.Reader
|
|
|
left int
|
|
|
}
|
|
|
|
|
@@ -50,10 +61,6 @@ func (lr *limitReader) Read(p []byte) (n int, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (lr *limitReader) Close() error {
|
|
|
- return lr.r.Close()
|
|
|
-}
|
|
|
-
|
|
|
func initDownloading() {
|
|
|
transport := &http.Transport{
|
|
|
Proxy: http.ProxyFromEnvironment,
|
|
@@ -120,45 +127,56 @@ func checkTypeAndDimensions(r io.Reader) (imageType, error) {
|
|
|
return imgtype, nil
|
|
|
}
|
|
|
|
|
|
-func readAndCheckImage(ctx context.Context, res *http.Response) (context.Context, context.CancelFunc, error) {
|
|
|
- var contentLength int
|
|
|
-
|
|
|
- if res.ContentLength > 0 {
|
|
|
- contentLength = int(res.ContentLength)
|
|
|
-
|
|
|
- if conf.MaxSrcFileSize > 0 && contentLength > conf.MaxSrcFileSize {
|
|
|
- return ctx, func() {}, errSourceFileTooBig
|
|
|
- }
|
|
|
+func readAndCheckImage(r io.Reader, contentLength int) (*imageData, error) {
|
|
|
+ if conf.MaxSrcFileSize > 0 && contentLength > conf.MaxSrcFileSize {
|
|
|
+ return nil, errSourceFileTooBig
|
|
|
}
|
|
|
|
|
|
buf := downloadBufPool.Get(contentLength)
|
|
|
- cancel := func() {
|
|
|
- downloadBufPool.Put(buf)
|
|
|
+ cancel := func() { downloadBufPool.Put(buf) }
|
|
|
+
|
|
|
+ if conf.MaxSrcFileSize > 0 {
|
|
|
+ r = &limitReader{r: r, left: conf.MaxSrcFileSize}
|
|
|
}
|
|
|
|
|
|
- body := res.Body
|
|
|
+ imgtype, err := checkTypeAndDimensions(io.TeeReader(r, buf))
|
|
|
+ if err != nil {
|
|
|
+ cancel()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
|
|
|
- if conf.MaxSrcFileSize > 0 {
|
|
|
- body = &limitReader{r: body, left: conf.MaxSrcFileSize}
|
|
|
+ if _, err = buf.ReadFrom(r); err != nil {
|
|
|
+ cancel()
|
|
|
+ return nil, newError(404, err.Error(), msgSourceImageIsUnreachable)
|
|
|
}
|
|
|
|
|
|
- imgtype, err := checkTypeAndDimensions(io.TeeReader(body, buf))
|
|
|
+ return &imageData{buf.Bytes(), imgtype, cancel}, nil
|
|
|
+}
|
|
|
+
|
|
|
+func requestImage(imageURL string) (*http.Response, error) {
|
|
|
+ req, err := http.NewRequest("GET", imageURL, nil)
|
|
|
if err != nil {
|
|
|
- return ctx, cancel, err
|
|
|
+ return nil, newError(404, err.Error(), msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
}
|
|
|
|
|
|
- if _, err = buf.ReadFrom(body); err != nil {
|
|
|
- return ctx, cancel, newError(404, err.Error(), msgSourceImageIsUnreachable)
|
|
|
+ req.Header.Set("User-Agent", conf.UserAgent)
|
|
|
+
|
|
|
+ res, err := downloadClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return res, newError(404, err.Error(), msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
}
|
|
|
|
|
|
- ctx = context.WithValue(ctx, imageTypeCtxKey, imgtype)
|
|
|
- ctx = context.WithValue(ctx, imageDataCtxKey, buf)
|
|
|
+ if res.StatusCode != 200 {
|
|
|
+ body, _ := ioutil.ReadAll(res.Body)
|
|
|
+ msg := fmt.Sprintf("Can't download image; Status: %d; %s", res.StatusCode, string(body))
|
|
|
+ return res, newError(404, msg, msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
+ }
|
|
|
|
|
|
- return ctx, cancel, nil
|
|
|
+ return res, nil
|
|
|
}
|
|
|
|
|
|
func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, error) {
|
|
|
- url := getImageURL(ctx)
|
|
|
+ imageURL := getImageURL(ctx)
|
|
|
|
|
|
if newRelicEnabled {
|
|
|
newRelicCancel := startNewRelicSegment(ctx, "Downloading image")
|
|
@@ -169,34 +187,24 @@ func downloadImage(ctx context.Context) (context.Context, context.CancelFunc, er
|
|
|
defer startPrometheusDuration(prometheusDownloadDuration)()
|
|
|
}
|
|
|
|
|
|
- req, err := http.NewRequest("GET", url, nil)
|
|
|
- if err != nil {
|
|
|
- return ctx, func() {}, newError(404, err.Error(), msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
- }
|
|
|
-
|
|
|
- req.Header.Set("User-Agent", conf.UserAgent)
|
|
|
-
|
|
|
- res, err := downloadClient.Do(req)
|
|
|
+ res, err := requestImage(imageURL)
|
|
|
if res != nil {
|
|
|
defer res.Body.Close()
|
|
|
}
|
|
|
if err != nil {
|
|
|
- return ctx, func() {}, newError(404, err.Error(), msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
+ return ctx, func() {}, err
|
|
|
}
|
|
|
|
|
|
- if res.StatusCode != 200 {
|
|
|
- body, _ := ioutil.ReadAll(res.Body)
|
|
|
- msg := fmt.Sprintf("Can't download image; Status: %d; %s", res.StatusCode, string(body))
|
|
|
- return ctx, func() {}, newError(404, msg, msgSourceImageIsUnreachable).MarkAsUnexpected()
|
|
|
+ imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength))
|
|
|
+ if err != nil {
|
|
|
+ return ctx, func() {}, err
|
|
|
}
|
|
|
|
|
|
- return readAndCheckImage(ctx, res)
|
|
|
-}
|
|
|
+ ctx = context.WithValue(ctx, imageDataCtxKey, imgdata)
|
|
|
|
|
|
-func getImageType(ctx context.Context) imageType {
|
|
|
- return ctx.Value(imageTypeCtxKey).(imageType)
|
|
|
+ return ctx, imgdata.Close, err
|
|
|
}
|
|
|
|
|
|
-func getImageData(ctx context.Context) *bytes.Buffer {
|
|
|
- return ctx.Value(imageDataCtxKey).(*bytes.Buffer)
|
|
|
+func getImageData(ctx context.Context) *imageData {
|
|
|
+ return ctx.Value(imageDataCtxKey).(*imageData)
|
|
|
}
|