azure.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package azure
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "strconv"
  8. "strings"
  9. "github.com/Azure/azure-sdk-for-go/sdk/azcore"
  10. "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
  11. "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
  12. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
  13. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
  14. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
  15. "github.com/imgproxy/imgproxy/v3/fetcher/transport/common"
  16. "github.com/imgproxy/imgproxy/v3/fetcher/transport/notmodified"
  17. "github.com/imgproxy/imgproxy/v3/httpheaders"
  18. "github.com/imgproxy/imgproxy/v3/httprange"
  19. )
  20. type transport struct {
  21. client *azblob.Client
  22. }
  23. func New(config *Config, trans *http.Transport) (http.RoundTripper, error) {
  24. if err := config.Validate(); err != nil {
  25. return nil, err
  26. }
  27. var (
  28. client *azblob.Client
  29. sharedKeyCredential *azblob.SharedKeyCredential
  30. defaultAzureCredential *azidentity.DefaultAzureCredential
  31. err error
  32. )
  33. endpoint := config.Endpoint
  34. if len(endpoint) == 0 {
  35. endpoint = fmt.Sprintf("https://%s.blob.core.windows.net", config.Name)
  36. }
  37. endpointURL, err := url.Parse(endpoint)
  38. if err != nil {
  39. return nil, err
  40. }
  41. opts := azblob.ClientOptions{
  42. ClientOptions: policy.ClientOptions{
  43. Transport: &http.Client{Transport: trans},
  44. },
  45. }
  46. if len(config.Key) > 0 {
  47. sharedKeyCredential, err = azblob.NewSharedKeyCredential(config.Name, config.Key)
  48. if err != nil {
  49. return nil, err
  50. }
  51. client, err = azblob.NewClientWithSharedKeyCredential(endpointURL.String(), sharedKeyCredential, &opts)
  52. } else {
  53. defaultAzureCredential, err = azidentity.NewDefaultAzureCredential(nil)
  54. if err != nil {
  55. return nil, err
  56. }
  57. client, err = azblob.NewClient(endpointURL.String(), defaultAzureCredential, &opts)
  58. }
  59. if err != nil {
  60. return nil, err
  61. }
  62. return transport{client}, nil
  63. }
  64. func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
  65. container, key, _ := common.GetBucketAndKey(req.URL)
  66. if len(container) == 0 || len(key) == 0 {
  67. body := strings.NewReader("Invalid ABS URL: container name or object key is empty")
  68. return &http.Response{
  69. StatusCode: http.StatusNotFound,
  70. Proto: "HTTP/1.0",
  71. ProtoMajor: 1,
  72. ProtoMinor: 0,
  73. Header: http.Header{httpheaders.ContentType: {"text/plain"}},
  74. ContentLength: int64(body.Len()),
  75. Body: io.NopCloser(body),
  76. Close: false,
  77. Request: req,
  78. }, nil
  79. }
  80. statusCode := http.StatusOK
  81. header := make(http.Header)
  82. opts := &blob.DownloadStreamOptions{}
  83. if r := req.Header.Get(httpheaders.Range); len(r) != 0 {
  84. start, end, err := httprange.Parse(r)
  85. if err != nil {
  86. return httprange.InvalidHTTPRangeResponse(req), nil
  87. }
  88. if end != 0 {
  89. length := end - start + 1
  90. if end <= 0 {
  91. length = blockblob.CountToEnd
  92. }
  93. opts.Range = blob.HTTPRange{
  94. Offset: start,
  95. Count: length,
  96. }
  97. }
  98. statusCode = http.StatusPartialContent
  99. }
  100. result, err := t.client.DownloadStream(req.Context(), container, key, opts)
  101. if err != nil {
  102. if azError, ok := err.(*azcore.ResponseError); !ok || azError.StatusCode < 100 || azError.StatusCode == 301 {
  103. return nil, err
  104. } else {
  105. body := strings.NewReader(azError.Error())
  106. return &http.Response{
  107. StatusCode: azError.StatusCode,
  108. Proto: "HTTP/1.0",
  109. ProtoMajor: 1,
  110. ProtoMinor: 0,
  111. Header: http.Header{"Content-Type": {"text/plain"}},
  112. ContentLength: int64(body.Len()),
  113. Body: io.NopCloser(body),
  114. Close: false,
  115. Request: req,
  116. }, nil
  117. }
  118. }
  119. if result.ETag != nil {
  120. etag := string(*result.ETag)
  121. header.Set(httpheaders.Etag, etag)
  122. }
  123. if result.LastModified != nil {
  124. lastModified := result.LastModified.Format(http.TimeFormat)
  125. header.Set(httpheaders.LastModified, lastModified)
  126. }
  127. if resp := notmodified.Response(req, header); resp != nil {
  128. if result.Body != nil {
  129. result.Body.Close()
  130. }
  131. return resp, nil
  132. }
  133. header.Set(httpheaders.AcceptRanges, "bytes")
  134. contentLength := int64(0)
  135. if result.ContentLength != nil {
  136. contentLength = *result.ContentLength
  137. header.Set(httpheaders.ContentLength, strconv.FormatInt(*result.ContentLength, 10))
  138. }
  139. if result.ContentType != nil {
  140. header.Set(httpheaders.ContentType, *result.ContentType)
  141. }
  142. if result.ContentRange != nil {
  143. header.Set(httpheaders.ContentRange, *result.ContentRange)
  144. }
  145. if result.CacheControl != nil {
  146. header.Set(httpheaders.CacheControl, *result.CacheControl)
  147. }
  148. return &http.Response{
  149. StatusCode: statusCode,
  150. Proto: "HTTP/1.0",
  151. ProtoMajor: 1,
  152. ProtoMinor: 0,
  153. Header: header,
  154. ContentLength: contentLength,
  155. Body: result.Body,
  156. Close: true,
  157. Request: req,
  158. }, nil
  159. }