Просмотр исходного кода

Use XML parser based on encoding/xml

DarthSim 4 месяцев назад
Родитель
Сommit
df374a6503

+ 0 - 1
go.mod

@@ -39,7 +39,6 @@ require (
 	github.com/prometheus/client_golang v1.23.2
 	github.com/shirou/gopsutil v3.21.11+incompatible
 	github.com/stretchr/testify v1.11.1
-	github.com/tdewolff/parse/v2 v2.8.3
 	github.com/trimmer-io/go-xmp v1.0.0
 	github.com/urfave/cli/v3 v3.4.1
 	go.opentelemetry.io/contrib/detectors/aws/ec2/v2 v2.0.0

+ 0 - 4
go.sum

@@ -429,10 +429,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
 github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
-github.com/tdewolff/parse/v2 v2.7.21 h1:OCuPFtGr4mXdnfKikQlUb0n654ROJANhBqCk+wioJ/A=
-github.com/tdewolff/parse/v2 v2.7.21/go.mod h1:I7TXO37t3aSG9SlPUBefAhgIF8nt7yYUwVGgETIoBcA=
-github.com/tdewolff/test v1.0.11 h1:FdLbwQVHxqG16SlkGveC0JVyrJN62COWTRyUFzfbtBE=
-github.com/tdewolff/test v1.0.11/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8=
 github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww=
 github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
 github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4=

+ 13 - 22
imagetype/svg.go

@@ -1,14 +1,11 @@
 package imagetype
 
 import (
+	"encoding/xml"
 	"errors"
 	"io"
-	"strings"
 
 	"github.com/imgproxy/imgproxy/v3/bufreader"
-
-	"github.com/tdewolff/parse/v2"
-	"github.com/tdewolff/parse/v2/xml"
 )
 
 func init() {
@@ -18,33 +15,27 @@ func init() {
 }
 
 func IsSVG(r bufreader.ReadPeeker) (Type, error) {
-	l := xml.NewLexer(parse.NewInput(r))
+	dec := xml.NewDecoder(r)
+	dec.Strict = false
 
 	for {
-		tt, _ := l.Next()
-
-		switch tt {
-		case xml.ErrorToken:
-			err := l.Err()
-
-			if err == io.EOF || err == io.ErrUnexpectedEOF {
-				// EOF or unexpected EOF means we don't have enough data to determine the type
-				return Unknown, nil
-			}
-
-			var perr *parse.Error
+		tok, err := dec.RawToken()
+		if err == io.EOF || err == io.ErrUnexpectedEOF {
+			// EOF or unexpected EOF means we don't have enough data to determine the type
+			return Unknown, nil
+		}
+		if err != nil {
+			var perr *xml.SyntaxError
 			if errors.As(err, &perr) {
 				// If the error is a parse error, we can assume that the data is not SVG
 				return Unknown, nil
 			}
 
 			return Unknown, err
+		}
 
-		case xml.StartTagToken:
-			tag := strings.ToLower(string(l.Text()))
-			if tag == "svg" || tag == "svg:svg" {
-				return SVG, nil
-			}
+		if se, ok := tok.(xml.StartElement); ok && se.Name.Local == "svg" {
+			return SVG, nil
 		}
 	}
 }

+ 6 - 0
imagetype/svg_test.go

@@ -25,6 +25,12 @@ func TestSVGDetectSuccess(t *testing.T) {
 	typ, err = IsSVG(r)
 	require.NoError(t, err)
 	require.Equal(t, SVG, typ)
+
+	// Partial content; Simulate limit reader
+	r = bufreader.New(strings.NewReader(`<svg xmlns="http://www.w3.org/2000/svg">SomethingSomething...`))
+	typ, err = IsSVG(r)
+	require.NoError(t, err)
+	require.Equal(t, SVG, typ)
 }
 
 func TestSVGDetectNotSvg(t *testing.T) {

+ 31 - 0
processing/svg/parser/document.go

@@ -0,0 +1,31 @@
+package svgparser
+
+import (
+	"bufio"
+	"io"
+)
+
+type Document struct {
+	Node
+}
+
+func NewDocument(r io.ReadSeeker) (*Document, error) {
+	doc := &Document{}
+	if err := doc.readFrom(r); err != nil {
+		return nil, err
+	}
+
+	return doc, nil
+}
+
+func (doc *Document) WriteTo(w io.Writer) (int64, error) {
+	wc := writeCounter{Writer: w}
+	bw := bufio.NewWriter(&wc)
+	if err := doc.writeChildrenTo(bw); err != nil {
+		return 0, err
+	}
+	if err := bw.Flush(); err != nil {
+		return 0, err
+	}
+	return wc.Count, nil
+}

+ 73 - 0
processing/svg/parser/escape.go

@@ -0,0 +1,73 @@
+package svgparser
+
+import (
+	"bufio"
+	"unicode/utf8"
+)
+
+var (
+	escQuot = []byte("&quot;")
+	escApos = []byte("&apos;")
+	escAmp  = []byte("&amp;")
+	escLT   = []byte("&lt;")
+	escGT   = []byte("&gt;")
+	escTab  = []byte("&#x9;")
+	escNL   = []byte("&#xA;")
+	escCR   = []byte("&#xD;")
+	escFFFD = []byte("\uFFFD") // Unicode replacement character
+)
+
+// escapeString writes an escaped version of a string to the writer.
+// It's a copy of xml.EscapeText with changed escape replacements.
+func escapeString(w *bufio.Writer, s string) error {
+	var esc []byte
+	last := 0
+	for i := 0; i < len(s); {
+		r, width := utf8.DecodeRuneInString(s[i:])
+		i += width
+		switch r {
+		case '"':
+			esc = escQuot
+		case '\'':
+			esc = escApos
+		case '&':
+			esc = escAmp
+		case '<':
+			esc = escLT
+		case '>':
+			esc = escGT
+		case '\t':
+			esc = escTab
+		case '\n':
+			esc = escNL
+		case '\r':
+			esc = escCR
+		default:
+			if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
+				esc = escFFFD
+				break
+			}
+			continue
+		}
+		if _, err := w.WriteString(s[last : i-width]); err != nil {
+			return err
+		}
+		if _, err := w.Write(esc); err != nil {
+			return err
+		}
+		last = i
+	}
+	if _, err := w.WriteString(s[last:]); err != nil {
+		return err
+	}
+	return nil
+}
+
+func isInCharacterRange(r rune) bool {
+	return r == 0x09 ||
+		r == 0x0A ||
+		r == 0x0D ||
+		r >= 0x20 && r <= 0xD7FF ||
+		r >= 0xE000 && r <= 0xFFFD ||
+		r >= 0x10000 && r <= 0x10FFFF
+}

+ 253 - 0
processing/svg/parser/node.go

@@ -0,0 +1,253 @@
+package svgparser
+
+import (
+	"bufio"
+	"encoding/xml"
+	"errors"
+	"fmt"
+	"io"
+
+	"golang.org/x/net/html/charset"
+)
+
+type Attr = xml.Attr
+
+type Node struct {
+	Parent   *Node
+	Name     xml.Name
+	Attrs    []Attr
+	Children []any
+}
+
+func (n *Node) readFrom(r io.ReadSeeker) error {
+	if n.Parent != nil {
+		return errors.New("cannot read child node")
+	}
+
+	dec := xml.NewDecoder(r)
+	dec.Strict = false
+	dec.CharsetReader = charset.NewReaderLabel
+
+	curNode := n
+
+	for {
+		// Save the current position to know where to read raw CData from.
+		pos := dec.InputOffset()
+
+		// Read raw token so decoder doesn't mess with attributes and namespaces.
+		tok, err := dec.RawToken()
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return err
+		}
+
+		switch t := tok.(type) {
+		case xml.StartElement:
+			// An element is opened, create a node for it
+			el := &Node{
+				Parent: curNode,
+				Name:   t.Name,
+				Attrs:  t.Attr,
+			}
+			// Append the node to the current node's children and make it current
+			curNode.Children = append(curNode.Children, el)
+			curNode = el
+
+		case xml.EndElement:
+			// If the current node has no parent, then we are at the root,
+			// which can't be closed.
+			if curNode.Parent == nil {
+				return fmt.Errorf(
+					"malformed XML: unexpected closing tag </%s> while no elements are opened",
+					fullName(t.Name),
+				)
+			}
+			// Closing tag name should match opened node name (which is current)
+			if curNode.Name.Local != t.Name.Local || curNode.Name.Space != t.Name.Space {
+				return fmt.Errorf(
+					"malformed XML: unexpected closing tag </%s> for opened <%s> element",
+					fullName(t.Name),
+					fullName(curNode.Name),
+				)
+			}
+			// The node is closed, return to its parent
+			curNode = curNode.Parent
+
+		case xml.CharData:
+			// We want CData as is, so read it raw
+			cdata, err := readRawCData(r, pos, dec.InputOffset()-pos)
+			if err != nil {
+				return err
+			}
+
+			curNode.Children = append(curNode.Children, cdata)
+
+		case xml.Directive:
+			curNode.Children = append(curNode.Children, t.Copy())
+
+		case xml.Comment:
+			curNode.Children = append(curNode.Children, t.Copy())
+
+		case xml.ProcInst:
+			curNode.Children = append(curNode.Children, t.Copy())
+		}
+	}
+
+	return nil
+}
+
+func (n *Node) writeTo(w *bufio.Writer) error {
+	if err := w.WriteByte('<'); err != nil {
+		return err
+	}
+	if err := writeFullName(w, n.Name); err != nil {
+		return err
+	}
+
+	n.writeAttrsTo(w)
+
+	if len(n.Children) == 0 {
+		if _, err := w.WriteString("/>"); err != nil {
+			return err
+		}
+		return nil
+	}
+
+	if err := w.WriteByte('>'); err != nil {
+		return err
+	}
+
+	if err := n.writeChildrenTo(w); err != nil {
+		return err
+	}
+
+	if _, err := w.WriteString("</"); err != nil {
+		return err
+	}
+	if err := writeFullName(w, n.Name); err != nil {
+		return err
+	}
+	if err := w.WriteByte('>'); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (n *Node) writeAttrsTo(w *bufio.Writer) error {
+	for _, attr := range n.Attrs {
+		if err := w.WriteByte(' '); err != nil {
+			return err
+		}
+		if err := writeFullName(w, attr.Name); err != nil {
+			return err
+		}
+		if _, err := w.WriteString(`="`); err != nil {
+			return err
+		}
+		if len(attr.Value) > 2 && attr.Value[0] == '&' && attr.Value[len(attr.Value)-1] == ';' {
+			// Attribute value is an entity, write it as is
+			if _, err := w.WriteString(attr.Value); err != nil {
+				return err
+			}
+		} else {
+			// Escape the attribute value
+			if err := escapeString(w, attr.Value); err != nil {
+				return err
+			}
+		}
+		if err := w.WriteByte('"'); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (n *Node) writeChildrenTo(w *bufio.Writer) error {
+	for _, child := range n.Children {
+		switch c := child.(type) {
+		case *Node:
+			if err := c.writeTo(w); err != nil {
+				return err
+			}
+
+		case CData:
+			if _, err := w.Write([]byte(c)); err != nil {
+				return err
+			}
+
+		case Comment:
+			if _, err := w.WriteString("<!--"); err != nil {
+				return err
+			}
+			if _, err := w.Write([]byte(c)); err != nil {
+				return err
+			}
+			if _, err := w.WriteString("-->"); err != nil {
+				return err
+			}
+
+		case Directive:
+			if _, err := w.WriteString("<!"); err != nil {
+				return err
+			}
+			if _, err := w.Write([]byte(c)); err != nil {
+				return err
+			}
+			if err := w.WriteByte('>'); err != nil {
+				return err
+			}
+
+		case ProcInst:
+			if _, err := w.WriteString("<?"); err != nil {
+				return err
+			}
+			if _, err := w.WriteString(c.Target); err != nil {
+				return err
+			}
+			if len(c.Inst) > 0 {
+				if err := w.WriteByte(' '); err != nil {
+					return err
+				}
+				if _, err := w.Write([]byte(c.Inst)); err != nil {
+					return err
+				}
+			}
+			if _, err := w.WriteString("?>"); err != nil {
+				return err
+			}
+
+		default:
+			return fmt.Errorf("unknown child type: %T", c)
+		}
+	}
+
+	return nil
+}
+
+func fullName(name xml.Name) string {
+	if len(name.Space) == 0 {
+		return name.Local
+	}
+	return name.Space + ":" + name.Local
+}
+
+func writeFullName(w *bufio.Writer, name xml.Name) error {
+	if len(name.Space) > 0 {
+		if _, err := w.WriteString(name.Space); err != nil {
+			return err
+		}
+		if err := w.WriteByte(':'); err != nil {
+			return err
+		}
+	}
+
+	if _, err := w.WriteString(name.Local); err != nil {
+		return err
+	}
+
+	return nil
+}

+ 39 - 0
processing/svg/parser/tokens.go

@@ -0,0 +1,39 @@
+package svgparser
+
+import (
+	"encoding/xml"
+	"io"
+)
+
+type Directive = xml.Directive
+type Comment = xml.Comment
+type ProcInst = xml.ProcInst
+
+type CData []byte
+
+func readRawCData(r io.ReadSeeker, from, size int64) (CData, error) {
+	// Get the current position of the reader so we can return there
+	// after reading raw CData.
+	curPos, err := r.Seek(0, io.SeekCurrent)
+	if err != nil {
+		return nil, err
+	}
+
+	// Seek to the CData start
+	if _, err := r.Seek(from, io.SeekStart); err != nil {
+		return nil, err
+	}
+
+	// Read the raw CData.
+	cdata := make(CData, size)
+	if _, err := io.ReadFull(r, cdata); err != nil {
+		return nil, err
+	}
+
+	// Restore the reader position
+	if _, err := r.Seek(curPos, io.SeekStart); err != nil {
+		return nil, err
+	}
+
+	return cdata, nil
+}

+ 14 - 0
processing/svg/parser/write_couter.go

@@ -0,0 +1,14 @@
+package svgparser
+
+import "io"
+
+type writeCounter struct {
+	Writer io.Writer
+	Count  int64
+}
+
+func (wc *writeCounter) Write(p []byte) (int, error) {
+	n, err := wc.Writer.Write(p)
+	wc.Count += int64(n)
+	return n, err
+}

+ 81 - 54
processing/svg/svg.go

@@ -3,15 +3,12 @@ package svg
 import (
 	"bytes"
 	"errors"
-	"io"
-	"strings"
 	"sync"
 
 	"github.com/imgproxy/imgproxy/v3/imagedata"
 	"github.com/imgproxy/imgproxy/v3/imagetype"
 	"github.com/imgproxy/imgproxy/v3/options"
-	"github.com/tdewolff/parse/v2"
-	"github.com/tdewolff/parse/v2/xml"
+	svgparser "github.com/imgproxy/imgproxy/v3/processing/svg/parser"
 )
 
 // pool represents temorary pool for svg sanitized data
@@ -56,8 +53,13 @@ func (p *Processor) sanitize(data imagedata.ImageData) (imagedata.ImageData, err
 		return data, nil
 	}
 
-	r := data.Reader()
-	l := xml.NewLexer(parse.NewInput(r))
+	doc, err := svgparser.NewDocument(data.Reader())
+	if err != nil {
+		return nil, newSanitizeError(err)
+	}
+
+	// Sanitize the document's children
+	p.sanitizeChildren(&doc.Node)
 
 	buf, ok := pool.Get().(*bytes.Buffer)
 	if !ok {
@@ -69,67 +71,92 @@ func (p *Processor) sanitize(data imagedata.ImageData) (imagedata.ImageData, err
 		pool.Put(buf)
 	}
 
-	ignoreTag := 0
-
-	var curTagName string
+	// Write the sanitized document to the buffer
+	if _, err := doc.WriteTo(buf); err != nil {
+		cancel()
+		return nil, newSanitizeError(err)
+	}
 
-	for {
-		tt, tdata := l.Next()
+	// Create new ImageData from the sanitized buffer
+	newData := imagedata.NewFromBytesWithFormat(
+		imagetype.SVG,
+		buf.Bytes(),
+	)
+	newData.AddCancel(cancel)
 
-		if tt == xml.ErrorToken {
-			if l.Err() != io.EOF {
-				cancel()
-				return nil, newSanitizeError(l.Err())
-			}
-			break
-		}
+	return newData, nil
+}
 
-		if ignoreTag > 0 {
-			switch tt {
-			case xml.EndTagToken, xml.StartTagCloseVoidToken:
-				ignoreTag--
-			case xml.StartTagToken:
-				ignoreTag++
-			}
+// sanitizeChildren sanitizes all child elements of the given element.
+func (p *Processor) sanitizeChildren(el *svgparser.Node) {
+	if el == nil || len(el.Children) == 0 {
+		return
+	}
 
+	// Filter children in place
+	filteredChildren := el.Children[:0]
+	for _, toc := range el.Children {
+		childEl, ok := toc.(*svgparser.Node)
+		if !ok {
+			// Keep non-element nodes (text, comments, etc.)
+			filteredChildren = append(filteredChildren, toc)
 			continue
 		}
 
-		switch tt {
-		case xml.StartTagToken:
-			curTagName = strings.ToLower(string(l.Text()))
+		// Sanitize the child element.
+		// Keep this child if sanitizeElement returned true.
+		if p.sanitizeElement(childEl) {
+			filteredChildren = append(filteredChildren, childEl)
+		}
+	}
 
-			if curTagName == "script" {
-				ignoreTag++
-				continue
-			}
+	el.Children = filteredChildren
+}
 
-			buf.Write(tdata)
-		case xml.AttributeToken:
-			attrName := strings.ToLower(string(l.Text()))
+// sanitizeElement sanitizes a single SVG element.
+// It returns true if the element should be kept, false if it should be removed.
+func (p *Processor) sanitizeElement(el *svgparser.Node) bool {
+	if el == nil {
+		return false
+	}
 
-			if _, unsafe := unsafeAttrs[attrName]; unsafe {
-				continue
-			}
+	// Strip <script> tags
+	if el.Name.Local == "script" {
+		return false
+	}
 
-			if curTagName == "use" && (attrName == "href" || attrName == "xlink:href") {
-				val := strings.TrimSpace(strings.Trim(string(l.AttrVal()), `"'`))
-				if len(val) > 0 && val[0] != '#' {
-					continue
-				}
+	// Filter out unsafe attributes (such as on* events)
+	el.Attrs = filterAttributes(el.Attrs, func(attr svgparser.Attr) bool {
+		_, unsafe := unsafeAttrs[attr.Name.Local]
+		return !unsafe
+	})
+
+	// Special handling for <use> tags.
+	if el.Name.Local == "use" {
+		el.Attrs = filterAttributes(el.Attrs, func(attr svgparser.Attr) bool {
+			// Keep non-href attributes
+			if attr.Name.Local != "href" {
+				return true
 			}
-
-			buf.Write(tdata)
-		default:
-			buf.Write(tdata)
-		}
+			// Strip hrefs that are not internal references
+			return len(attr.Value) == 0 || attr.Value[0] == '#'
+		})
 	}
 
-	newData := imagedata.NewFromBytesWithFormat(
-		imagetype.SVG,
-		buf.Bytes(),
-	)
-	newData.AddCancel(cancel)
+	// Recurse into children
+	p.sanitizeChildren(el)
 
-	return newData, nil
+	// Keep this element
+	return true
+}
+
+// filterAttributes filters attributes based on the given predicate function.
+func filterAttributes(attrs []svgparser.Attr, f func(attr svgparser.Attr) bool) []svgparser.Attr {
+	filtered := attrs[:0]
+	for _, attr := range attrs {
+		if f(attr) {
+			filtered = append(filtered, attr)
+		}
+	}
+	return filtered
 }