Browse Source

Simple http router for better looking server code

DarthSim 6 years ago
parent
commit
ab9363e989
2 changed files with 140 additions and 107 deletions
  1. 86 0
      router.go
  2. 54 107
      server.go

+ 86 - 0
router.go

@@ -0,0 +1,86 @@
+package main
+
+import (
+	"net/http"
+	"regexp"
+	"strings"
+
+	nanoid "github.com/matoous/go-nanoid"
+)
+
+const (
+	xRequestIDHeader = "X-Request-ID"
+)
+
+var (
+	requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
+)
+
+type routeHandler func(string, http.ResponseWriter, *http.Request)
+type panicHandler func(string, http.ResponseWriter, *http.Request, error)
+
+type route struct {
+	Method  string
+	Prefix  string
+	Handler routeHandler
+}
+
+type router struct {
+	Routes       []*route
+	PanicHandler panicHandler
+}
+
+func newRouter() *router {
+	return &router{
+		Routes: make([]*route, 0),
+	}
+}
+
+func (r *router) Add(method, prefix string, handler routeHandler) {
+	r.Routes = append(
+		r.Routes,
+		&route{Method: method, Prefix: prefix, Handler: handler},
+	)
+}
+
+func (r *router) GET(prefix string, handler routeHandler) {
+	r.Add(http.MethodGet, prefix, handler)
+}
+
+func (r *router) OPTIONS(prefix string, handler routeHandler) {
+	r.Add(http.MethodOptions, prefix, handler)
+}
+
+func (r *router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	reqID := req.Header.Get(xRequestIDHeader)
+
+	if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
+		reqID, _ = nanoid.Nanoid()
+	}
+
+	rw.Header().Set("Server", "imgproxy")
+	rw.Header().Set(xRequestIDHeader, reqID)
+
+	defer func() {
+		if rerr := recover(); rerr != nil {
+			if err, ok := rerr.(error); ok && r.PanicHandler != nil {
+				r.PanicHandler(reqID, rw, req, err)
+			} else {
+				panic(rerr)
+			}
+		}
+	}()
+
+	logRequest(reqID, req)
+
+	for _, rr := range r.Routes {
+		if rr.Method == req.Method && strings.HasPrefix(req.URL.Path, rr.Prefix) {
+			rr.Handler(reqID, rw, req)
+			return
+		}
+	}
+
+	logWarning("Route for %s is not defined", req.URL.Path)
+
+	rw.WriteHeader(404)
+}

+ 54 - 107
server.go

@@ -8,19 +8,15 @@ import (
 	"net/http"
 	"net/url"
 	"path/filepath"
-	"regexp"
 	"strconv"
 	"strings"
 	"time"
 
-	nanoid "github.com/matoous/go-nanoid"
 	"golang.org/x/net/netutil"
 )
 
 const (
-	healthPath                         = "/health"
 	contextDispositionFilenameFallback = "image"
-	xRequestIDHeader                   = "X-Request-ID"
 )
 
 var (
@@ -40,34 +36,40 @@ var (
 		imageTypeICO:  "inline; filename=\"%s.ico\"",
 	}
 
-	authHeaderMust []byte
-
 	imgproxyIsRunningMsg = []byte("imgproxy is running")
 
 	errInvalidMethod = newError(422, "Invalid request method", "Method doesn't allowed")
 	errInvalidSecret = newError(403, "Invalid secret", "Forbidden")
 
-	requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
-
 	responseGzipBufPool *bufPool
 	responseGzipPool    *gzipPool
+
+	processingSem chan struct{}
 )
 
