Procházet zdrojové kódy

Don't use io.ReadFull in io.Reader wrappers

DarthSim před 6 měsíci
rodič
revize
c95725b12f
3 změnil soubory, kde provedl 33 přidání a 11 odebrání
  1. 8 6
      asyncbuffer/buffer.go
  2. 6 5
      bufreader/bufreader.go
  3. 19 0
      ioutil/ioutil.go

+ 8 - 6
asyncbuffer/buffer.go

@@ -21,6 +21,8 @@ import (
 	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
+
+	"github.com/imgproxy/imgproxy/v3/ioutil"
 )
 
 const (
@@ -155,19 +157,19 @@ func (ab *AsyncBuffer) readChunks() {
 		}
 
 		// Read data into the chunk's buffer
-		// There is no way to guarantee that ReadFull will abort on context cancellation,
+		// There is no way to guarantee that ab.r.Read will abort on context cancellation,
 		// unfortunately, this is how golang works.
-		n, err := io.ReadFull(ab.r, chunk.buf)
+		n, err := ioutil.TryReadFull(ab.r, chunk.buf)
 
 		// If it's not the EOF, we need to store the error
-		if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
+		if err != nil && err != io.EOF {
 			ab.err.Store(err)
 			chunkPool.Put(chunk)
 			return
 		}
 
 		// No bytes were read (n == 0), we can return the chunk to the pool
-		if err == io.EOF || n == 0 {
+		if n == 0 {
 			chunkPool.Put(chunk)
 			return
 		}
@@ -178,9 +180,9 @@ func (ab *AsyncBuffer) readChunks() {
 		// Store the reference to the chunk in the AsyncBuffer
 		ab.addChunk(chunk)
 
-		// We got ErrUnexpectedEOF meaning that some bytes were read, but this is the
+		// EOF at this point means that some bytes were read, but this is the
 		// end of the stream, so we can stop reading
-		if err == io.ErrUnexpectedEOF {
+		if err == io.EOF {
 			return
 		}
 	}

+ 6 - 5
bufreader/bufreader.go

@@ -4,6 +4,8 @@ package bufreader
 
 import (
 	"io"
+
+	"github.com/imgproxy/imgproxy/v3/ioutil"
 )
 
 // ReadPeeker is an interface that combines io.Reader and a method to peek at the next n bytes
@@ -72,14 +74,13 @@ func (br *Reader) fetch(need int) error {
 	}
 
 	b := make([]byte, need-len(br.buf))
-	n, err := io.ReadFull(br.r, b)
-	if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
-		return err
-	}
+	n, err := ioutil.TryReadFull(br.r, b)
 
-	if err == io.EOF || err == io.ErrUnexpectedEOF {
+	if err == io.EOF {
 		// If we reached EOF, we mark the reader as finished
 		br.finished = true
+	} else if err != nil {
+		return err
 	}
 
 	if n > 0 {

+ 19 - 0
ioutil/ioutil.go

@@ -0,0 +1,19 @@
+package ioutil
+
+import "io"
+
+// TryReadFull acts like io.ReadFull with a couple of differences:
+//  1. It doesn't return io.ErrUnexpectedEOF if the reader returns less data than requested.
+//     Instead, it returns the number of bytes read and the error from the last read operation.
+//  2. It always returns the number of bytes read regardless of the error.
+func TryReadFull(r io.Reader, b []byte) (n int, err error) {
+	var nn int
+	toRead := len(b)
+
+	for n < toRead && err == nil {
+		nn, err = r.Read(b[n:])
+		n += nn
+	}
+
+	return n, err
+}