middlewares.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package server
  2. import (
  3. "context"
  4. "crypto/subtle"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "github.com/imgproxy/imgproxy/v3/errorreport"
  9. "github.com/imgproxy/imgproxy/v3/httpheaders"
  10. "github.com/imgproxy/imgproxy/v3/ierrors"
  11. "github.com/imgproxy/imgproxy/v3/metrics"
  12. )
  13. const (
  14. categoryTimeout = "timeout"
  15. )
  16. // WithMetrics wraps RouteHandler with metrics handling.
  17. func (r *Router) WithMetrics(h RouteHandler) RouteHandler {
  18. if !metrics.Enabled() {
  19. return h
  20. }
  21. return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
  22. ctx, metricsCancel, rw := metrics.StartRequest(req.Context(), rw, req)
  23. defer metricsCancel()
  24. return h(reqID, rw, req.WithContext(ctx))
  25. }
  26. }
  27. // WithCORS wraps RouteHandler with CORS handling
  28. func (r *Router) WithCORS(h RouteHandler) RouteHandler {
  29. if len(r.config.CORSAllowOrigin) == 0 {
  30. return h
  31. }
  32. return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
  33. rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
  34. rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
  35. return h(reqID, rw, req)
  36. }
  37. }
  38. // WithSecret wraps RouteHandler with secret handling
  39. func (r *Router) WithSecret(h RouteHandler) RouteHandler {
  40. if len(r.config.Secret) == 0 {
  41. return h
  42. }
  43. authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
  44. return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
  45. if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
  46. return h(reqID, rw, req)
  47. } else {
  48. return newInvalidSecretError()
  49. }
  50. }
  51. }
  52. // WithPanic recovers panic and converts it to normal error
  53. func (r *Router) WithPanic(h RouteHandler) RouteHandler {
  54. return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
  55. defer func() {
  56. // try to recover from panic
  57. rerr := recover()
  58. if rerr == nil {
  59. return
  60. }
  61. // abort handler is an exception of net/http, we should simply repanic it.
  62. // it will supress the stack trace
  63. if rerr == http.ErrAbortHandler {
  64. panic(rerr)
  65. }
  66. // let's recover error value from panic if it has panicked with error
  67. err, ok := rerr.(error)
  68. if !ok {
  69. err = fmt.Errorf("panic: %v", err)
  70. }
  71. retErr = err
  72. }()
  73. return h(reqID, rw, r)
  74. }
  75. }
  76. // WithReportError handles error reporting.
  77. // It should be placed after `WithMetrics`, but before `WithPanic`.
  78. func (r *Router) WithReportError(h RouteHandler) RouteHandler {
  79. return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
  80. // Open the error context
  81. ctx := errorreport.StartRequest(req)
  82. req = req.WithContext(ctx)
  83. errorreport.SetMetadata(req, "Request ID", reqID)
  84. // Call the underlying handler passing the context downwards
  85. err := h(reqID, rw, req)
  86. if err == nil {
  87. return nil
  88. }
  89. // Wrap a resulting error into ierrors.Error
  90. ierr := ierrors.Wrap(err, 0)
  91. // Get the error category
  92. errCat := ierr.Category()
  93. // Exception: any context.DeadlineExceeded error is timeout
  94. if errors.Is(ierr, context.DeadlineExceeded) {
  95. errCat = categoryTimeout
  96. }
  97. // We do not need to send any canceled context
  98. if !errors.Is(ierr, context.Canceled) {
  99. metrics.SendError(ctx, errCat, err)
  100. }
  101. // Report error to error collectors
  102. if ierr.ShouldReport() {
  103. errorreport.Report(ierr, req)
  104. }
  105. // Log response and format the error output
  106. LogResponse(reqID, req, ierr.StatusCode(), ierr)
  107. // Error message: either is public message or full development error
  108. rw.Header().Set(httpheaders.ContentType, "text/plain")
  109. rw.WriteHeader(ierr.StatusCode())
  110. if r.config.DevelopmentErrorsMode {
  111. rw.Write([]byte(ierr.Error()))
  112. } else {
  113. rw.Write([]byte(ierr.PublicMessage()))
  114. }
  115. return nil
  116. }
  117. }