Explorar el Código

Add priority to imagetype detectors

DarthSim hace 1 mes
padre
commit
61418f55b8
Se han modificado 3 ficheros con 66 adiciones y 28 borrados
  1. 24 8
      imagetype/registry.go
  2. 39 18
      imagetype/registry_test.go
  3. 3 2
      imagetype/svg.go

+ 24 - 8
imagetype/registry.go

@@ -2,6 +2,7 @@ package imagetype
 
 import (
 	"io"
+	"slices"
 
 	"github.com/imgproxy/imgproxy/v3/bufreader"
 )
@@ -30,9 +31,15 @@ type TypeDesc struct {
 // DetectFunc is a function that detects the image type from byte data
 type DetectFunc func(r bufreader.ReadPeeker) (Type, error)
 
+// detector is a struct that holds a detection function and its priority
+type detector struct {
+	priority int        // priority of the detector, lower is better
+	fn       DetectFunc // function that detects the image type
+}
+
 // Registry holds the type registry
 type registry struct {
-	detectors   []DetectFunc
+	detectors   []detector
 	types       []*TypeDesc
 	typesByName map[string]Type // maps type string to Type
 }
@@ -106,8 +113,8 @@ func (r *registry) GetTypeByName(name string) (Type, bool) {
 
 // RegisterDetector registers a custom detector function
 // Detectors are tried in the order they were registered
-func RegisterDetector(detector DetectFunc) {
-	globalRegistry.RegisterDetector(detector)
+func RegisterDetector(priority int, fn DetectFunc) {
+	globalRegistry.RegisterDetector(priority, fn)
 }
 
 // RegisterMagicBytes registers magic bytes for a specific image type
@@ -123,13 +130,22 @@ func Detect(r io.Reader) (Type, error) {
 }
 
 // RegisterDetector registers a custom detector function on this registry instance
-func (r *registry) RegisterDetector(detector DetectFunc) {
-	r.detectors = append(r.detectors, detector)
+func (r *registry) RegisterDetector(priority int, fn DetectFunc) {
+	r.detectors = append(r.detectors, detector{
+		priority: priority,
+		fn:       fn,
+	})
+	// Sort detectors by priority.
+	// We don't expect a huge number of detectors, and detectors should be registered at startup,
+	// so sorting them at each registration is okay.
+	slices.SortStableFunc(r.detectors, func(a, b detector) int {
+		return a.priority - b.priority
+	})
 }
 
 // RegisterMagicBytes registers magic bytes for a specific image type on this registry instance
 func (r *registry) RegisterMagicBytes(typ Type, signature ...[]byte) {
-	r.detectors = append(r.detectors, func(r bufreader.ReadPeeker) (Type, error) {
+	r.RegisterDetector(-1, func(r bufreader.ReadPeeker) (Type, error) {
 		for _, sig := range signature {
 			b, err := r.Peek(len(sig))
 			if err != nil {
@@ -149,9 +165,9 @@ func (r *registry) RegisterMagicBytes(typ Type, signature ...[]byte) {
 func (r *registry) Detect(re io.Reader) (Type, error) {
 	br := bufreader.New(io.LimitReader(re, maxDetectionLimit))
 
-	for _, fn := range r.detectors {
+	for _, d := range r.detectors {
 		br.Rewind()
-		typ, err := fn(br)
+		typ, err := d.fn(br)
 		if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
 			return Unknown, newTypeDetectionError(err)
 		}

+ 39 - 18
imagetype/registry_test.go

@@ -1,7 +1,10 @@
 package imagetype
 
 import (
+	"bytes"
 	"io"
+	"reflect"
+	"runtime"
 	"strings"
 	"testing"
 
@@ -145,24 +148,37 @@ func TestRegisterDetector(t *testing.T) {
 	// Create a test registry to avoid interfering with global state
 	testRegistry := NewRegistry()
 
-	// Create a test detector function
-	testDetector := func(r bufreader.ReadPeeker) (Type, error) {
-		b, err := r.Peek(2)
-		if err != nil {
-			return Unknown, err
-		}
-		if len(b) >= 2 && b[0] == 0xFF && b[1] == 0xD8 {
-			return JPEG, nil
-		}
-		return Unknown, newUnknownFormatError()
+	functionsEqual := func(fn1, fn2 DetectFunc) {
+		// Compare function names to check if they are the same
+		fnName1 := runtime.FuncForPC(reflect.ValueOf(fn1).Pointer()).Name()
+		fnName2 := runtime.FuncForPC(reflect.ValueOf(fn2).Pointer()).Name()
+		require.Equal(t, fnName1, fnName2)
 	}
 
-	// Register the detector using the method
-	testRegistry.RegisterDetector(testDetector)
-
-	// Verify the detector is registered
-	require.Len(t, testRegistry.detectors, 1)
-	require.NotNil(t, testRegistry.detectors[0])
+	// Create a test detector functions
+	testDetector1 := func(r bufreader.ReadPeeker) (Type, error) { return JPEG, nil }
+	testDetector2 := func(r bufreader.ReadPeeker) (Type, error) { return PNG, nil }
+	testDetector3 := func(r bufreader.ReadPeeker) (Type, error) { return GIF, nil }
+	testDetector4 := func(r bufreader.ReadPeeker) (Type, error) { return SVG, nil }
+
+	// Register the detectors using the method
+	testRegistry.RegisterDetector(0, testDetector1)
+	testRegistry.RegisterDetector(0, testDetector2)
+	testRegistry.RegisterDetector(10, testDetector3)
+	testRegistry.RegisterDetector(5, testDetector4)
+
+	// Verify the detectors are registered
+	require.Len(t, testRegistry.detectors, 4)
+
+	// Verify the order of detectors based on priority
+	require.Equal(t, 0, testRegistry.detectors[0].priority)
+	functionsEqual(testDetector1, testRegistry.detectors[0].fn)
+	require.Equal(t, 0, testRegistry.detectors[1].priority)
+	functionsEqual(testDetector2, testRegistry.detectors[1].fn)
+	require.Equal(t, 5, testRegistry.detectors[2].priority)
+	functionsEqual(testDetector4, testRegistry.detectors[2].fn)
+	require.Equal(t, 10, testRegistry.detectors[3].priority)
+	functionsEqual(testDetector3, testRegistry.detectors[3].fn)
 }
 
 func TestRegisterMagicBytes(t *testing.T) {
@@ -177,6 +193,11 @@ func TestRegisterMagicBytes(t *testing.T) {
 
 	// Verify the magic bytes are registered
 	require.Len(t, testRegistry.detectors, 1)
+	require.Equal(t, -1, testRegistry.detectors[0].priority)
+
+	typ, err := testRegistry.Detect(bufreader.New(bytes.NewReader(jpegMagic)))
+	require.NoError(t, err)
+	require.Equal(t, JPEG, typ)
 }
 
 func TestDetectionErrorReturns(t *testing.T) {
@@ -186,12 +207,12 @@ func TestDetectionErrorReturns(t *testing.T) {
 	detErr := error(nil)
 
 	// The first detector will fail with detErr
-	testRegistry.RegisterDetector(func(r bufreader.ReadPeeker) (Type, error) {
+	testRegistry.RegisterDetector(0, func(r bufreader.ReadPeeker) (Type, error) {
 		return Unknown, detErr
 	})
 
 	// The second detector will succeed
-	testRegistry.RegisterDetector(func(r bufreader.ReadPeeker) (Type, error) {
+	testRegistry.RegisterDetector(1, func(r bufreader.ReadPeeker) (Type, error) {
 		return JPEG, nil
 	})
 

+ 3 - 2
imagetype/svg.go

@@ -12,8 +12,9 @@ import (
 )
 
 func init() {
-	// Register SVG detector (needs at least 1000 bytes to reliably detect SVG)
-	RegisterDetector(IsSVG)
+	// Register SVG detector.
+	// We register it with a priority of 100 to run it after magic number detectors
+	RegisterDetector(100, IsSVG)
 }
 
 func IsSVG(r bufreader.ReadPeeker) (Type, error) {