middlewares.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package server
  2. import (
  3. "context"
  4. "crypto/subtle"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "github.com/imgproxy/imgproxy/v3/errorreport"
  9. "github.com/imgproxy/imgproxy/v3/httpheaders"
  10. "github.com/imgproxy/imgproxy/v3/ierrors"
  11. "github.com/imgproxy/imgproxy/v3/monitoring"
  12. )
  13. const (
  14. categoryTimeout = "timeout"
  15. )
  16. // WithMonitoring wraps RouteHandler with monitoring handling.
  17. func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
  18. if !monitoring.Enabled() {
  19. return h
  20. }
  21. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  22. ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
  23. defer cancel()
  24. // Replace rw.ResponseWriter with new one returned from monitoring
  25. rw.SetHTTPResponseWriter(newRw)
  26. return h(reqID, rw, req.WithContext(ctx))
  27. }
  28. }
  29. // WithCORS wraps RouteHandler with CORS handling
  30. func (r *Router) WithCORS(h RouteHandler) RouteHandler {
  31. if len(r.config.CORSAllowOrigin) == 0 {
  32. return h
  33. }
  34. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  35. rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
  36. rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
  37. return h(reqID, rw, req)
  38. }
  39. }
  40. // WithSecret wraps RouteHandler with secret handling
  41. func (r *Router) WithSecret(h RouteHandler) RouteHandler {
  42. if len(r.config.Secret) == 0 {
  43. return h
  44. }
  45. authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
  46. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  47. if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
  48. return h(reqID, rw, req)
  49. } else {
  50. return newInvalidSecretError()
  51. }
  52. }
  53. }
  54. // WithPanic recovers panic and converts it to normal error
  55. func (r *Router) WithPanic(h RouteHandler) RouteHandler {
  56. return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
  57. defer func() {
  58. // try to recover from panic
  59. rerr := recover()
  60. if rerr == nil {
  61. return
  62. }
  63. // abort handler is an exception of net/http, we should simply repanic it.
  64. // it will supress the stack trace
  65. if rerr == http.ErrAbortHandler {
  66. panic(rerr)
  67. }
  68. // let's recover error value from panic if it has panicked with error
  69. err, ok := rerr.(error)
  70. if !ok {
  71. err = fmt.Errorf("panic: %v", err)
  72. }
  73. retErr = ierrors.Wrap(err, 1)
  74. }()
  75. return h(reqID, rw, r)
  76. }
  77. }
  78. // WithReportError handles error reporting.
  79. // It should be placed after `WithMonitoring`, but before `WithPanic`.
  80. func (r *Router) WithReportError(h RouteHandler) RouteHandler {
  81. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  82. // Open the error context
  83. ctx := errorreport.StartRequest(req)
  84. req = req.WithContext(ctx)
  85. errorreport.SetMetadata(req, "Request ID", reqID)
  86. // Call the underlying handler passing the context downwards
  87. err := h(reqID, rw, req)
  88. if err == nil {
  89. return nil
  90. }
  91. // Wrap a resulting error into ierrors.Error
  92. ierr := ierrors.Wrap(err, 0)
  93. // Get the error category
  94. errCat := ierr.Category()
  95. // Exception: any context.DeadlineExceeded error is timeout
  96. if errors.Is(ierr, context.DeadlineExceeded) {
  97. errCat = categoryTimeout
  98. }
  99. // We do not need to send any canceled context
  100. if !errors.Is(ierr, context.Canceled) {
  101. monitoring.SendError(ctx, errCat, err)
  102. }
  103. // Report error to error collectors
  104. if ierr.ShouldReport() {
  105. errorreport.Report(ierr, req)
  106. }
  107. // Log response and format the error output
  108. LogResponse(reqID, req, ierr.StatusCode(), ierr)
  109. // Error message: either is public message or full development error
  110. rw.Header().Set(httpheaders.ContentType, "text/plain")
  111. rw.WriteHeader(ierr.StatusCode())
  112. if r.config.DevelopmentErrorsMode {
  113. rw.Write([]byte(ierr.Error()))
  114. } else {
  115. rw.Write([]byte(ierr.PublicMessage()))
  116. }
  117. return nil
  118. }
  119. }