writer.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package responsewriter
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/imgproxy/imgproxy/v3/httpheaders"
  10. )
  11. // Just aliases for [http.ResponseWriter] and [http.ResponseController].
  12. // We need them to make them private in [Writer] so they can't be accessed directly.
  13. type httpResponseWriter = http.ResponseWriter
  14. type httpResponseController = *http.ResponseController
  15. // Writer is an implementation of [http.ResponseWriter] with additional
  16. // functionality for managing response headers.
  17. type Writer struct {
  18. httpResponseWriter
  19. httpResponseController
  20. config *Config // Configuration for the writer
  21. originHeaders http.Header // Original response headers
  22. result http.Header // Headers to be written to the response
  23. maxAge int // Current max age for Cache-Control header
  24. beforeWriteOnce sync.Once
  25. }
  26. // HTTPResponseWriter returns the underlying http.ResponseWriter.
  27. func (w *Writer) HTTPResponseWriter() http.ResponseWriter {
  28. return w.httpResponseWriter
  29. }
  30. // SetHTTPResponseWriter replaces the underlying http.ResponseWriter.
  31. func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) {
  32. w.httpResponseWriter = rw
  33. w.httpResponseController = http.NewResponseController(rw)
  34. }
  35. // SetOriginHeaders sets the origin headers for the request.
  36. func (w *Writer) SetOriginHeaders(h http.Header) {
  37. w.originHeaders = h
  38. }
  39. // SetIsFallbackImage sets the Fallback-Image header to
  40. // indicate that the fallback image was used.
  41. func (w *Writer) SetIsFallbackImage() {
  42. // We set maxAge to FallbackImageTTL if it's explicitly passed
  43. if w.config.FallbackImageTTL < 0 {
  44. return
  45. }
  46. // However, we should not overwrite existing value if set (or greater than ours)
  47. if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
  48. w.maxAge = w.config.FallbackImageTTL
  49. }
  50. }
  51. // SetExpires sets the TTL from time
  52. func (w *Writer) SetExpires(expires *time.Time) {
  53. if expires == nil {
  54. return
  55. }
  56. // Convert current maxAge to time
  57. currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
  58. // If maxAge outlives expires or was not set, we'll use expires as maxAge.
  59. if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
  60. w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
  61. }
  62. }
  63. // SetVary sets the Vary header
  64. func (w *Writer) SetVary() {
  65. if val := w.config.VaryValue; len(val) > 0 {
  66. w.result.Set(httpheaders.Vary, val)
  67. }
  68. }
  69. // SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
  70. func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
  71. value := httpheaders.ContentDispositionValue(
  72. originURL,
  73. filename,
  74. ext,
  75. contentType,
  76. returnAttachment,
  77. )
  78. if value != "" {
  79. w.result.Set(httpheaders.ContentDisposition, value)
  80. }
  81. }
  82. // Passthrough copies specified headers from the original response headers to the response headers.
  83. func (w *Writer) Passthrough(only ...string) {
  84. httpheaders.Copy(w.originHeaders, w.result, only)
  85. }
  86. // CopyFrom copies specified headers from the headers object. Please note that
  87. // all the past operations may overwrite those values.
  88. func (w *Writer) CopyFrom(headers http.Header, only []string) {
  89. httpheaders.Copy(headers, w.result, only)
  90. }
  91. // SetContentLength sets the Content-Length header
  92. func (w *Writer) SetContentLength(contentLength int) {
  93. if contentLength < 0 {
  94. return
  95. }
  96. w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
  97. }
  98. // SetContentType sets the Content-Type header
  99. func (w *Writer) SetContentType(mime string) {
  100. w.result.Set(httpheaders.ContentType, mime)
  101. }
  102. // writeCanonical sets the Link header with the canonical URL.
  103. // It is mandatory for any response if enabled in the configuration.
  104. func (w *Writer) SetCanonical(url string) {
  105. if !w.config.SetCanonicalHeader {
  106. return
  107. }
  108. if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
  109. value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
  110. w.result.Set(httpheaders.Link, value)
  111. }
  112. }
  113. // setCacheControl sets the Cache-Control header with the specified value.
  114. func (w *Writer) setCacheControl(value int) bool {
  115. if value <= 0 {
  116. return false
  117. }
  118. w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
  119. return true
  120. }
  121. // setCacheControlNoCache sets the Cache-Control header to no-cache (default).
  122. func (w *Writer) setCacheControlNoCache() {
  123. w.result.Set(httpheaders.CacheControl, "no-cache")
  124. }
  125. // setCacheControlPassthrough sets the Cache-Control header from the request
  126. // if passthrough is enabled in the configuration.
  127. func (w *Writer) setCacheControlPassthrough() bool {
  128. if !w.config.CacheControlPassthrough || w.maxAge > 0 {
  129. return false
  130. }
  131. if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" {
  132. w.result.Set(httpheaders.CacheControl, val)
  133. return true
  134. }
  135. if val := w.originHeaders.Get(httpheaders.Expires); val != "" {
  136. if t, err := time.Parse(http.TimeFormat, val); err == nil {
  137. maxAge := max(0, int(time.Until(t).Seconds()))
  138. return w.setCacheControl(maxAge)
  139. }
  140. }
  141. return false
  142. }
  143. // setCSP sets the Content-Security-Policy header to prevent script execution.
  144. func (w *Writer) setCSP() {
  145. w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
  146. }
  147. // flushHeaders writes the headers to the response writer. It does not overwrite
  148. // target headers, which were set outside the header writer.
  149. func (w *Writer) flushHeaders() {
  150. // Then, let's try to set Cache-Control using priority order
  151. switch {
  152. case w.setCacheControl(w.maxAge): // First, try set explicit
  153. case w.setCacheControlPassthrough(): // Try to pick up from request headers
  154. case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value
  155. default:
  156. w.setCacheControlNoCache() // By default we use no-cache
  157. }
  158. w.setCSP()
  159. // Copy all headers to the response without overwriting existing ones
  160. httpheaders.CopyAll(w.result, w.Header(), false)
  161. }
  162. // beforeWrite is called before [WriteHeader] and [Write]
  163. func (w *Writer) beforeWrite() {
  164. w.beforeWriteOnce.Do(func() {
  165. // We're going to start writing response.
  166. // Set write deadline.
  167. w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout))
  168. // Flush headers before we write anything
  169. w.flushHeaders()
  170. })
  171. }
  172. // WriteHeader writes the HTTP response header.
  173. //
  174. // It ensures that all headers are flushed before writing the status code.
  175. func (w *Writer) WriteHeader(statusCode int) {
  176. w.beforeWrite()
  177. w.httpResponseWriter.WriteHeader(statusCode)
  178. }
  179. // Write writes the HTTP response body.
  180. //
  181. // It ensures that all headers are flushed before writing the body.
  182. func (w *Writer) Write(b []byte) (int, error) {
  183. w.beforeWrite()
  184. return w.httpResponseWriter.Write(b)
  185. }