azure.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package azure
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "github.com/Azure/azure-sdk-for-go/sdk/azcore"
  11. "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
  12. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
  13. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
  14. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
  15. "github.com/imgproxy/imgproxy/v3/config"
  16. "github.com/imgproxy/imgproxy/v3/ctxreader"
  17. "github.com/imgproxy/imgproxy/v3/httprange"
  18. )
  19. type transport struct {
  20. client *azblob.Client
  21. }
  22. func New() (http.RoundTripper, error) {
  23. var (
  24. client *azblob.Client
  25. sharedKeyCredential *azblob.SharedKeyCredential
  26. defaultAzureCredential *azidentity.DefaultAzureCredential
  27. err error
  28. )
  29. if len(config.ABSName) == 0 {
  30. return nil, errors.New("IMGPROXY_ABS_NAME must be set")
  31. }
  32. endpoint := config.ABSEndpoint
  33. if len(endpoint) == 0 {
  34. endpoint = fmt.Sprintf("https://%s.blob.core.windows.net", config.ABSName)
  35. }
  36. endpointURL, err := url.Parse(endpoint)
  37. if err != nil {
  38. return nil, err
  39. }
  40. if len(config.ABSKey) > 0 {
  41. sharedKeyCredential, err = azblob.NewSharedKeyCredential(config.ABSName, config.ABSKey)
  42. if err != nil {
  43. return nil, err
  44. }
  45. client, err = azblob.NewClientWithSharedKeyCredential(endpointURL.String(), sharedKeyCredential, nil)
  46. } else {
  47. defaultAzureCredential, err = azidentity.NewDefaultAzureCredential(nil)
  48. if err != nil {
  49. return nil, err
  50. }
  51. client, err = azblob.NewClient(endpointURL.String(), defaultAzureCredential, nil)
  52. }
  53. if err != nil {
  54. return nil, err
  55. }
  56. return transport{client}, nil
  57. }
  58. func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
  59. container := req.URL.Host
  60. key := req.URL.Path
  61. statusCode := http.StatusOK
  62. header := make(http.Header)
  63. opts := &blob.DownloadStreamOptions{}
  64. if r := req.Header.Get("Range"); len(r) != 0 {
  65. start, end, err := httprange.Parse(r)
  66. if err != nil {
  67. return httprange.InvalidHTTPRangeResponse(req), err
  68. }
  69. if end != 0 {
  70. length := end - start + 1
  71. if end <= 0 {
  72. length = blockblob.CountToEnd
  73. }
  74. opts.Range = blob.HTTPRange{
  75. Offset: start,
  76. Count: length,
  77. }
  78. }
  79. statusCode = http.StatusPartialContent
  80. }
  81. result, err := t.client.DownloadStream(req.Context(), container, strings.TrimPrefix(key, "/"), opts)
  82. if err != nil {
  83. if azError, ok := err.(*azcore.ResponseError); !ok || azError.StatusCode < 100 || azError.StatusCode == 301 {
  84. return nil, err
  85. } else {
  86. body := strings.NewReader(azError.Error())
  87. return &http.Response{
  88. StatusCode: azError.StatusCode,
  89. Proto: "HTTP/1.0",
  90. ProtoMajor: 1,
  91. ProtoMinor: 0,
  92. Header: header,
  93. ContentLength: int64(body.Len()),
  94. Body: io.NopCloser(body),
  95. Close: false,
  96. Request: req,
  97. }, nil
  98. }
  99. }
  100. if config.ETagEnabled && result.ETag != nil {
  101. azETag := string(*result.ETag)
  102. header.Set("ETag", azETag)
  103. if etag := req.Header.Get("If-None-Match"); len(etag) > 0 && azETag == etag {
  104. return &http.Response{
  105. StatusCode: http.StatusNotModified,
  106. Proto: "HTTP/1.0",
  107. ProtoMajor: 1,
  108. ProtoMinor: 0,
  109. Header: header,
  110. ContentLength: 0,
  111. Body: nil,
  112. Close: false,
  113. Request: req,
  114. }, nil
  115. }
  116. }
  117. header.Set("Accept-Ranges", "bytes")
  118. contentLength := int64(0)
  119. if result.ContentLength != nil {
  120. contentLength = *result.ContentLength
  121. header.Set("Content-Length", strconv.FormatInt(*result.ContentLength, 10))
  122. }
  123. if result.ContentType != nil {
  124. header.Set("Content-Type", *result.ContentType)
  125. }
  126. if result.ContentRange != nil {
  127. header.Set("Content-Range", *result.ContentRange)
  128. }
  129. if result.CacheControl != nil {
  130. header.Set("Cache-Control", *result.CacheControl)
  131. }
  132. return &http.Response{
  133. StatusCode: statusCode,
  134. Proto: "HTTP/1.0",
  135. ProtoMajor: 1,
  136. ProtoMinor: 0,
  137. Header: header,
  138. ContentLength: contentLength,
  139. Body: ctxreader.New(req.Context(), result.Body, true),
  140. Close: true,
  141. Request: req,
  142. }, nil
  143. }