code_completion.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package llm
  2. import (
  3. "context"
  4. "regexp"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "github.com/0xJacky/Nginx-UI/settings"
  9. "github.com/sashabaranov/go-openai"
  10. "github.com/uozi-tech/cosy/logger"
  11. )
  12. const (
  13. MaxTokens = 100
  14. Temperature = 1
  15. // Build system prompt and user prompt
  16. SystemPrompt = "You are a code completion assistant. " +
  17. "Complete the provided code snippet based on the context and instruction." +
  18. "[IMPORTANT] Keep the original code indentation."
  19. )
  20. // Position the cursor position
  21. type Position struct {
  22. Row int `json:"row"`
  23. Column int `json:"column"`
  24. }
  25. // CodeCompletionRequest the code completion request
  26. type CodeCompletionRequest struct {
  27. RequestID string `json:"request_id"`
  28. UserID uint64 `json:"user_id"`
  29. Context string `json:"context"`
  30. Code string `json:"code"`
  31. Suffix string `json:"suffix"`
  32. Language string `json:"language"`
  33. Position Position `json:"position"`
  34. }
  35. var (
  36. requestContext = make(map[uint64]context.CancelFunc)
  37. mutex sync.Mutex
  38. )
  39. func (c *CodeCompletionRequest) Send() (completedCode string, err error) {
  40. if cancel, ok := requestContext[c.UserID]; ok {
  41. logger.Infof("Code completion request cancelled for user %d", c.UserID)
  42. cancel()
  43. }
  44. mutex.Lock()
  45. ctx, cancel := context.WithCancel(context.Background())
  46. defer cancel()
  47. requestContext[c.UserID] = cancel
  48. mutex.Unlock()
  49. defer func() {
  50. mutex.Lock()
  51. delete(requestContext, c.UserID)
  52. mutex.Unlock()
  53. }()
  54. openaiClient, err := GetClient()
  55. if err != nil {
  56. return
  57. }
  58. // Build user prompt with code and instruction
  59. userPrompt := "Here is a file written in " + c.Language + ":\n```\n" + c.Context + "\n```\n"
  60. userPrompt += "I'm editing at row " + strconv.Itoa(c.Position.Row) + ", column " + strconv.Itoa(c.Position.Column) + ".\n"
  61. userPrompt += "Code before cursor:\n```\n" + c.Code + "\n```\n"
  62. if c.Suffix != "" {
  63. userPrompt += "Code after cursor:\n```\n" + c.Suffix + "\n```\n"
  64. }
  65. userPrompt += "Instruction: Only provide the completed code that should be inserted at the cursor position without explanations. " +
  66. "The code should be syntactically correct and follow best practices for " + c.Language + "."
  67. messages := []openai.ChatCompletionMessage{
  68. {
  69. Role: openai.ChatMessageRoleSystem,
  70. Content: SystemPrompt,
  71. },
  72. {
  73. Role: openai.ChatMessageRoleUser,
  74. Content: userPrompt,
  75. },
  76. }
  77. req := openai.ChatCompletionRequest{
  78. Model: settings.OpenAISettings.GetCodeCompletionModel(),
  79. Messages: messages,
  80. MaxTokens: MaxTokens,
  81. Temperature: Temperature,
  82. }
  83. // Make a direct (non-streaming) call to the API
  84. response, err := openaiClient.CreateChatCompletion(ctx, req)
  85. if err != nil {
  86. return
  87. }
  88. completedCode = response.Choices[0].Message.Content
  89. // extract the last word of the code
  90. lastWord := extractLastWord(c.Code)
  91. completedCode = cleanCompletionResponse(completedCode, lastWord)
  92. logger.Infof("Code completion response: %s", completedCode)
  93. return
  94. }
  95. // extractLastWord extract the last word of the code
  96. func extractLastWord(code string) string {
  97. if code == "" {
  98. return ""
  99. }
  100. // define a regex to match word characters (letters, numbers, underscores)
  101. re := regexp.MustCompile(`[a-zA-Z0-9_]+$`)
  102. // find the last word of the code
  103. match := re.FindString(code)
  104. return match
  105. }
  106. // cleanCompletionResponse removes any <think></think> tags and their content from the completion response
  107. // and strips the already entered code from the completion
  108. func cleanCompletionResponse(response string, lastWord string) (cleanResp string) {
  109. // remove <think></think> tags and their content using regex
  110. re := regexp.MustCompile(`<think>[\s\S]*?</think>`)
  111. cleanResp = re.ReplaceAllString(response, "")
  112. // remove markdown code block tags
  113. codeBlockRegex := regexp.MustCompile("```(?:[a-zA-Z]+)?\n((?:.|\n)*?)\n```")
  114. matches := codeBlockRegex.FindStringSubmatch(cleanResp)
  115. if len(matches) > 1 {
  116. // extract the code block content
  117. cleanResp = strings.TrimSpace(matches[1])
  118. } else {
  119. // if no code block is found, keep the original response
  120. cleanResp = strings.TrimSpace(cleanResp)
  121. }
  122. // remove markdown backticks
  123. cleanResp = strings.Trim(cleanResp, "`")
  124. // if there is a last word, and the completion result starts with the last word, remove the already entered part
  125. if lastWord != "" && strings.HasPrefix(cleanResp, lastWord) {
  126. cleanResp = cleanResp[len(lastWord):]
  127. }
  128. return
  129. }