-type httpHandler struct {
-	sem chan struct{}
-}
+func buildRouter() *router {
+	r := newRouter()
 
-func newHTTPHandler() *httpHandler {
-	return &httpHandler{make(chan struct{}, conf.Concurrency)}
+	r.PanicHandler = handlePanic
+
+	r.GET("/health", handleHealth)
+	r.GET("/", withCORS(withSecret(handleProcessing)))
+	r.OPTIONS("/", withCORS(handleOptions))
+
+	return r
 }
 
 func startServer() *http.Server {
+	processingSem = make(chan struct{}, conf.Concurrency)
+
 	l, err := net.Listen("tcp", conf.Bind)
 	if err != nil {
 		logFatal(err.Error())
 	}
+	l = netutil.LimitListener(l, conf.MaxClients)
+
 	s := &http.Server{
-		Handler:        newHTTPHandler(),
+		Handler:        buildRouter(),
 		ReadTimeout:    time.Duration(conf.ReadTimeout) * time.Second,
 		MaxHeaderBytes: 1 << 20,
 	}
@@ -79,7 +81,7 @@ func startServer() *http.Server {
 
 	go func() {
 		logNotice("Starting server at %s", conf.Bind)
-		if err := s.Serve(netutil.LimitListener(l, conf.MaxClients)); err != nil && err != http.ErrServerClosed {
+		if err := s.Serve(l); err != nil && err != http.ErrServerClosed {
 			logFatal(err.Error())
 		}
 	}()
@@ -96,13 +98,6 @@ func shutdownServer(s *http.Server) {
 	s.Shutdown(ctx)
 }
 
-func writeCORS(rw http.ResponseWriter) {
-	if len(conf.AllowOrigin) > 0 {
-		rw.Header().Set("Access-Control-Allow-Origin", conf.AllowOrigin)
-		rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
-	}
-}
-
 func contentDisposition(imageURL string, imgtype imageType) string {
 	url, err := url.Parse(imageURL)
 	if err != nil {
@@ -183,104 +178,55 @@ func respondWithError(reqID string, rw http.ResponseWriter, err *imgproxyError)
 	}
 }
 
-func respondWithOptions(reqID string, rw http.ResponseWriter) {
-	logResponse(reqID, 200, "Respond with options")
-	rw.WriteHeader(200)
-}
-
-func respondWithHealth(reqID string, rw http.ResponseWriter) {
-	logResponse(reqID, 200, string(imgproxyIsRunningMsg))
-	rw.WriteHeader(200)
-	rw.Write(imgproxyIsRunningMsg)
-}
+func withCORS(h routeHandler) routeHandler {
+	return func(reqID string, rw http.ResponseWriter, r *http.Request) {
+		if len(conf.AllowOrigin) > 0 {
+			rw.Header().Set("Access-Control-Allow-Origin", conf.AllowOrigin)
+			rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
+		}
 
-func respondWithNotModified(reqID string, rw http.ResponseWriter) {
-	logResponse(reqID, 200, "Not modified")
-	rw.WriteHeader(304)
+		h(reqID, rw, r)
+	}
 }
 
-func generateRequestID(rw http.ResponseWriter, r *http.Request) (reqID string) {
-	reqID = r.Header.Get(xRequestIDHeader)
-
-	if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
-		reqID, _ = nanoid.Nanoid()
+func withSecret(h routeHandler) routeHandler {
+	if len(conf.Secret) == 0 {
+		return h
 	}
 
-	rw.Header().Set(xRequestIDHeader, reqID)
+	authHeader := []byte(fmt.Sprintf("Bearer %s", conf.Secret))
 
-	return
-}
-
-func prepareAuthHeaderMust() []byte {
-	if len(authHeaderMust) == 0 {
-		authHeaderMust = []byte(fmt.Sprintf("Bearer %s", conf.Secret))
+	return func(reqID string, rw http.ResponseWriter, r *http.Request) {
+		if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), authHeader) == 1 {
+			h(reqID, rw, r)
+		} else {
+			respondWithError(reqID, rw, errInvalidSecret)
+		}
 	}
-
-	return authHeaderMust
 }
 
-func checkSecret(r *http.Request) bool {
-	if len(conf.Secret) == 0 {
-		return true
-	}
+func handlePanic(reqID string, rw http.ResponseWriter, r *http.Request, err error) {
+	reportError(err, r)
 
-	return subtle.ConstantTimeCompare(
-		[]byte(r.Header.Get("Authorization")),
-		prepareAuthHeaderMust(),
-	) == 1
+	if ierr, ok := err.(*imgproxyError); ok {
+		respondWithError(reqID, rw, ierr)
+	} else {
+		respondWithError(reqID, rw, newUnexpectedError(err.Error(), 3))
+	}
 }
 
-func (h *httpHandler) lock() {
-	h.sem <- struct{}{}
+func handleHealth(reqID string, rw http.ResponseWriter, r *http.Request) {
+	logResponse(reqID, 200, string(imgproxyIsRunningMsg))
+	rw.WriteHeader(200)
+	rw.Write(imgproxyIsRunningMsg)
 }
 
-func (h *httpHandler) unlock() {
-	<-h.sem
+func handleOptions(reqID string, rw http.ResponseWriter, r *http.Request) {
+	logResponse(reqID, 200, "Respond with options")
+	rw.WriteHeader(200)
 }
 
-func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
-	rw.Header().Set("Server", "imgproxy")
-
-	reqID := generateRequestID(rw, r)
-
-	defer func() {
-		if rerr := recover(); rerr != nil {
-			if err, ok := rerr.(error); ok {
-				reportError(err, r)
-
-				if ierr, ok := err.(*imgproxyError); ok {
-					respondWithError(reqID, rw, ierr)
-				} else {
-					respondWithError(reqID, rw, newUnexpectedError(err.Error(), 3))
-				}
-			} else {
-				panic(rerr)
-			}
-		}
-	}()
-
-	logRequest(reqID, r)
-
-	writeCORS(rw)
-
-	if r.Method == http.MethodOptions {
-		respondWithOptions(reqID, rw)
-		return
-	}
-
-	if r.Method != http.MethodGet {
-		panic(errInvalidMethod)
-	}
-
-	if r.URL.Path == healthPath {
-		respondWithHealth(reqID, rw)
-		return
-	}
-
-	if !checkSecret(r) {
-		panic(errInvalidSecret)
-	}
-
+func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 	ctx := context.Background()
 
 	if newRelicEnabled {
@@ -294,8 +240,8 @@ func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
 		defer startPrometheusDuration(prometheusRequestDuration)()
 	}
 
-	h.lock()
-	defer h.unlock()
+	processingSem <- struct{}{}
+	defer func() { <-processingSem }()
 
 	ctx, timeoutCancel := startTimer(ctx, time.Duration(conf.WriteTimeout)*time.Second)
 	defer timeoutCancel()
@@ -324,7 +270,8 @@ func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
 		rw.Header().Set("ETag", eTag)
 
 		if eTag == r.Header.Get("If-None-Match") {
-			respondWithNotModified(reqID, rw)
+			logResponse(reqID, 304, "Not modified")
+			rw.WriteHeader(304)
 			return
 		}
 	}