Переглянути джерело

Wrap custom transport bodies with a context reader

DarthSim 2 роки тому
батько
коміт
1403886840

+ 44 - 0
ctxreader/ctxreader.go

@@ -0,0 +1,44 @@
+package ctxreader
+
+import (
+	"context"
+	"io"
+	"sync"
+	"sync/atomic"
+)
+
+type ctxReader struct {
+	r         io.ReadCloser
+	err       atomic.Value
+	closeOnce sync.Once
+}
+
+func (r *ctxReader) Read(p []byte) (int, error) {
+	if err := r.err.Load(); err != nil {
+		return 0, err.(error)
+	}
+	return r.r.Read(p)
+}
+
+func (r *ctxReader) Close() (err error) {
+	r.closeOnce.Do(func() { err = r.r.Close() })
+	return
+}
+
+func New(ctx context.Context, r io.ReadCloser, closeOnDone bool) io.ReadCloser {
+	if ctx.Done() == nil {
+		return r
+	}
+
+	ctxr := ctxReader{r: r}
+
+	go func(ctx context.Context) {
+		<-ctx.Done()
+		ctxr.err.Store(ctx.Err())
+		if closeOnDone {
+			ctxr.closeOnce.Do(func() { ctxr.r.Close() })
+		}
+	}(ctx)
+
+	return &ctxr
+}

+ 114 - 0
ctxreader/ctxreader_test.go

@@ -0,0 +1,114 @@
+package ctxreader
+
+import (
+	"context"
+	"crypto/rand"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+	"github.com/stretchr/testify/suite"
+)
+
+type testReader struct {
+	closed bool
+}
+
+func (r *testReader) Read(p []byte) (int, error) {
+	return rand.Reader.Read(p)
+}
+
+func (r *testReader) Close() error {
+	r.closed = true
+	return nil
+}
+
+type CtxReaderTestSuite struct {
+	suite.Suite
+}
+
+func (s *CtxReaderTestSuite) TestReadUntilCanceled() {
+	ctx, cancel := context.WithCancel(context.Background())
+
+	r := New(ctx, &testReader{}, false)
+	p := make([]byte, 1024)
+
+	_, err := r.Read(p)
+	require.Nil(s.T(), err)
+
+	cancel()
+	time.Sleep(time.Second)
+
+	_, err = r.Read(p)
+	require.Equal(s.T(), err, context.Canceled)
+}
+
+func (s *CtxReaderTestSuite) TestReturnOriginalOnBackgroundContext() {
+	rr := &testReader{}
+	r := New(context.Background(), rr, false)
+
+	require.Equal(s.T(), rr, r)
+}
+
+func (s *CtxReaderTestSuite) TestClose() {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	rr := &testReader{}
+	New(ctx, rr, true).Close()
+
+	require.True(s.T(), rr.closed)
+}
+
+func (s *CtxReaderTestSuite) TestCloseOnCancel() {
+	ctx, cancel := context.WithCancel(context.Background())
+
+	rr := &testReader{}
+	New(ctx, rr, true)
+
+	cancel()
+	time.Sleep(time.Second)
+
+	require.True(s.T(), rr.closed)
+}
+
+func (s *CtxReaderTestSuite) TestDontCloseOnCancel() {
+	ctx, cancel := context.WithCancel(context.Background())
+
+	rr := &testReader{}
+	New(ctx, rr, false)
+
+	cancel()
+	time.Sleep(time.Second)
+
+	require.False(s.T(), rr.closed)
+}
+
+func TestCtxReader(t *testing.T) {
+	suite.Run(t, new(CtxReaderTestSuite))
+}
+
+func BenchmarkRawReader(b *testing.B) {
+	r := testReader{}
+
+	b.ResetTimer()
+
+	p := make([]byte, 1024)
+	for i := 0; i < b.N; i++ {
+		r.Read(p)
+	}
+}
+
+func BenchmarkCtxReader(b *testing.B) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
+	defer cancel()
+
+	r := New(ctx, &testReader{}, true)
+
+	b.ResetTimer()
+
+	p := make([]byte, 1024)
+	for i := 0; i < b.N; i++ {
+		r.Read(p)
+	}
+}

+ 2 - 1
transport/fs/fs.go

@@ -14,6 +14,7 @@ import (
 	"strings"
 
 	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/ctxreader"
 	"github.com/imgproxy/imgproxy/v3/httprange"
 )
 
@@ -103,7 +104,7 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error)
 		ProtoMinor:    0,
 		Header:        header,
 		ContentLength: size,
-		Body:          body,
+		Body:          ctxreader.New(req.Context(), body, true),
 		Close:         true,
 		Request:       req,
 	}, nil

+ 2 - 1
transport/gcs/gcs.go

@@ -12,6 +12,7 @@ import (
 	"google.golang.org/api/option"
 
 	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/ctxreader"
 	"github.com/imgproxy/imgproxy/v3/httprange"
 )
 
@@ -141,7 +142,7 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		ProtoMinor:    0,
 		Header:        header,
 		ContentLength: reader.Attrs.Size,
-		Body:          reader,
+		Body:          ctxreader.New(req.Context(), reader, true),
 		Close:         true,
 		Request:       req,
 	}, nil

+ 2 - 1
transport/swift/swift.go

@@ -12,6 +12,7 @@ import (
 	"github.com/ncw/swift/v2"
 
 	"github.com/imgproxy/imgproxy/v3/config"
+	"github.com/imgproxy/imgproxy/v3/ctxreader"
 )
 
 type transport struct {
@@ -105,7 +106,7 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error)
 		ProtoMajor: 1,
 		ProtoMinor: 0,
 		Header:     header,
-		Body:       object,
+		Body:       ctxreader.New(req.Context(), object, true),
 		Close:      true,
 		Request:    req,
 	}, nil