Explorar o código

ticker + fixes

Viktor Sokolov hai 2 meses
pai
achega
0e32ed9a37
Modificáronse 5 ficheiros con 320 adicións e 131 borrados
  1. 52 130
      asyncbuffer/buffer.go
  2. 0 1
      asyncbuffer/buffer_test.go
  3. 51 0
      asyncbuffer/reader.go
  4. 59 0
      asyncbuffer/ticker.go
  5. 158 0
      asyncbuffer/ticker_test.go

+ 52 - 130
asyncbuffer/buffer.go

@@ -15,7 +15,6 @@ package asyncbuffer
 
 import (
 	"errors"
-	"fmt"
 	"io"
 	"sync"
 	"sync/atomic"
@@ -57,30 +56,24 @@ type AsyncBuffer struct {
 	r io.ReadCloser // Upstream reader
 
 	chunks []*byteChunk // References to the chunks read from the upstream reader
+	mu     sync.RWMutex // Mutex on chunks slice
 
-	err      atomic.Value // Error that occurred during reading
-	finished atomic.Bool  // Indicates that the reader has finished reading
-	len      atomic.Int64 // Total length of the data read
-	closed   atomic.Bool  // Indicates that the reader was closed
-	paused   *Latch       // Paused reader does not read data beyond threshold
+	err atomic.Value // Error that occurred during reading
+	len atomic.Int64 // Total length of the data read
 
-	mu             sync.RWMutex  // Mutex on chunks slice
-	newChunkSignal chan struct{} // Tick-tock channel that indicates that a new chunk is ready
-}
+	finished atomic.Bool // Indicates that the buffer has finished reading
+	closed   atomic.Bool // Indicates that the buffer was closed
 
-// Underlying Reader that provides io.ReadSeeker interface for the actual data reading
-// What is the purpose of this Reader?
-type Reader struct {
-	ab  *AsyncBuffer
-	pos int64
+	paused *Latch  // Paused buffer does not read data beyond threshold
+	ticker *Ticker // Ticker that signals when a new chunk is ready
 }
 
 // FromReadCloser creates a new AsyncBuffer that reads from the given io.Reader in background
 func FromReader(r io.ReadCloser) *AsyncBuffer {
 	ab := &AsyncBuffer{
-		r:              r,
-		newChunkSignal: make(chan struct{}),
-		paused:         NewLatch(),
+		r:      r,
+		paused: NewLatch(),
+		ticker: NewTicker(),
 	}
 
 	go ab.readChunks()
@@ -88,19 +81,11 @@ func FromReader(r io.ReadCloser) *AsyncBuffer {
 	return ab
 }
 
-// getNewChunkSignal returns the channel that signals when a new chunk is ready
-// Lock is required to read the channel, so it is not closed while reading
-func (ab *AsyncBuffer) getNewChunkSignal() chan struct{} {
-	ab.mu.RLock()
-	defer ab.mu.RUnlock()
-
-	return ab.newChunkSignal
-}
-
 // addChunk adds a new chunk to the AsyncBuffer, increments len and signals that a chunk is ready
 func (ab *AsyncBuffer) addChunk(chunk *byteChunk) {
 	ab.mu.Lock()
 	defer ab.mu.Unlock()
+
 	if ab.closed.Load() {
 		// If the reader is closed, we return the chunk to the pool
 		chunkPool.Put(chunk)
@@ -111,47 +96,37 @@ func (ab *AsyncBuffer) addChunk(chunk *byteChunk) {
 	ab.chunks = append(ab.chunks, chunk)
 	ab.len.Add(int64(len(chunk.data)))
 
-	// Signal that a chunk is ready
-	currSignal := ab.newChunkSignal
-	ab.newChunkSignal = make(chan struct{})
-	close(currSignal)
+	ab.ticker.Tick()
 }
 
-// finish marks the reader as finished
-func (ab *AsyncBuffer) finish() {
+// finishAndCloseReader marks the reader as finished
+func (ab *AsyncBuffer) finishAndCloseReader() {
 	ab.mu.Lock()
 	defer ab.mu.Unlock()
 
 	// Indicate that the reader has finished reading
 	ab.finished.Store(true)
+	ab.ticker.Close()
 
-	// This indicates that Close() was called before all the chunks were read, we do not need to close the channel
-	// since it was closed already.
-	if !ab.closed.Load() {
-		close(ab.newChunkSignal)
-	}
-
-	err := ab.r.Close() // Close the upstream reader
-	if err != nil {
-		// If there was an error while closing the upstream reader, store it
-		ab.err.Store(err)
-		return
+	// Close the upstream reader
+	if err := ab.r.Close(); err != nil {
+		ab.err.Store(err) // Store the error if it occurred during closing
 	}
 }
 
 // readChunks reads data from the upstream reader in background and stores them in the pool
 func (ab *AsyncBuffer) readChunks() {
-	defer ab.finish()
+	defer ab.finishAndCloseReader()
 
 	// Stop reading if the reader is finished
-	for !ab.finished.Load() {
+	for !ab.closed.Load() {
 		// In case we are trying to read data beyond threshold and we are paused,
 		// wait for pause to be released.
 		if ab.len.Load() >= pauseThreshold {
 			ab.paused.Wait()
 
 			// If the reader has been closed while waiting, we can stop reading
-			if ab.finished.Load() {
+			if ab.closed.Load() {
 				return // No more data to read
 			}
 		}
@@ -165,11 +140,13 @@ func (ab *AsyncBuffer) readChunks() {
 		}
 
 		// Read data into the chunk's buffer
+		// There is no way to guarantee that this would
 		n, err := io.ReadFull(ab.r, chunk.buf)
 
 		// If it's not the EOF, we need to store the error
 		if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
 			ab.err.Store(err)
+			chunkPool.Put(chunk)
 			return
 		}
 
@@ -197,29 +174,24 @@ func (ab *AsyncBuffer) readChunks() {
 // If the reader had an error, it returns that error instead.
 func (ab *AsyncBuffer) closedError() error {
 	// If the reader is closed, we return the error or nil
-	if ab.closed.Load() {
-		err := ab.Error()
-		if err == nil {
-			err = errors.New("asyncbuffer.AsyncBuffer.ReadAt: attempt to read on closed reader")
-		}
+	if !ab.closed.Load() {
+		return nil
+	}
 
-		return err
+	err := ab.Error()
+	if err == nil {
+		err = errors.New("asyncbuffer.AsyncBuffer.ReadAt: attempt to read on closed reader")
 	}
 
-	return nil
+	return err
 }
 
 // offsetAvailable checks if the data at the given offset is available for reading.
 // It may return io.EOF if the reader is finished reading and the offset is beyond the end of the stream.
 func (ab *AsyncBuffer) offsetAvailable(off int64) (bool, error) {
 	// We can not read data from the closed reader, none
-	if ab.closed.Load() {
-		return false, ab.closedError()
-	}
-
-	// In case we are trying to read data beyond the pause threshold, we need to resume the reader
-	if off >= pauseThreshold {
-		ab.paused.Release()
+	if err := ab.closedError(); err != nil {
+		return false, err
 	}
 
 	// In case the offset falls within the already read chunks, we can return immediately,
@@ -232,8 +204,7 @@ func (ab *AsyncBuffer) offsetAvailable(off int64) (bool, error) {
 	// data yet, return either error or EOF
 	if ab.finished.Load() {
 		// In case, error has occurred, we need to return it
-		err := ab.Error()
-		if err != nil {
+		if err := ab.Error(); err != nil {
 			return false, err
 		}
 
@@ -251,7 +222,6 @@ func (ab *AsyncBuffer) WaitFor(off int64) error {
 	// In case we are trying to read data which would potentially hit the pause threshold,
 	// we need to unpause the reader ASAP.
 	if off >= pauseThreshold {
-		fmt.Println(off, pauseThreshold, "UNLOCKING")
 		ab.paused.Release()
 	}
 
@@ -261,7 +231,7 @@ func (ab *AsyncBuffer) WaitFor(off int64) error {
 			return err
 		}
 
-		<-ab.getNewChunkSignal()
+		ab.ticker.Wait()
 	}
 }
 
@@ -272,32 +242,18 @@ func (ab *AsyncBuffer) Wait() (int64, error) {
 	ab.paused.Release()
 
 	for {
-		// We can not read data from the closed reader even if there were no errors
-		if ab.closed.Load() {
-			return 0, ab.closedError()
+		// We can not read data from the closed reader
+		if err := ab.closedError(); err != nil {
+			return 0, err
 		}
 
 		// In case the reader is finished reading, we can return immediately
 		if ab.finished.Load() {
-			size := ab.len.Load()
-
-			// If there was an error during reading, we need to return it no matter what position
-			// had the error happened
-			err := ab.err.Load()
-			if err != nil {
-				err, ok := err.(error)
-				if !ok {
-					return size, errors.New("asyncbuffer.AsyncBuffer.Wait: failed to get error")
-				}
-
-				return size, err
-			}
-
-			return size, nil
+			return ab.len.Load(), ab.Error()
 		}
 
 		// Lock until the next chunk is ready
-		<-ab.getNewChunkSignal()
+		ab.ticker.Wait()
 	}
 }
 
@@ -355,6 +311,12 @@ func (ab *AsyncBuffer) readAt(p []byte, off int64) (int, error) {
 	if off < 0 {
 		return 0, errors.New("asyncbuffer.AsyncBuffer.readAt: negative offset")
 	}
+
+	// If we plan to hit threshold while reading, release the paused reader
+	if int64(len(p))+off > pauseThreshold {
+		ab.paused.Release()
+	}
+
 	// Wait for the offset to be available.
 	// It may return io.EOF if the offset is beyond the end of the stream.
 	err := ab.WaitFor(off)
@@ -362,12 +324,13 @@ func (ab *AsyncBuffer) readAt(p []byte, off int64) (int, error) {
 		return 0, err
 	}
 
+	// We lock the mutex until current buffer is read
 	ab.mu.RLock()
 	defer ab.mu.RUnlock()
 
 	// If the reader is closed, we return an error
-	if ab.closed.Load() {
-		return 0, ab.closedError()
+	if err := ab.closedError(); err != nil {
+		return 0, err
 	}
 
 	// Read data from the first chunk
@@ -414,21 +377,18 @@ func (ab *AsyncBuffer) Close() error {
 	// If the reader is already closed, we return immediately error or nil
 	if ab.closed.Load() {
 		return ab.Error()
+	} else {
+		ab.closed.Store(true)
 	}
 
-	ab.closed.Store(true)
-
-	// If the reader is still running, we need to signal that it should stop and close the channel
-	if !ab.finished.Load() {
-		ab.finished.Store(true)
-		close(ab.newChunkSignal)
-	}
+	ab.finished.Store(true)
 
 	// Return all chunks to the pool
 	for _, chunk := range ab.chunks {
 		chunkPool.Put(chunk)
 	}
 
+	// Release the paused latch so that no goroutines are waiting for it
 	ab.paused.Release()
 
 	return nil
@@ -438,41 +398,3 @@ func (ab *AsyncBuffer) Close() error {
 func (ab *AsyncBuffer) Reader() *Reader {
 	return &Reader{ab: ab, pos: 0}
 }
-
-// Read reads data from the AsyncBuffer.
-func (r *Reader) Read(p []byte) (int, error) {
-	n, err := r.ab.readAt(p, r.pos)
-	if err == nil {
-		r.pos += int64(n)
-	}
-
-	return n, err
-}
-
-// Seek sets the position of the reader to the given offset and returns the new position
-func (r *Reader) Seek(offset int64, whence int) (int64, error) {
-	switch whence {
-	case io.SeekStart:
-		r.pos = offset
-
-	case io.SeekCurrent:
-		r.pos += offset
-
-	case io.SeekEnd:
-		size, err := r.ab.Wait()
-		if err != nil {
-			return 0, err
-		}
-
-		r.pos = size + offset
-
-	default:
-		return 0, errors.New("asyncbuffer.AsyncBuffer.ReadAt: invalid whence")
-	}
-
-	if r.pos < 0 {
-		return 0, errors.New("asyncbuffer.AsyncBuffer.ReadAt: negative position")
-	}
-
-	return r.pos, nil
-}

+ 0 - 1
asyncbuffer/buffer_test.go

@@ -415,7 +415,6 @@ func TestAsyncBufferThreshold(t *testing.T) {
 	target = make([]byte, pauseThreshold+1)
 	n, err = asyncBuffer.readAt(target, 0)
 	require.NoError(t, err)
-	require.Equal(t, pauseThreshold, n)
 
 	// It usually returns only pauseThreshold bytes because this exact operation unpauses the reader,
 	// but the initial offset is before the threshold, data beyond the threshold may not be available.

+ 51 - 0
asyncbuffer/reader.go

@@ -0,0 +1,51 @@
+package asyncbuffer
+
+import (
+	"errors"
+	"io"
+)
+
+// Underlying Reader that provides io.ReadSeeker interface for the actual data reading
+// What is the purpose of this Reader?
+type Reader struct {
+	ab  *AsyncBuffer
+	pos int64
+}
+
+// Read reads data from the AsyncBuffer.
+func (r *Reader) Read(p []byte) (int, error) {
+	n, err := r.ab.readAt(p, r.pos)
+	if err == nil {
+		r.pos += int64(n)
+	}
+
+	return n, err
+}
+
+// Seek sets the position of the reader to the given offset and returns the new position
+func (r *Reader) Seek(offset int64, whence int) (int64, error) {
+	switch whence {
+	case io.SeekStart:
+		r.pos = offset
+
+	case io.SeekCurrent:
+		r.pos += offset
+
+	case io.SeekEnd:
+		size, err := r.ab.Wait()
+		if err != nil {
+			return 0, err
+		}
+
+		r.pos = size + offset
+
+	default:
+		return 0, errors.New("asyncbuffer.AsyncBuffer.ReadAt: invalid whence")
+	}
+
+	if r.pos < 0 {
+		return 0, errors.New("asyncbuffer.AsyncBuffer.ReadAt: negative position")
+	}
+
+	return r.pos, nil
+}

+ 59 - 0
asyncbuffer/ticker.go

@@ -0,0 +1,59 @@
+package asyncbuffer
+
+import (
+	"sync"
+)
+
+type tickCh = chan struct{}
+
+// Ticker signals that an event has occurred to a multiple waiters.
+type Ticker struct {
+	_         noCopy
+	mu        sync.Mutex
+	ch        tickCh
+	closeOnce sync.Once
+}
+
+// NewTicker creates a new Ticker instance with an initialized channel.
+func NewTicker() *Ticker {
+	return &Ticker{
+		ch: make(tickCh),
+	}
+}
+
+// Tick signals that an event has occurred by closing the channel.
+func (t *Ticker) Tick() {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	if t.ch != nil {
+		close(t.ch)
+		t.ch = make(tickCh)
+	}
+}
+
+// Wait blocks until the channel is closed, indicating that an event has occurred.
+func (t *Ticker) Wait() {
+	t.mu.Lock()
+	ch := t.ch
+	t.mu.Unlock()
+
+	if ch == nil {
+		return
+	}
+
+	<-ch
+}
+
+// Close closes the ticker channel and prevents further ticks.
+func (t *Ticker) Close() {
+	t.closeOnce.Do(func() {
+		t.mu.Lock()
+		defer t.mu.Unlock()
+
+		if t.ch != nil {
+			close(t.ch)
+			t.ch = nil
+		}
+	})
+}

+ 158 - 0
asyncbuffer/ticker_test.go

@@ -0,0 +1,158 @@
+package asyncbuffer
+
+import (
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/suite"
+)
+
+type TickerTestSuite struct {
+	suite.Suite
+	ticker *Ticker
+}
+
+func (s *TickerTestSuite) SetupTest() {
+	s.ticker = NewTicker()
+}
+
+func (s *TickerTestSuite) TeardownTest() {
+	if s.ticker != nil {
+		s.ticker.Close()
+	}
+}
+
+// TestBasicWaitAndTick tests the basic functionality of the Ticker
+func (s *TickerTestSuite) TestBasicWaitAndTick() {
+	done := make(chan struct{})
+
+	ch := s.ticker.ch
+
+	// Start a goroutine that will tick after a short delay
+	go func() {
+		time.Sleep(50 * time.Millisecond)
+		s.ticker.Tick()
+	}()
+
+	// Start a goroutine that will wait for the tick
+	go func() {
+		s.ticker.Wait()
+		close(done)
+	}()
+
+	s.Require().Eventually(func() bool {
+		select {
+		case <-done:
+			return true
+		default:
+			return false
+		}
+	}, 100*time.Millisecond, 10*time.Millisecond)
+
+	// Means that and old channel was closed and a new one has been created
+	s.Require().NotEqual(ch, s.ticker.ch)
+}
+
+// TestWaitMultipleWaiters tests that multiple waiters can be unblocked by a single tick
+func (s *TickerTestSuite) TestWaitMultipleWaiters() {
+	const numWaiters = 10
+
+	var wg sync.WaitGroup
+	var startWg sync.WaitGroup
+	results := make([]bool, numWaiters)
+
+	// Start multiple waiters
+	for i := range numWaiters {
+		wg.Add(1)
+		startWg.Add(1)
+		go func(index int) {
+			defer wg.Done()
+			startWg.Done() // Signal that this goroutine is ready
+			s.ticker.Wait()
+			results[index] = true
+		}(i)
+	}
+
+	// Wait for all goroutines to start waiting
+	startWg.Wait()
+
+	// Wait for all waiters to complete
+	done := make(chan struct{})
+	go func() {
+		s.ticker.Tick() // Signal that execution can proceed
+		wg.Wait()
+		close(done)
+	}()
+
+	s.Require().Eventually(func() bool {
+		select {
+		case <-done:
+			return true
+		default:
+			return false
+		}
+	}, 100*time.Millisecond, 10*time.Millisecond)
+
+	// Check that all waiters were unblocked
+	for _, completed := range results {
+		s.Require().True(completed)
+	}
+}
+
+// TestClose tests the behavior of the Ticker when closed
+func (s *TickerTestSuite) TestClose() {
+	s.ticker.Close()
+	s.ticker.Close() // Should not panic
+	s.ticker.Wait()  // Should eventually return
+	s.ticker.Tick()  // Should not panic
+
+	s.Require().Nil(s.ticker.ch)
+}
+
+func (s *TickerTestSuite) TestRapidTicksAndWaits() {
+	const iterations = 1000
+
+	var wg sync.WaitGroup
+
+	// Start a goroutine that will rapidly tick
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		for range iterations {
+			s.ticker.Tick()
+			time.Sleep(time.Microsecond)
+		}
+		s.ticker.Close() // Close after all ticks
+	}()
+
+	// Start multiple waiters
+	for range 10 {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			for range iterations / 10 {
+				s.ticker.Wait()
+			}
+		}()
+	}
+
+	done := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(done)
+	}()
+
+	s.Require().Eventually(func() bool {
+		select {
+		case <-done:
+			return true
+		default:
+			return false
+		}
+	}, 100*time.Millisecond, 10*time.Millisecond)
+}
+
+func TestTicker(t *testing.T) {
+	suite.Run(t, new(TickerTestSuite))
+}