Pārlūkot izejas kodu

Move panic handling out of router

DarthSim 3 gadi atpakaļ
vecāks
revīzija
86b646fe1b
2 mainītis faili ar 27 papildinājumiem un 28 dzēšanām
  1. 2 14
      router/router.go
  2. 25 14
      server.go

+ 2 - 14
router/router.go

@@ -19,7 +19,6 @@ var (
 )
 
 type RouteHandler func(string, http.ResponseWriter, *http.Request)
-type PanicHandler func(string, http.ResponseWriter, *http.Request, error)
 
 type route struct {
 	Method  string
@@ -29,9 +28,8 @@ type route struct {
 }
 
 type Router struct {
-	prefix       string
-	Routes       []*route
-	PanicHandler PanicHandler
+	prefix string
+	Routes []*route
 }
 
 func (r *route) isMatch(req *http.Request) bool {
@@ -95,16 +93,6 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 		replaceRemoteAddr(req, ip)
 	}
 
-	defer func() {
-		if rerr := recover(); rerr != nil {
-			if err, ok := rerr.(error); ok && r.PanicHandler != nil {
-				r.PanicHandler(reqID, rw, req, err)
-			} else {
-				panic(rerr)
-			}
-		}
-	}()
-
 	LogRequest(reqID, req)
 
 	for _, rr := range r.Routes {

+ 25 - 14
server.go

@@ -26,12 +26,10 @@ var (
 func buildRouter() *router.Router {
 	r := router.New(config.PathPrefix)
 
-	r.PanicHandler = handlePanic
-
 	r.GET("/", handleLanding, true)
 	r.GET("/health", handleHealth, true)
 	r.GET("/favicon.ico", handleFavicon, true)
-	r.GET("/", withCORS(withSecret(handleProcessing)), false)
+	r.GET("/", withCORS(withPanicHandler(withSecret(handleProcessing))), false)
 	r.HEAD("/", withCORS(handleHead), false)
 	r.OPTIONS("/", withCORS(handleHead), false)
 
@@ -104,21 +102,34 @@ func withSecret(h router.RouteHandler) router.RouteHandler {
 	}
 }
 
-func handlePanic(reqID string, rw http.ResponseWriter, r *http.Request, err error) {
-	ierr := ierrors.Wrap(err, 3)
+func withPanicHandler(h router.RouteHandler) router.RouteHandler {
+	return func(reqID string, rw http.ResponseWriter, r *http.Request) {
+		defer func() {
+			if rerr := recover(); rerr != nil {
+				err, ok := rerr.(error)
+				if !ok {
+					panic(rerr)
+				}
 
-	if ierr.Unexpected {
-		errorreport.Report(err, r)
-	}
+				ierr := ierrors.Wrap(err, 3)
 
-	router.LogResponse(reqID, r, ierr.StatusCode, ierr)
+				if ierr.Unexpected {
+					errorreport.Report(err, r)
+				}
 
-	rw.WriteHeader(ierr.StatusCode)
+				router.LogResponse(reqID, r, ierr.StatusCode, ierr)
 
-	if config.DevelopmentErrorsMode {
-		rw.Write([]byte(ierr.Message))
-	} else {
-		rw.Write([]byte(ierr.PublicMessage))
+				rw.WriteHeader(ierr.StatusCode)
+
+				if config.DevelopmentErrorsMode {
+					rw.Write([]byte(ierr.Message))
+				} else {
+					rw.Write([]byte(ierr.PublicMessage))
+				}
+			}
+		}()
+
+		h(reqID, rw, r)
 	}
 }