Browse Source

Add expected data length to asyncbuffer to allow seeking from the end without waiting for reader

DarthSim 1 tháng trước cách đây
mục cha
commit
0ddefe1b85
4 tập tin đã thay đổi với 95 bổ sung30 xóa
  1. 31 12
      asyncbuffer/buffer.go
  2. 57 14
      asyncbuffer/buffer_test.go
  3. 6 3
      asyncbuffer/reader.go
  4. 1 1
      imagedata/factory.go

+ 31 - 12
asyncbuffer/buffer.go

@@ -58,13 +58,14 @@ var chunkPool = sync.Pool{
 // AsyncBuffer is a wrapper around io.Reader that reads data in chunks
 // in background and allows reading from synchronously.
 type AsyncBuffer struct {
-	r io.ReadCloser // Upstream reader
+	r       io.ReadCloser // Upstream reader
+	dataLen int           // Expected length of the data in r, <= 0 means unknown length
 
 	chunks []*byteChunk // References to the chunks read from the upstream reader
 	mu     sync.RWMutex // Mutex on chunks slice
 
-	err atomic.Value // Error that occurred during reading
-	len atomic.Int64 // Total length of the data read
+	err       atomic.Value // Error that occurred during reading
+	bytesRead atomic.Int64 // Total length of the data read
 
 	finished atomic.Bool // Indicates that the buffer has finished reading
 	closed   atomic.Bool // Indicates that the buffer was closed
@@ -78,9 +79,14 @@ type AsyncBuffer struct {
 
 // New creates a new AsyncBuffer that reads from the given io.ReadCloser in background
 // and closes it when finished.
-func New(r io.ReadCloser, finishFn ...context.CancelFunc) *AsyncBuffer {
+//
+//	r - io.ReadCloser to read data from
+//	dataLen - expected length of the data in r, <= 0 means unknown length
+//	finishFn - optional functions to call when the buffer is finished reading
+func New(r io.ReadCloser, dataLen int, finishFn ...context.CancelFunc) *AsyncBuffer {
 	ab := &AsyncBuffer{
 		r:         r,
+		dataLen:   dataLen,
 		paused:    NewLatch(),
 		chunkCond: NewCond(),
 		finishFn:  finishFn,
@@ -102,7 +108,8 @@ func (ab *AsyncBuffer) callFinishFn() {
 	})
 }
 
-// addChunk adds a new chunk to the AsyncBuffer, increments len and signals that a chunk is ready
+// addChunk adds a new chunk to the AsyncBuffer, increments bytesRead
+// and signals that a chunk is ready
 func (ab *AsyncBuffer) addChunk(chunk *byteChunk) {
 	ab.mu.Lock()
 	defer ab.mu.Unlock()
@@ -115,7 +122,7 @@ func (ab *AsyncBuffer) addChunk(chunk *byteChunk) {
 
 	// Store the chunk, increase chunk size, increase length of the data read
 	ab.chunks = append(ab.chunks, chunk)
-	ab.len.Add(int64(len(chunk.data)))
+	ab.bytesRead.Add(int64(len(chunk.data)))
 
 	ab.chunkCond.Tick()
 }
@@ -132,14 +139,26 @@ func (ab *AsyncBuffer) readChunks() {
 			logrus.WithField("source", "asyncbuffer.AsyncBuffer.readChunks").Warningf("error closing upstream reader: %s", err)
 		}
 
+		if ab.bytesRead.Load() < int64(ab.dataLen) && ab.err.Load() == nil {
+			// If the reader has finished reading and we have not read enough data,
+			// set err to io.ErrUnexpectedEOF
+			ab.err.Store(io.ErrUnexpectedEOF)
+		}
+
 		ab.callFinishFn()
 	}()
 
+	r := ab.r.(io.Reader)
+	if ab.dataLen > 0 {
+		// If the data length is known, we read only that much data
+		r = io.LimitReader(r, int64(ab.dataLen))
+	}
+
 	// Stop reading if the reader is closed
 	for !ab.closed.Load() {
 		// In case we are trying to read data beyond threshold and we are paused,
 		// wait for pause to be released.
-		if ab.len.Load() >= pauseThreshold {
+		if ab.bytesRead.Load() >= pauseThreshold {
 			ab.paused.Wait()
 
 			// If the reader has been closed while waiting, we can stop reading
@@ -157,9 +176,9 @@ func (ab *AsyncBuffer) readChunks() {
 		}
 
 		// Read data into the chunk's buffer
-		// There is no way to guarantee that ab.r.Read will abort on context cancellation,
+		// There is no way to guarantee that r.Read will abort on context cancellation,
 		// unfortunately, this is how golang works.
-		n, err := ioutil.TryReadFull(ab.r, chunk.buf)
+		n, err := ioutil.TryReadFull(r, chunk.buf)
 
 		// If it's not the EOF, we need to store the error
 		if err != nil && err != io.EOF {
@@ -214,7 +233,7 @@ func (ab *AsyncBuffer) offsetAvailable(off int64) (bool, error) {
 
 	// In case the offset falls within the already read chunks, we can return immediately,
 	// even if error has occurred in the future
-	if off < ab.len.Load() {
+	if off < ab.bytesRead.Load() {
 		return true, nil
 	}
 
@@ -267,7 +286,7 @@ func (ab *AsyncBuffer) Wait() (int, error) {
 
 		// In case the reader is finished reading, we can return immediately
 		if ab.finished.Load() {
-			return int(ab.len.Load()), ab.Error()
+			return int(ab.bytesRead.Load()), ab.Error()
 		}
 
 		// Lock until the next chunk is ready
@@ -296,7 +315,7 @@ func (ab *AsyncBuffer) Error() error {
 // (eg. offset is beyond the end of the stream).
 func (ab *AsyncBuffer) readChunkAt(p []byte, off int64) int {
 	// If the chunk is not available, we return 0
-	if off >= ab.len.Load() {
+	if off >= ab.bytesRead.Load() {
 		return 0
 	}
 

+ 57 - 14
asyncbuffer/buffer_test.go

@@ -109,7 +109,7 @@ func generateSourceData(t *testing.T, size int) ([]byte, io.ReadSeekCloser) {
 func TestAsyncBufferReadAt(t *testing.T) {
 	// Let's use source buffer which is 4.5 chunks long
 	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 	defer asyncBuffer.Close()
 
 	asyncBuffer.Wait() // Wait for all chunks to be read since we're going to read all data
@@ -169,7 +169,7 @@ func TestAsyncBufferReadAt(t *testing.T) {
 // TestAsyncBufferRead tests reading from AsyncBuffer using ReadAt method
 func TestAsyncBufferReadAtSmallBuffer(t *testing.T) {
 	source, bytesReader := generateSourceData(t, 20)
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 	defer asyncBuffer.Close()
 
 	// First, let's read all the data
@@ -199,7 +199,7 @@ func TestAsyncBufferReader(t *testing.T) {
 	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
 
 	// Create an AsyncBuffer with the byte slice
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 	defer asyncBuffer.Close()
 
 	// Let's wait for all chunks to be read
@@ -267,7 +267,7 @@ func TestAsyncBufferClose(t *testing.T) {
 	_, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
 
 	// Create an AsyncBuffer with the byte slice
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 
 	reader1 := asyncBuffer.Reader()
 	reader2 := asyncBuffer.Reader()
@@ -294,7 +294,7 @@ func TestAsyncBufferReadAtErrAtSomePoint(t *testing.T) {
 	// Let's use source buffer which is 4.5 chunks long
 	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
 	slowReader := &erraticReader{reader: bytesReader, failAt: chunkSize*3 + 5} // fails at last chunk
-	asyncBuffer := New(slowReader)
+	asyncBuffer := New(slowReader, -1)
 	defer asyncBuffer.Close()
 
 	// Let's wait for all chunks to be read
@@ -327,7 +327,7 @@ func TestAsyncBufferReadAsync(t *testing.T) {
 	// Let's use source buffer which is 4.5 chunks long
 	source, bytesReader := generateSourceData(t, chunkSize*3)
 	blockingReader := newBlockingReader(bytesReader)
-	asyncBuffer := New(blockingReader)
+	asyncBuffer := New(blockingReader, -1)
 	defer asyncBuffer.Close()
 
 	// flush the first chunk to allow reading
@@ -367,11 +367,54 @@ func TestAsyncBufferReadAsync(t *testing.T) {
 	assert.Equal(t, 0, n)
 }
 
+// TestAsyncBufferWithDataLenAndExactReaderSize tests that AsyncBuffer doesn't
+// return an error when the expected data length is set and matches the reader size
+func TestAsyncBufferWithDataLenAndExactReaderSize(t *testing.T) {
+	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
+	asyncBuffer := New(bytesReader, len(source))
+	defer asyncBuffer.Close()
+
+	// Let's wait for all chunks to be read
+	size, err := asyncBuffer.Wait()
+	require.NoError(t, err, "AsyncBuffer failed to wait for all chunks")
+	assert.Equal(t, len(source), size)
+}
+
+// TestAsyncBufferWithDataLenAndShortReaderSize tests that AsyncBuffer returns
+// io.ErrUnexpectedEOF when the expected data length is set and the reader size
+// is shorter than the expected data length
+func TestAsyncBufferWithDataLenAndShortReaderSize(t *testing.T) {
+	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
+	asyncBuffer := New(bytesReader, len(source)+100) // 100 bytes more than the source
+	defer asyncBuffer.Close()
+
+	// Let's wait for all chunks to be read
+	size, err := asyncBuffer.Wait()
+	require.Equal(t, len(source), size)
+	require.ErrorIs(t, err, io.ErrUnexpectedEOF,
+		"AsyncBuffer should return io.ErrUnexpectedEOF when data length is longer than reader size")
+}
+
+// TestAsyncBufferWithDataLenAndLongerReaderSize tests that AsyncBuffer doesn't
+// read more data than specified by the expected data length and doesn't return an error
+// when the reader size is longer than the expected data length
+func TestAsyncBufferWithDataLenAndLongerReaderSize(t *testing.T) {
+	source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize)
+	asyncBuffer := New(bytesReader, len(source)-100) // 100 bytes less than the source
+	defer asyncBuffer.Close()
+
+	// Let's wait for all chunks to be read
+	size, err := asyncBuffer.Wait()
+	require.NoError(t, err, "AsyncBuffer failed to wait for all chunks")
+	assert.Equal(t, len(source)-100, size,
+		"AsyncBuffer should read only the specified amount of data when data length is set")
+}
+
 // TestAsyncBufferReadAllCompability tests that ReadAll methods works as expected
 func TestAsyncBufferReadAllCompability(t *testing.T) {
 	source, err := os.ReadFile("../testdata/test1.jpg")
 	require.NoError(t, err)
-	asyncBuffer := New(nopSeekCloser{bytes.NewReader(source)})
+	asyncBuffer := New(nopSeekCloser{bytes.NewReader(source)}, -1)
 	defer asyncBuffer.Close()
 
 	b, err := io.ReadAll(asyncBuffer.Reader())
@@ -381,7 +424,7 @@ func TestAsyncBufferReadAllCompability(t *testing.T) {
 
 func TestAsyncBufferThreshold(t *testing.T) {
 	_, bytesReader := generateSourceData(t, pauseThreshold*2)
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 	defer asyncBuffer.Close()
 
 	target := make([]byte, chunkSize)
@@ -391,12 +434,12 @@ func TestAsyncBufferThreshold(t *testing.T) {
 
 	// Ensure that buffer hits the pause threshold
 	require.Eventually(t, func() bool {
-		return asyncBuffer.len.Load() >= pauseThreshold
+		return asyncBuffer.bytesRead.Load() >= pauseThreshold
 	}, 300*time.Millisecond, 10*time.Millisecond)
 
 	// Ensure that buffer never reaches the end of the stream
 	require.Never(t, func() bool {
-		return asyncBuffer.len.Load() >= pauseThreshold*2-1
+		return asyncBuffer.bytesRead.Load() >= pauseThreshold*2-1
 	}, 300*time.Millisecond, 10*time.Millisecond)
 
 	// Let's hit the pause threshold
@@ -407,7 +450,7 @@ func TestAsyncBufferThreshold(t *testing.T) {
 
 	// Ensure that buffer never reaches the end of the stream
 	require.Never(t, func() bool {
-		return asyncBuffer.len.Load() >= pauseThreshold*2-1
+		return asyncBuffer.bytesRead.Load() >= pauseThreshold*2-1
 	}, 300*time.Millisecond, 10*time.Millisecond)
 
 	// Let's hit the pause threshold
@@ -421,13 +464,13 @@ func TestAsyncBufferThreshold(t *testing.T) {
 
 	// Ensure that buffer hits the end of the stream
 	require.Eventually(t, func() bool {
-		return asyncBuffer.len.Load() >= pauseThreshold*2
+		return asyncBuffer.bytesRead.Load() >= pauseThreshold*2
 	}, 300*time.Millisecond, 10*time.Millisecond)
 }
 
 func TestAsyncBufferThresholdInstantBeyondAccess(t *testing.T) {
 	_, bytesReader := generateSourceData(t, pauseThreshold*2)
-	asyncBuffer := New(bytesReader)
+	asyncBuffer := New(bytesReader, -1)
 	defer asyncBuffer.Close()
 
 	target := make([]byte, chunkSize)
@@ -437,6 +480,6 @@ func TestAsyncBufferThresholdInstantBeyondAccess(t *testing.T) {
 
 	// Ensure that buffer hits the end of the stream
 	require.Eventually(t, func() bool {
-		return asyncBuffer.len.Load() >= pauseThreshold*2
+		return asyncBuffer.bytesRead.Load() >= pauseThreshold*2
 	}, 300*time.Millisecond, 10*time.Millisecond)
 }

+ 6 - 3
asyncbuffer/reader.go

@@ -32,9 +32,12 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
 		r.pos += offset
 
 	case io.SeekEnd:
-		size, err := r.ab.Wait()
-		if err != nil {
-			return 0, err
+		size := r.ab.dataLen
+		if size <= 0 {
+			var err error
+			if size, err = r.ab.Wait(); err != nil {
+				return 0, err
+			}
 		}
 
 		r.pos = int64(size) + offset

+ 1 - 1
imagedata/factory.go

@@ -150,7 +150,7 @@ func downloadAsync(ctx context.Context, imageURL string, opts DownloadOptions) (
 		return nil, h, err
 	}
 
-	b := asyncbuffer.New(res.Body, opts.DownloadFinished)
+	b := asyncbuffer.New(res.Body, int(res.ContentLength), opts.DownloadFinished)
 
 	format, err := imagetype.Detect(b.Reader())
 	if err != nil {