Browse Source

fix: ensure the list sort query is validated to prevent SQL injection

Credits to @jorgectf for the advisories.
Hintay 1 year ago
parent
commit
ec93ab05a3
2 changed files with 34 additions and 10 deletions
  1. 20 8
      api/cosy/sort.go
  2. 14 2
      model/model.go

+ 20 - 8
api/cosy/sort.go

@@ -2,27 +2,39 @@ package cosy
 
 import (
 	"fmt"
+	"github.com/0xJacky/Nginx-UI/internal/logger"
 	"github.com/gin-gonic/gin"
 	"gorm.io/gorm"
+	"gorm.io/gorm/schema"
+	"sync"
 )
 
 func (c *Ctx[T]) SortOrder() func(db *gorm.DB) *gorm.DB {
 	return func(db *gorm.DB) *gorm.DB {
 		sort := c.ctx.DefaultQuery("order", "desc")
-		order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort)
-		return db.Order(order)
+		if sort != "desc" && sort != "asc" {
+			sort = "desc"
+		}
+
+		// check if the order field is valid
+		// todo: maybe we can use more generic way to check if the sort_by is valid
+		order := DefaultQuery(c.ctx, "sort_by", c.itemKey)
+		s, _ := schema.Parse(c.Model, &sync.Map{}, schema.NamingStrategy{})
+		if _, ok := s.FieldsByDBName[order]; ok {
+			order = fmt.Sprintf("%s %s", order, sort)
+			return db.Order(order)
+		} else {
+			logger.Error("invalid order field:", order)
+		}
+
+		return db
 	}
 }
 
 func (c *Ctx[T]) OrderAndPaginate() func(db *gorm.DB) *gorm.DB {
 	return func(db *gorm.DB) *gorm.DB {
-		sort := c.ctx.DefaultQuery("order", "desc")
-
-		order := fmt.Sprintf("%s %s", DefaultQuery(c.ctx, "sort_by", c.itemKey), sort)
-		db = db.Order(order)
-
+		db = c.SortOrder()(db)
 		_, offset, pageSize := GetPagingParams(c.ctx)
-
 		return db.Offset(offset).Limit(pageSize)
 	}
 }

+ 14 - 2
model/model.go

@@ -10,8 +10,10 @@ import (
 	"gorm.io/gen"
 	"gorm.io/gorm"
 	gormlogger "gorm.io/gorm/logger"
+	"gorm.io/gorm/schema"
 	"path"
 	"strings"
+	"sync"
 	"time"
 )
 
@@ -100,9 +102,19 @@ func SortOrder(c *gin.Context) func(db *gorm.DB) *gorm.DB {
 func OrderAndPaginate(c *gin.Context) func(db *gorm.DB) *gorm.DB {
 	return func(db *gorm.DB) *gorm.DB {
 		sort := c.DefaultQuery("order", "desc")
+		if sort != "desc" && sort != "asc" {
+			sort = "desc"
+		}
 
-		order := fmt.Sprintf("`%s` %s", DefaultQuery(c, "sort_by", "id"), sort)
-		db = db.Order(order)
+		// check if the order field is valid
+		order := c.DefaultQuery("sort_by", "id")
+		s, _ := schema.Parse(db.Model, &sync.Map{}, schema.NamingStrategy{})
+		if _, ok := s.FieldsByName[order]; ok {
+			order = fmt.Sprintf("%s %s", order, sort)
+			db = db.Order(order)
+		} else {
+			logger.Error("invalid order field: ", order)
+		}
 
 		page := cast.ToInt(c.Query("page"))
 		if page == 0 {