Browse Source

Refactor tracing

Svyatoslav Kryukov 4 years ago
parent
commit
5ba463749c
8 changed files with 72 additions and 69 deletions
  1. 3 5
      bufpool.go
  2. 2 8
      download.go
  3. 2 4
      main.go
  4. 26 12
      newrelic.go
  5. 2 8
      process.go
  6. 9 21
      processing_handler.go
  7. 26 4
      prometheus.go
  8. 2 7
      timer.go

+ 3 - 5
bufpool.go

@@ -65,10 +65,8 @@ func (p *bufPool) calibrateAndClean() {
 		runtime.GC()
 	}
 
-	if prometheusEnabled {
-		setPrometheusBufferDefaultSize(p.name, p.defaultSize)
-		setPrometheusBufferMaxSize(p.name, p.maxSize)
-	}
+	setPrometheusBufferDefaultSize(p.name, p.defaultSize)
+	setPrometheusBufferMaxSize(p.name, p.maxSize)
 }
 
 func (p *bufPool) Get(size int) *bytes.Buffer {
@@ -143,7 +141,7 @@ func (p *bufPool) Put(buf *bytes.Buffer) {
 		if b == nil {
 			p.buffers[i] = buf
 
-			if prometheusEnabled && buf.Cap() > 0 {
+			if buf.Cap() > 0 {
 				observePrometheusBufferSize(p.name, buf.Cap())
 			}
 

+ 2 - 8
download.go

@@ -214,14 +214,8 @@ func downloadImage(imageURL string) (*imageData, error) {
 func downloadImageCtx(ctx context.Context) (context.Context, context.CancelFunc, error) {
 	imageURL := getImageURL(ctx)
 
-	if newRelicEnabled {
-		newRelicCancel := startNewRelicSegment(ctx, "Downloading image")
-		defer newRelicCancel()
-	}
-
-	if prometheusEnabled {
-		defer startPrometheusDuration(prometheusDownloadDuration)()
-	}
+	defer startNewRelicSegment(ctx, "Downloading image")()
+	defer startPrometheusDuration(prometheusDownloadDuration)()
 
 	imgdata, err := downloadImage(imageURL)
 	if err != nil {

+ 2 - 4
main.go

@@ -73,10 +73,8 @@ func run() error {
 
 	ctx, cancel := context.WithCancel(context.Background())
 
-	if prometheusEnabled {
-		if err := startPrometheusServer(cancel); err != nil {
-			return err
-		}
+	if err := startPrometheusServer(cancel); err != nil {
+		return err
 	}
 
 	s, err := startServer(cancel)

+ 26 - 12
newrelic.go

@@ -9,12 +9,14 @@ import (
 	"github.com/newrelic/go-agent/v3/newrelic"
 )
 
+const (
+	newRelicTransactionCtxKey = ctxKey("newRelicTransaction")
+)
+
 var (
 	newRelicEnabled = false
 
 	newRelicApp *newrelic.Application
-
-	newRelicTransactionCtxKey = ctxKey("newRelicTransaction")
 )
 
 func initNewrelic() error {
@@ -44,6 +46,10 @@ func initNewrelic() error {
 }
 
 func startNewRelicTransaction(ctx context.Context, rw http.ResponseWriter, r *http.Request) (context.Context, context.CancelFunc, http.ResponseWriter) {
+	if !newRelicEnabled {
+		return ctx, func() {}, rw
+	}
+
 	txn := newRelicApp.StartTransaction("request")
 	txn.SetWebRequestHTTP(r)
 	newRw := txn.SetWebResponse(rw)
@@ -52,23 +58,31 @@ func startNewRelicTransaction(ctx context.Context, rw http.ResponseWriter, r *ht
 }
 
 func startNewRelicSegment(ctx context.Context, name string) context.CancelFunc {
+	if !newRelicEnabled {
+		return func() {}
+	}
+
 	txn := ctx.Value(newRelicTransactionCtxKey).(*newrelic.Transaction)
 	segment := txn.StartSegment(name)
 	return func() { segment.End() }
 }
 
 func sendErrorToNewRelic(ctx context.Context, err error) {
-	txn := ctx.Value(newRelicTransactionCtxKey).(*newrelic.Transaction)
-	txn.NoticeError(err)
+	if newRelicEnabled {
+		txn := ctx.Value(newRelicTransactionCtxKey).(*newrelic.Transaction)
+		txn.NoticeError(err)
+	}
 }
 
 func sendTimeoutToNewRelic(ctx context.Context, d time.Duration) {
-	txn := ctx.Value(newRelicTransactionCtxKey).(*newrelic.Transaction)
-	txn.NoticeError(newrelic.Error{
-		Message: "Timeout",
-		Class:   "Timeout",
-		Attributes: map[string]interface{}{
-			"time": d.Seconds(),
-		},
-	})
+	if newRelicEnabled {
+		txn := ctx.Value(newRelicTransactionCtxKey).(*newrelic.Transaction)
+		txn.NoticeError(newrelic.Error{
+			Message: "Timeout",
+			Class:   "Timeout",
+			Attributes: map[string]interface{}{
+				"time": d.Seconds(),
+			},
+		})
+	}
 }

+ 2 - 8
process.go

@@ -762,14 +762,8 @@ func processImage(ctx context.Context) ([]byte, context.CancelFunc, error) {
 	runtime.LockOSThread()
 	defer runtime.UnlockOSThread()
 
-	if newRelicEnabled {
-		newRelicCancel := startNewRelicSegment(ctx, "Processing image")
-		defer newRelicCancel()
-	}
-
-	if prometheusEnabled {
-		defer startPrometheusDuration(prometheusProcessingDuration)()
-	}
+	defer startNewRelicSegment(ctx, "Processing image")()
+	defer startPrometheusDuration(prometheusProcessingDuration)()
 
 	defer vipsCleanup()
 

+ 9 - 21
processing_handler.go

@@ -121,16 +121,12 @@ func respondWithNotModified(ctx context.Context, reqID string, r *http.Request,
 func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 	ctx := r.Context()
 
-	if newRelicEnabled {
-		var newRelicCancel context.CancelFunc
-		ctx, newRelicCancel, rw = startNewRelicTransaction(ctx, rw, r)
-		defer newRelicCancel()
-	}
+	var newRelicCancel context.CancelFunc
+	ctx, newRelicCancel, rw = startNewRelicTransaction(ctx, rw, r)
+	defer newRelicCancel()
 
-	if prometheusEnabled {
-		prometheusRequestsTotal.Inc()
-		defer startPrometheusDuration(prometheusRequestDuration)()
-	}
+	incrementPrometheusRequestsTotal()
+	defer startPrometheusDuration(prometheusRequestDuration)()
 
 	select {
 	case processingSem <- struct{}{}:
@@ -150,12 +146,8 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 	ctx, downloadcancel, err := downloadImageCtx(ctx)
 	defer downloadcancel()
 	if err != nil {
-		if newRelicEnabled {
-			sendErrorToNewRelic(ctx, err)
-		}
-		if prometheusEnabled {
-			incrementPrometheusErrorsTotal("download")
-		}
+		sendErrorToNewRelic(ctx, err)
+		incrementPrometheusErrorsTotal("download")
 
 		if fallbackImage == nil {
 			panic(err)
@@ -202,12 +194,8 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
 	imageData, processcancel, err := processImage(ctx)
 	defer processcancel()
 	if err != nil {
-		if newRelicEnabled {
-			sendErrorToNewRelic(ctx, err)
-		}
-		if prometheusEnabled {
-			incrementPrometheusErrorsTotal("processing")
-		}
+		sendErrorToNewRelic(ctx, err)
+		incrementPrometheusErrorsTotal("processing")
 		panic(err)
 	}
 

+ 26 - 4
prometheus.go

@@ -116,6 +116,10 @@ func initPrometheus() {
 }
 
 func startPrometheusServer(cancel context.CancelFunc) error {
+	if !prometheusEnabled {
+		return nil
+	}
+
 	s := http.Server{Handler: promhttp.Handler()}
 
 	l, err := listenReuseport("tcp", conf.PrometheusBind)
@@ -135,6 +139,10 @@ func startPrometheusServer(cancel context.CancelFunc) error {
 }
 
 func startPrometheusDuration(m prometheus.Histogram) func() {
+	if !prometheusEnabled {
+		return func() {}
+	}
+
 	t := time.Now()
 	return func() {
 		m.Observe(time.Since(t).Seconds())
@@ -142,17 +150,31 @@ func startPrometheusDuration(m prometheus.Histogram) func() {
 }
 
 func incrementPrometheusErrorsTotal(t string) {
-	prometheusErrorsTotal.With(prometheus.Labels{"type": t}).Inc()
+	if prometheusEnabled {
+		prometheusErrorsTotal.With(prometheus.Labels{"type": t}).Inc()
+	}
+}
+
+func incrementPrometheusRequestsTotal() {
+	if prometheusEnabled {
+		prometheusRequestsTotal.Inc()
+	}
 }
 
 func observePrometheusBufferSize(t string, size int) {
-	prometheusBufferSize.With(prometheus.Labels{"type": t}).Observe(float64(size))
+	if prometheusEnabled {
+		prometheusBufferSize.With(prometheus.Labels{"type": t}).Observe(float64(size))
+	}
 }
 
 func setPrometheusBufferDefaultSize(t string, size int) {
-	prometheusBufferDefaultSize.With(prometheus.Labels{"type": t}).Set(float64(size))
+	if prometheusEnabled {
+		prometheusBufferDefaultSize.With(prometheus.Labels{"type": t}).Set(float64(size))
+	}
 }
 
 func setPrometheusBufferMaxSize(t string, size int) {
-	prometheusBufferMaxSize.With(prometheus.Labels{"type": t}).Set(float64(size))
+	if prometheusEnabled {
+		prometheusBufferMaxSize.With(prometheus.Labels{"type": t}).Set(float64(size))
+	}
 }

+ 2 - 7
timer.go

@@ -25,13 +25,8 @@ func checkTimeout(ctx context.Context) {
 			panic(newError(499, fmt.Sprintf("Request was cancelled after %v", d), "Cancelled"))
 		}
 
-		if newRelicEnabled {
-			sendTimeoutToNewRelic(ctx, d)
-		}
-
-		if prometheusEnabled {
-			incrementPrometheusErrorsTotal("timeout")
-		}
+		sendTimeoutToNewRelic(ctx, d)
+		incrementPrometheusErrorsTotal("timeout")
 
 		panic(newError(503, fmt.Sprintf("Timeout after %v", d), "Timeout"))
 	default: