Browse Source

Added existing TTL check to SetIsFallbackImage

Viktor Sokolov 1 month ago
parent
commit
56342b404b
4 changed files with 13 additions and 16 deletions
  1. 1 1
      handlers/stream/handler.go
  2. 9 9
      headerwriter/writer.go
  3. 3 5
      headerwriter/writer_test.go
  4. 0 1
      httpheaders/headers.go

+ 1 - 1
handlers/stream/handler.go

@@ -114,7 +114,7 @@ func (s *request) execute(ctx context.Context) error {
 	hw.Passthrough(s.handler.config.PassthroughResponseHeaders) // NOTE: priority? This is lowest as it was
 	hw.SetContentLength(int(res.ContentLength))
 	hw.SetCanonical()
-	hw.SetForceExpires(s.po.Expires)
+	hw.SetExpires(s.po.Expires)
 	hw.Write(s.rw)
 
 	// Write Content-Disposition header

+ 9 - 9
headerwriter/writer.go

@@ -50,16 +50,16 @@ func New(config *Config, originalResponseHeaders http.Header, url string) *Write
 func (w *Writer) SetIsFallbackImage() {
 	// We set maxAge to FallbackImageTTL if it's explicitly passed
 	if w.config.FallbackImageTTL >= 0 {
-		w.maxAge = w.config.FallbackImageTTL
+		// However, we should not overwrite existing value if set (or greater than ours)
+		if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
+			w.maxAge = w.config.FallbackImageTTL
+		}
 	}
-
-	w.result.Set(httpheaders.FallbackImage, "1")
 }
 
-// SetForceExpires sets the TTL from time
-func (w *Writer) SetForceExpires(force *time.Time) {
-	// Now, if force is passed as well
-	if force == nil {
+// SetExpires sets the TTL from time
+func (w *Writer) SetExpires(expires *time.Time) {
+	if expires == nil {
 		return
 	}
 
@@ -67,8 +67,8 @@ func (w *Writer) SetForceExpires(force *time.Time) {
 	currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
 
 	// If maxAge outlives expires or was not set, we'll use expires as maxAge.
-	if w.maxAge < 0 || force.Before(currentMaxAgeTime) {
-		w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*force).Seconds())))
+	if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
+		w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
 	}
 }
 

+ 3 - 5
headerwriter/writer_test.go

@@ -162,7 +162,6 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			res: http.Header{
 				httpheaders.CacheControl:          []string{"max-age=1, public"},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
-				httpheaders.FallbackImage:         []string{"1"},
 			},
 			config: Config{
 				DefaultTTL:       3600,
@@ -183,7 +182,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				DefaultTTL: math.MaxInt32,
 			},
 			fn: func(w *Writer) {
-				w.SetForceExpires(&expires)
+				w.SetExpires(&expires)
 			},
 		},
 		{
@@ -192,7 +191,6 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			res: http.Header{
 				httpheaders.CacheControl:          []string{fmt.Sprintf("max-age=%s, public", shortExpiresSeconds)},
 				httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
-				httpheaders.FallbackImage:         []string{"1"},
 			},
 			config: Config{
 				DefaultTTL:       math.MaxInt32,
@@ -200,7 +198,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 			},
 			fn: func(w *Writer) {
 				w.SetIsFallbackImage()
-				w.SetForceExpires(&shortExpires)
+				w.SetExpires(&shortExpires)
 			},
 		},
 		{
@@ -286,7 +284,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
 				DefaultTTL: 3600,
 			},
 			fn: func(w *Writer) {
-				w.SetForceExpires(nil)
+				w.SetExpires(nil)
 			},
 		},
 		{

+ 0 - 1
httpheaders/headers.go

@@ -34,7 +34,6 @@ const (
 	Expect                          = "Expect"
 	ExpectCt                        = "Expect-Ct"
 	Expires                         = "Expires"
-	FallbackImage                   = "Fallback-Image"
 	Forwarded                       = "Forwarded"
 	Host                            = "Host"
 	IfMatch                         = "If-Match"