Browse Source

use "clean" HTTP transport; Use context for downloading timeout control

DarthSim 2 years ago
parent
commit
24f4d43a0f
7 changed files with 105 additions and 51 deletions
  1. 46 26
      imagedata/download.go
  2. 43 0
      imagedata/error.go
  3. 4 4
      imagedata/image_data.go
  4. 6 3
      imagedata/read.go
  5. 0 14
      imagedata/timeout.go
  6. 1 1
      processing_handler.go
  7. 5 3
      stream.go

+ 46 - 26
imagedata/download.go

@@ -2,9 +2,11 @@ package imagedata
 
 import (
 	"compress/gzip"
+	"context"
 	"crypto/tls"
 	"fmt"
 	"io"
+	"net"
 	"net/http"
 	"net/http/cookiejar"
 	"time"
@@ -55,16 +57,25 @@ func (e *ErrorNotModified) Error() string {
 }
 
 func initDownloading() error {
-	transport := http.DefaultTransport.(*http.Transport).Clone()
-	transport.DisableCompression = true
-
-	if config.ClientKeepAliveTimeout > 0 {
-		transport.MaxIdleConns = config.Concurrency
-		transport.MaxIdleConnsPerHost = config.Concurrency
-		transport.IdleConnTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
-	} else {
-		transport.MaxIdleConns = 0
-		transport.MaxIdleConnsPerHost = 0
+	transport := &http.Transport{
+		Proxy: http.ProxyFromEnvironment,
+		DialContext: (&net.Dialer{
+			Timeout:   30 * time.Second,
+			KeepAlive: 30 * time.Second,
+			DualStack: true,
+		}).DialContext,
+		MaxIdleConns:          100,
+		MaxIdleConnsPerHost:   config.Concurrency + 1,
+		IdleConnTimeout:       time.Duration(config.ClientKeepAliveTimeout) * time.Second,
+		TLSHandshakeTimeout:   10 * time.Second,
+		ExpectContinueTimeout: 1 * time.Second,
+		ForceAttemptHTTP2:     true,
+		DisableCompression:    true,
+	}
+
+	if config.ClientKeepAliveTimeout <= 0 {
+		transport.MaxIdleConnsPerHost = -1
+		transport.DisableKeepAlives = true
 	}
 
 	if config.IgnoreSslVerification {
@@ -113,7 +124,6 @@ func initDownloading() error {
 	}
 
 	downloadClient = &http.Client{
-		Timeout:   time.Duration(config.DownloadTimeout) * time.Second,
 		Transport: transport,
 		CheckRedirect: func(req *http.Request, via []*http.Request) error {
 			redirects := len(via)
@@ -139,14 +149,18 @@ func headersToStore(res *http.Response) map[string]string {
 	return m
 }
 
-func BuildImageRequest(imageURL string, header http.Header, jar *cookiejar.Jar) (*http.Request, error) {
-	req, err := http.NewRequest("GET", imageURL, nil)
+func BuildImageRequest(ctx context.Context, imageURL string, header http.Header, jar *cookiejar.Jar) (*http.Request, context.CancelFunc, error) {
+	reqCtx, reqCancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second)
+
+	req, err := http.NewRequestWithContext(reqCtx, "GET", imageURL, nil)
 	if err != nil {
-		return nil, ierrors.New(404, err.Error(), msgSourceImageIsUnreachable)
+		reqCancel()
+		return nil, func() {}, ierrors.New(404, err.Error(), msgSourceImageIsUnreachable)
 	}
 
 	if _, ok := enabledSchemes[req.URL.Scheme]; !ok {
-		return nil, ierrors.New(
+		reqCancel()
+		return nil, func() {}, ierrors.New(
 			404,
 			fmt.Sprintf("Unknown scheme: %s", req.URL.Scheme),
 			msgSourceImageIsUnreachable,
@@ -167,37 +181,41 @@ func BuildImageRequest(imageURL string, header http.Header, jar *cookiejar.Jar)
 		}
 	}
 
-	return req, nil
+	return req, reqCancel, nil
 }
 
 func SendRequest(req *http.Request) (*http.Response, error) {
 	res, err := downloadClient.Do(req)
 	if err != nil {
-		return nil, ierrors.New(500, checkTimeoutErr(err).Error(), msgSourceImageIsUnreachable)
+		return nil, wrapError(err)
 	}
 
 	return res, nil
 }
 
-func requestImage(imageURL string, opts DownloadOptions) (*http.Response, error) {
-	req, err := BuildImageRequest(imageURL, opts.Header, opts.CookieJar)
+func requestImage(ctx context.Context, imageURL string, opts DownloadOptions) (*http.Response, context.CancelFunc, error) {
+	req, reqCancel, err := BuildImageRequest(ctx, imageURL, opts.Header, opts.CookieJar)
 	if err != nil {
-		return nil, err
+		reqCancel()
+		return nil, func() {}, err
 	}
 
 	res, err := SendRequest(req)
 	if err != nil {
-		return nil, err
+		reqCancel()
+		return nil, func() {}, err
 	}
 
 	if res.StatusCode == http.StatusNotModified {
 		res.Body.Close()
-		return nil, &ErrorNotModified{Message: "Not Modified", Headers: headersToStore(res)}
+		reqCancel()
+		return nil, func() {}, &ErrorNotModified{Message: "Not Modified", Headers: headersToStore(res)}
 	}
 
 	if res.StatusCode != 200 {
 		body, _ := io.ReadAll(res.Body)
 		res.Body.Close()
+		reqCancel()
 
 		status := 404
 		if res.StatusCode >= 500 {
@@ -205,19 +223,21 @@ func requestImage(imageURL string, opts DownloadOptions) (*http.Response, error)
 		}
 
 		msg := fmt.Sprintf("Status: %d; %s", res.StatusCode, string(body))
-		return nil, ierrors.New(status, msg, msgSourceImageIsUnreachable)
+		return nil, func() {}, ierrors.New(status, msg, msgSourceImageIsUnreachable)
 	}
 
-	return res, nil
+	return res, reqCancel, nil
 }
 
-func download(imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
+func download(ctx context.Context, imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
 	// We use this for testing
 	if len(redirectAllRequestsTo) > 0 {
 		imageURL = redirectAllRequestsTo
 	}
 
-	res, err := requestImage(imageURL, opts)
+	res, reqCancel, err := requestImage(ctx, imageURL, opts)
+	defer reqCancel()
+
 	if res != nil {
 		defer res.Body.Close()
 	}

+ 43 - 0
imagedata/error.go

@@ -0,0 +1,43 @@
+package imagedata
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net/http"
+
+	"github.com/imgproxy/imgproxy/v3/ierrors"
+)
+
+type httpError interface {
+	Timeout() bool
+}
+
+func wrapError(err error) error {
+	isTimeout := false
+
+	if errors.Is(err, context.Canceled) {
+		return ierrors.New(
+			499,
+			fmt.Sprintf("The image request is cancelled: %s", err),
+			msgSourceImageIsUnreachable,
+		)
+	} else if errors.Is(err, context.DeadlineExceeded) {
+		isTimeout = true
+	} else if httpErr, ok := err.(httpError); ok {
+		isTimeout = httpErr.Timeout()
+	}
+
+	if !isTimeout {
+		return err
+	}
+
+	ierr := ierrors.New(
+		http.StatusGatewayTimeout,
+		fmt.Sprintf("The image request timed out: %s", err),
+		msgSourceImageIsUnreachable,
+	)
+	ierr.Unexpected = true
+
+	return ierr
+}

+ 4 - 4
imagedata/image_data.go

@@ -70,7 +70,7 @@ func loadWatermark() (err error) {
 	}
 
 	if len(config.WatermarkURL) > 0 {
-		Watermark, err = Download(config.WatermarkURL, "watermark", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
+		Watermark, err = Download(context.Background(), config.WatermarkURL, "watermark", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
 		return
 	}
 
@@ -84,7 +84,7 @@ func loadFallbackImage() (err error) {
 	case len(config.FallbackImagePath) > 0:
 		FallbackImage, err = FromFile(config.FallbackImagePath, "fallback image", security.DefaultOptions())
 	case len(config.FallbackImageURL) > 0:
-		FallbackImage, err = Download(config.FallbackImageURL, "fallback image", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
+		FallbackImage, err = Download(context.Background(), config.FallbackImageURL, "fallback image", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
 	default:
 		FallbackImage, err = nil, nil
 	}
@@ -130,8 +130,8 @@ func FromFile(path, desc string, secopts security.Options) (*ImageData, error) {
 	return imgdata, nil
 }
 
-func Download(imageURL, desc string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
-	imgdata, err := download(imageURL, opts, secopts)
+func Download(ctx context.Context, imageURL, desc string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
+	imgdata, err := download(ctx, imageURL, opts, secopts)
 	if err != nil {
 		if nmErr, ok := err.(*ErrorNotModified); ok {
 			nmErr.Message = fmt.Sprintf("Can't download %s: %s", desc, nmErr.Message)

+ 6 - 3
imagedata/read.go

@@ -42,13 +42,14 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
 			return nil, ErrSourceImageTypeNotSupported
 		}
 
-		return nil, checkTimeoutErr(err)
+		return nil, wrapError(err)
 	}
 
 	if err = security.CheckDimensions(meta.Width(), meta.Height(), 1, secopts); err != nil {
 		buf.Reset()
 		cancel()
-		return nil, err
+
+		return nil, wrapError(err)
 	}
 
 	if contentLength > buf.Cap() {
@@ -56,8 +57,10 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
 	}
 
 	if err = br.Flush(); err != nil {
+		buf.Reset()
 		cancel()
-		return nil, checkTimeoutErr(err)
+
+		return nil, wrapError(err)
 	}
 
 	return &ImageData{

+ 0 - 14
imagedata/timeout.go

@@ -1,14 +0,0 @@
-package imagedata
-
-import "errors"
-
-type httpError interface {
-	Timeout() bool
-}
-
-func checkTimeoutErr(err error) error {
-	if httpErr, ok := err.(httpError); ok && httpErr.Timeout() {
-		return errors.New("The image request timed out")
-	}
-	return err
-}

+ 1 - 1
processing_handler.go

@@ -303,7 +303,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 			checkErr(ctx, "download", err)
 		}
 
-		return imagedata.Download(imageURL, "source image", downloadOpts, po.SecurityOptions)
+		return imagedata.Download(ctx, imageURL, "source image", downloadOpts, po.SecurityOptions)
 	}()
 
 	if err == nil {

+ 5 - 3
stream.go

@@ -71,14 +71,16 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
 		checkErr(ctx, "streaming", err)
 	}
 
-	req, err := imagedata.BuildImageRequest(imageURL, imgRequestHeader, cookieJar)
+	req, reqCancel, err := imagedata.BuildImageRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
+	defer reqCancel()
 	checkErr(ctx, "streaming", err)
 
 	res, err := imagedata.SendRequest(req)
+	if res != nil {
+		defer res.Body.Close()
+	}
 	checkErr(ctx, "streaming", err)
 
-	defer res.Body.Close()
-
 	for _, k := range streamRespHeaders {
 		vv := res.Header.Values(k)
 		for _, v := range vv {