middlewares.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package server
  2. import (
  3. "context"
  4. "crypto/subtle"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "github.com/imgproxy/imgproxy/v3/errorreport"
  9. "github.com/imgproxy/imgproxy/v3/httpheaders"
  10. "github.com/imgproxy/imgproxy/v3/ierrors"
  11. )
  12. const (
  13. categoryTimeout = "timeout"
  14. )
  15. // WithMonitoring wraps RouteHandler with monitoring handling.
  16. func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
  17. if !r.monitoring.Enabled() {
  18. return h
  19. }
  20. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  21. ctx, cancel, newRw := r.monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
  22. defer cancel()
  23. // Replace rw.ResponseWriter with new one returned from monitoring
  24. rw.SetHTTPResponseWriter(newRw)
  25. return h(reqID, rw, req.WithContext(ctx))
  26. }
  27. }
  28. // WithCORS wraps RouteHandler with CORS handling
  29. func (r *Router) WithCORS(h RouteHandler) RouteHandler {
  30. if len(r.config.CORSAllowOrigin) == 0 {
  31. return h
  32. }
  33. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  34. rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
  35. rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
  36. return h(reqID, rw, req)
  37. }
  38. }
  39. // WithSecret wraps RouteHandler with secret handling
  40. func (r *Router) WithSecret(h RouteHandler) RouteHandler {
  41. if len(r.config.Secret) == 0 {
  42. return h
  43. }
  44. authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
  45. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  46. if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
  47. return h(reqID, rw, req)
  48. } else {
  49. return newInvalidSecretError()
  50. }
  51. }
  52. }
  53. // WithPanic recovers panic and converts it to normal error
  54. func (r *Router) WithPanic(h RouteHandler) RouteHandler {
  55. return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
  56. defer func() {
  57. // try to recover from panic
  58. rerr := recover()
  59. if rerr == nil {
  60. return
  61. }
  62. // abort handler is an exception of net/http, we should simply repanic it.
  63. // it will supress the stack trace
  64. if rerr == http.ErrAbortHandler {
  65. panic(rerr)
  66. }
  67. // let's recover error value from panic if it has panicked with error
  68. err, ok := rerr.(error)
  69. if !ok {
  70. err = fmt.Errorf("panic: %v", err)
  71. }
  72. retErr = ierrors.Wrap(err, 1)
  73. }()
  74. return h(reqID, rw, r)
  75. }
  76. }
  77. // WithReportError handles error reporting.
  78. // It should be placed after `WithMonitoring`, but before `WithPanic`.
  79. func (r *Router) WithReportError(h RouteHandler) RouteHandler {
  80. return func(reqID string, rw ResponseWriter, req *http.Request) error {
  81. // Open the error context
  82. ctx := errorreport.StartRequest(req)
  83. req = req.WithContext(ctx)
  84. errorreport.SetMetadata(req, "Request ID", reqID)
  85. // Call the underlying handler passing the context downwards
  86. err := h(reqID, rw, req)
  87. if err == nil {
  88. return nil
  89. }
  90. // Wrap a resulting error into ierrors.Error
  91. ierr := ierrors.Wrap(err, 0)
  92. // Get the error category
  93. errCat := ierr.Category()
  94. // Exception: any context.DeadlineExceeded error is timeout
  95. if errors.Is(ierr, context.DeadlineExceeded) {
  96. errCat = categoryTimeout
  97. }
  98. // We do not need to send any canceled context
  99. if !errors.Is(ierr, context.Canceled) {
  100. r.monitoring.SendError(ctx, errCat, err)
  101. }
  102. // Report error to error collectors
  103. if ierr.ShouldReport() {
  104. r.errorReporter.Report(ierr, req)
  105. }
  106. // Log response and format the error output
  107. LogResponse(reqID, req, ierr.StatusCode(), ierr)
  108. // Error message: either is public message or full development error
  109. rw.Header().Set(httpheaders.ContentType, "text/plain")
  110. rw.WriteHeader(ierr.StatusCode())
  111. if r.config.DevelopmentErrorsMode {
  112. rw.Write([]byte(ierr.Error()))
  113. } else {
  114. rw.Write([]byte(ierr.PublicMessage()))
  115. }
  116. return nil
  117. }
  118. }