openai.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package openai
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. "time"
  9. "github.com/0xJacky/Nginx-UI/internal/llm"
  10. "github.com/0xJacky/Nginx-UI/settings"
  11. "github.com/gin-gonic/gin"
  12. "github.com/sashabaranov/go-openai"
  13. "github.com/uozi-tech/cosy"
  14. "github.com/uozi-tech/cosy/logger"
  15. )
  16. const ChatGPTInitPrompt = `You are a assistant who can help users write and optimise the configurations of Nginx,
  17. the first user message contains the content of the configuration file which is currently opened by the user and
  18. the current language code(CLC). You suppose to use the language corresponding to the CLC to give the first reply.
  19. Later the language environment depends on the user message.
  20. The first reply should involve the key information of the file and ask user what can you help them.`
  21. func MakeChatCompletionRequest(c *gin.Context) {
  22. var json struct {
  23. Filepath string `json:"filepath"`
  24. Messages []openai.ChatCompletionMessage `json:"messages"`
  25. }
  26. if !cosy.BindAndValid(c, &json) {
  27. return
  28. }
  29. messages := []openai.ChatCompletionMessage{
  30. {
  31. Role: openai.ChatMessageRoleSystem,
  32. Content: ChatGPTInitPrompt,
  33. },
  34. }
  35. messages = append(messages, json.Messages...)
  36. if json.Filepath != "" {
  37. messages = llm.ChatCompletionWithContext(json.Filepath, messages)
  38. }
  39. // SSE server
  40. c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
  41. c.Writer.Header().Set("Cache-Control", "no-cache")
  42. c.Writer.Header().Set("Connection", "keep-alive")
  43. c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
  44. openaiClient, err := llm.GetClient()
  45. if err != nil {
  46. c.Stream(func(w io.Writer) bool {
  47. c.SSEvent("message", gin.H{
  48. "type": "error",
  49. "content": err.Error(),
  50. })
  51. return false
  52. })
  53. return
  54. }
  55. ctx := context.Background()
  56. req := openai.ChatCompletionRequest{
  57. Model: settings.OpenAISettings.Model,
  58. Messages: messages,
  59. Stream: true,
  60. }
  61. stream, err := openaiClient.CreateChatCompletionStream(ctx, req)
  62. if err != nil {
  63. logger.Errorf("CompletionStream error: %v\n", err)
  64. c.Stream(func(w io.Writer) bool {
  65. c.SSEvent("message", gin.H{
  66. "type": "error",
  67. "content": err.Error(),
  68. })
  69. return false
  70. })
  71. return
  72. }
  73. defer stream.Close()
  74. msgChan := make(chan string)
  75. go func() {
  76. defer close(msgChan)
  77. messageCh := make(chan string)
  78. // 消息接收协程
  79. go func() {
  80. defer close(messageCh)
  81. for {
  82. response, err := stream.Recv()
  83. if errors.Is(err, io.EOF) {
  84. return
  85. }
  86. if err != nil {
  87. messageCh <- fmt.Sprintf("error: %v", err)
  88. logger.Errorf("Stream error: %v\n", err)
  89. return
  90. }
  91. messageCh <- response.Choices[0].Delta.Content
  92. }
  93. }()
  94. ticker := time.NewTicker(500 * time.Millisecond)
  95. defer ticker.Stop()
  96. var buffer strings.Builder
  97. for {
  98. select {
  99. case msg, ok := <-messageCh:
  100. if !ok {
  101. if buffer.Len() > 0 {
  102. msgChan <- buffer.String()
  103. }
  104. return
  105. }
  106. if strings.HasPrefix(msg, "error: ") {
  107. msgChan <- msg
  108. return
  109. }
  110. buffer.WriteString(msg)
  111. case <-ticker.C:
  112. if buffer.Len() > 0 {
  113. msgChan <- buffer.String()
  114. buffer.Reset()
  115. }
  116. }
  117. }
  118. }()
  119. c.Stream(func(w io.Writer) bool {
  120. m, ok := <-msgChan
  121. if !ok {
  122. return false
  123. }
  124. if strings.HasPrefix(m, "error: ") {
  125. c.SSEvent("message", gin.H{
  126. "type": "error",
  127. "content": strings.TrimPrefix(m, "error: "),
  128. })
  129. return false
  130. }
  131. c.SSEvent("message", gin.H{
  132. "type": "message",
  133. "content": m,
  134. })
  135. return true
  136. })
  137. }