api.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package eventstreamapi
  2. import (
  3. "fmt"
  4. "io"
  5. "github.com/aws/aws-sdk-go/aws"
  6. "github.com/aws/aws-sdk-go/private/protocol"
  7. "github.com/aws/aws-sdk-go/private/protocol/eventstream"
  8. )
  9. // Unmarshaler provides the interface for unmarshaling a EventStream
  10. // message into a SDK type.
  11. type Unmarshaler interface {
  12. UnmarshalEvent(protocol.PayloadUnmarshaler, eventstream.Message) error
  13. }
  14. // EventStream headers with specific meaning to async API functionality.
  15. const (
  16. MessageTypeHeader = `:message-type` // Identifies type of message.
  17. EventMessageType = `event`
  18. ErrorMessageType = `error`
  19. ExceptionMessageType = `exception`
  20. // Message Events
  21. EventTypeHeader = `:event-type` // Identifies message event type e.g. "Stats".
  22. // Message Error
  23. ErrorCodeHeader = `:error-code`
  24. ErrorMessageHeader = `:error-message`
  25. // Message Exception
  26. ExceptionTypeHeader = `:exception-type`
  27. )
  28. // EventReader provides reading from the EventStream of an reader.
  29. type EventReader struct {
  30. reader io.ReadCloser
  31. decoder *eventstream.Decoder
  32. unmarshalerForEventType func(string) (Unmarshaler, error)
  33. payloadUnmarshaler protocol.PayloadUnmarshaler
  34. payloadBuf []byte
  35. }
  36. // NewEventReader returns a EventReader built from the reader and unmarshaler
  37. // provided. Use ReadStream method to start reading from the EventStream.
  38. func NewEventReader(
  39. reader io.ReadCloser,
  40. payloadUnmarshaler protocol.PayloadUnmarshaler,
  41. unmarshalerForEventType func(string) (Unmarshaler, error),
  42. ) *EventReader {
  43. return &EventReader{
  44. reader: reader,
  45. decoder: eventstream.NewDecoder(reader),
  46. payloadUnmarshaler: payloadUnmarshaler,
  47. unmarshalerForEventType: unmarshalerForEventType,
  48. payloadBuf: make([]byte, 10*1024),
  49. }
  50. }
  51. // UseLogger instructs the EventReader to use the logger and log level
  52. // specified.
  53. func (r *EventReader) UseLogger(logger aws.Logger, logLevel aws.LogLevelType) {
  54. if logger != nil && logLevel.Matches(aws.LogDebugWithEventStreamBody) {
  55. r.decoder.UseLogger(logger)
  56. }
  57. }
  58. // ReadEvent attempts to read a message from the EventStream and return the
  59. // unmarshaled event value that the message is for.
  60. //
  61. // For EventStream API errors check if the returned error satisfies the
  62. // awserr.Error interface to get the error's Code and Message components.
  63. //
  64. // EventUnmarshalers called with EventStream messages must take copies of the
  65. // message's Payload. The payload will is reused between events read.
  66. func (r *EventReader) ReadEvent() (event interface{}, err error) {
  67. msg, err := r.decoder.Decode(r.payloadBuf)
  68. if err != nil {
  69. return nil, err
  70. }
  71. defer func() {
  72. // Reclaim payload buffer for next message read.
  73. r.payloadBuf = msg.Payload[0:0]
  74. }()
  75. typ, err := GetHeaderString(msg, MessageTypeHeader)
  76. if err != nil {
  77. return nil, err
  78. }
  79. switch typ {
  80. case EventMessageType:
  81. return r.unmarshalEventMessage(msg)
  82. case ExceptionMessageType:
  83. err = r.unmarshalEventException(msg)
  84. return nil, err
  85. case ErrorMessageType:
  86. return nil, r.unmarshalErrorMessage(msg)
  87. default:
  88. return nil, fmt.Errorf("unknown eventstream message type, %v", typ)
  89. }
  90. }
  91. func (r *EventReader) unmarshalEventMessage(
  92. msg eventstream.Message,
  93. ) (event interface{}, err error) {
  94. eventType, err := GetHeaderString(msg, EventTypeHeader)
  95. if err != nil {
  96. return nil, err
  97. }
  98. ev, err := r.unmarshalerForEventType(eventType)
  99. if err != nil {
  100. return nil, err
  101. }
  102. err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
  103. if err != nil {
  104. return nil, err
  105. }
  106. return ev, nil
  107. }
  108. func (r *EventReader) unmarshalEventException(
  109. msg eventstream.Message,
  110. ) (err error) {
  111. eventType, err := GetHeaderString(msg, ExceptionTypeHeader)
  112. if err != nil {
  113. return err
  114. }
  115. ev, err := r.unmarshalerForEventType(eventType)
  116. if err != nil {
  117. return err
  118. }
  119. err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
  120. if err != nil {
  121. return err
  122. }
  123. var ok bool
  124. err, ok = ev.(error)
  125. if !ok {
  126. err = messageError{
  127. code: "SerializationError",
  128. msg: fmt.Sprintf(
  129. "event stream exception %s mapped to non-error %T, %v",
  130. eventType, ev, ev,
  131. ),
  132. }
  133. }
  134. return err
  135. }
  136. func (r *EventReader) unmarshalErrorMessage(msg eventstream.Message) (err error) {
  137. var msgErr messageError
  138. msgErr.code, err = GetHeaderString(msg, ErrorCodeHeader)
  139. if err != nil {
  140. return err
  141. }
  142. msgErr.msg, err = GetHeaderString(msg, ErrorMessageHeader)
  143. if err != nil {
  144. return err
  145. }
  146. return msgErr
  147. }
  148. // Close closes the EventReader's EventStream reader.
  149. func (r *EventReader) Close() error {
  150. return r.reader.Close()
  151. }
  152. // GetHeaderString returns the value of the header as a string. If the header
  153. // is not set or the value is not a string an error will be returned.
  154. func GetHeaderString(msg eventstream.Message, headerName string) (string, error) {
  155. headerVal := msg.Headers.Get(headerName)
  156. if headerVal == nil {
  157. return "", fmt.Errorf("error header %s not present", headerName)
  158. }
  159. v, ok := headerVal.Get().(string)
  160. if !ok {
  161. return "", fmt.Errorf("error header value is not a string, %T", headerVal)
  162. }
  163. return v, nil
  164. }