Parcourir la source

pref(ChatGPT): advanced streaming and text transformation, support reasoner model

Jacky il y a 3 mois
Parent
commit
2af29eb80f

+ 56 - 16
api/openai/openai.go

@@ -2,15 +2,17 @@ package openai
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"github.com/0xJacky/Nginx-UI/internal/chatbot"
 	"github.com/0xJacky/Nginx-UI/settings"
 	"github.com/gin-gonic/gin"
-	"errors"
 	"github.com/sashabaranov/go-openai"
 	"github.com/uozi-tech/cosy"
 	"github.com/uozi-tech/cosy/logger"
 	"io"
+	"strings"
+	"time"
 )
 
 const ChatGPTInitPrompt = `You are a assistant who can help users write and optimise the configurations of Nginx,
@@ -83,31 +85,69 @@ func MakeChatCompletionRequest(c *gin.Context) {
 	msgChan := make(chan string)
 	go func() {
 		defer close(msgChan)
-		for {
-			response, err := stream.Recv()
-			if errors.Is(err, io.EOF) {
-				return
-			}
+		messageCh := make(chan string)
 
-			if err != nil {
-				logger.Errorf("Stream error: %v\n", err)
-				return
+		// 消息接收协程
+		go func() {
+			defer close(messageCh)
+			for {
+				response, err := stream.Recv()
+				if errors.Is(err, io.EOF) {
+					return
+				}
+				if err != nil {
+					messageCh <- fmt.Sprintf("error: %v", err)
+					logger.Errorf("Stream error: %v\n", err)
+					return
+				}
+				messageCh <- response.Choices[0].Delta.Content
 			}
+		}()
+
+		ticker := time.NewTicker(500 * time.Millisecond)
+		defer ticker.Stop()
 
-			message := fmt.Sprintf("%s", response.Choices[0].Delta.Content)
+		var buffer strings.Builder
 
-			msgChan <- message
+		for {
+			select {
+			case msg, ok := <-messageCh:
+				if !ok {
+					if buffer.Len() > 0 {
+						msgChan <- buffer.String()
+					}
+					return
+				}
+				if strings.HasPrefix(msg, "error: ") {
+					msgChan <- msg
+					return
+				}
+				buffer.WriteString(msg)
+			case <-ticker.C:
+				if buffer.Len() > 0 {
+					msgChan <- buffer.String()
+					buffer.Reset()
+				}
+			}
 		}
 	}()
 
 	c.Stream(func(w io.Writer) bool {
-		if m, ok := <-msgChan; ok {
+		m, ok := <-msgChan
+		if !ok {
+			return false
+		}
+		if strings.HasPrefix(m, "error: ") {
 			c.SSEvent("message", gin.H{
-				"type":    "message",
-				"content": m,
+				"type":    "error",
+				"content": strings.TrimPrefix(m, "error: "),
 			})
-			return true
+			return false
 		}
-		return false
+		c.SSEvent("message", gin.H{
+			"type":    "message",
+			"content": m,
+		})
+		return true
 	})
 }

+ 16 - 0
app/.idea/inspectionProfiles/Project_Default.xml

@@ -2,5 +2,21 @@
   <profile version="1.0">
     <option name="myName" value="Project Default" />
     <inspection_tool class="Eslint" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
+    <inspection_tool class="HtmlUnknownTag" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="myValues">
+        <value>
+          <list size="7">
+            <item index="0" class="java.lang.String" itemvalue="nobr" />
+            <item index="1" class="java.lang.String" itemvalue="noembed" />
+            <item index="2" class="java.lang.String" itemvalue="comment" />
+            <item index="3" class="java.lang.String" itemvalue="noscript" />
+            <item index="4" class="java.lang.String" itemvalue="embed" />
+            <item index="5" class="java.lang.String" itemvalue="script" />
+            <item index="6" class="java.lang.String" itemvalue="think" />
+          </list>
+        </value>
+      </option>
+      <option name="myCustomValuesEnabled" value="true" />
+    </inspection_tool>
   </profile>
 </component>

+ 188 - 76
app/src/components/ChatGPT/ChatGPT.vue

@@ -30,149 +30,241 @@ const messages = defineModel<ChatComplicationMessage[]>('historyMessages', {
 const loading = ref(false)
 const askBuffer = ref('')
 
+// Global buffer for accumulation
+let buffer = ''
+
+// Track last chunk to avoid immediate repeated content
+let lastChunkStr = ''
+
+// 定义一个用于跟踪代码块状态的类型
+interface CodeBlockState {
+  isInCodeBlock: boolean
+  backtickCount: number
+}
+
+const codeBlockState: CodeBlockState = reactive({
+  isInCodeBlock: false, // if in ``` code block
+  backtickCount: 0, // count of ```
+})
+
+/**
+ * transformReasonerThink: if <think> appears but is not paired with </think>, it will be automatically supplemented, and the entire text will be converted to a Markdown quote
+ */
+function transformReasonerThink(rawText: string): string {
+  // 1. Count number of <think> vs </think>
+  const openThinkRegex = /<think>/gi
+  const closeThinkRegex = /<\/think>/gi
+
+  const openCount = (rawText.match(openThinkRegex) || []).length
+  const closeCount = (rawText.match(closeThinkRegex) || []).length
+
+  // 2. If open tags exceed close tags, append missing </think> at the end
+  if (openCount > closeCount) {
+    const diff = openCount - closeCount
+    rawText += '</think>'.repeat(diff)
+  }
+
+  // 3. Replace <think>...</think> blocks with Markdown blockquote ("> ...")
+  return rawText.replace(/<think>([\s\S]*?)<\/think>/g, (match, p1) => {
+    // Split the inner text by line, prefix each with "> "
+    const lines = p1.trim().split('\n')
+    const blockquoted = lines.map(line => `> ${line}`).join('\n')
+    // Return the replaced Markdown quote
+    return `\n${blockquoted}\n`
+  })
+}
+
+/**
+ * transformText: transform the text
+ */
+function transformText(rawText: string): string {
+  return transformReasonerThink(rawText)
+}
+
+/**
+ * scrollToBottom: Scroll container to bottom
+ */
+function scrollToBottom() {
+  const container = document.querySelector('.right-settings .ant-card-body')
+  if (container)
+    container.scrollTop = container.scrollHeight
+}
+
+/**
+ * updateCodeBlockState: The number of unnecessary scans is reduced by changing the scanning method of incremental content
+ */
+function updateCodeBlockState(chunk: string) {
+  // count all ``` in chunk
+  // note to distinguish how many "backticks" are not paired
+
+  const regex = /```/g
+
+  while (regex.exec(chunk) !== null) {
+    codeBlockState.backtickCount++
+    // if backtickCount is even -> closed
+    codeBlockState.isInCodeBlock = codeBlockState.backtickCount % 2 !== 0
+  }
+}
+
+/**
+ * applyChunk: Process one SSE chunk and type out content character by character
+ * @param input   A chunk of data (Uint8Array) from SSE
+ * @param targetMsg  The assistant-type message object being updated
+ */
+
+async function applyChunk(input: Uint8Array, targetMsg: ChatComplicationMessage) {
+  const decoder = new TextDecoder('utf-8')
+  const raw = decoder.decode(input)
+  // SSE default split by segment
+  const lines = raw.split('\n\n')
+
+  for (const line of lines) {
+    if (!line.startsWith('event:message\ndata:'))
+      continue
+
+    const dataStr = line.slice('event:message\ndata:'.length)
+    if (!dataStr)
+      continue
+
+    const content = JSON.parse(dataStr).content as string
+    if (!content || content.trim() === '')
+      continue
+    if (content === lastChunkStr)
+      continue
+
+    lastChunkStr = content
+
+    // Only detect substrings
+    // 1. This can be processed in batches according to actual needs, reducing the number of character processing times
+    updateCodeBlockState(content)
+
+    for (const c of content) {
+      buffer += c
+      // codeBlockState.isInCodeBlock check if in code block
+      targetMsg.content = buffer
+      await nextTick()
+      await new Promise(resolve => setTimeout(resolve, 20))
+      scrollToBottom()
+    }
+  }
+}
+
+/**
+ * request: Send messages to server, receive SSE, and process by typing out chunk by chunk
+ */
 async function request() {
   loading.value = true
 
-  const t = ref({
+  // Add an "assistant" message object
+  const t = ref<ChatComplicationMessage>({
     role: 'assistant',
     content: '',
   })
 
-  const user = useUserStore()
+  messages.value.push(t.value)
 
-  const { token } = storeToRefs(user)
-
-  messages.value = [...messages.value!, t.value]
+  // Reset buffer flags each time
+  buffer = ''
+  lastChunkStr = ''
 
   await nextTick()
-
   scrollToBottom()
 
+  const user = useUserStore()
+  const { token } = storeToRefs(user)
+
   const res = await fetch(urlJoin(window.location.pathname, '/api/chatgpt'), {
     method: 'POST',
-    headers: { Accept: 'text/event-stream', Authorization: token.value },
-    body: JSON.stringify({ filepath: props.path, messages: messages.value?.slice(0, messages.value?.length - 1) }),
+    headers: {
+      Accept: 'text/event-stream',
+      Authorization: token.value,
+    },
+    body: JSON.stringify({
+      filepath: props.path,
+      messages: messages.value.slice(0, messages.value.length - 1),
+    }),
   })
 
-  const reader = res.body!.getReader()
-
-  let buffer = ''
+  if (!res.body) {
+    loading.value = false
+    return
+  }
 
-  let hasCodeBlockIndicator = false
+  const reader = res.body.getReader()
 
   while (true) {
     try {
       const { done, value } = await reader.read()
       if (done) {
+        // SSE stream ended
         setTimeout(() => {
           scrollToBottom()
-        }, 500)
-        loading.value = false
-        storeRecord()
+        }, 300)
         break
       }
-      apply(value!)
+      if (value) {
+        // Process each chunk
+        await applyChunk(value, t.value)
+      }
     }
     catch {
+      // In case of error
       break
     }
   }
 
-  function apply(input: Uint8Array) {
-    const decoder = new TextDecoder('utf-8')
-    const raw = decoder.decode(input)
-
-    // console.log(input, raw)
-
-    const line = raw.split('\n\n')
-
-    line?.forEach(v => {
-      const data = v.slice('event:message\ndata:'.length)
-      if (!data)
-        return
-
-      const content = JSON.parse(data).content
-
-      if (!hasCodeBlockIndicator)
-        hasCodeBlockIndicator = content.includes('`')
-
-      for (const c of content) {
-        buffer += c
-        if (hasCodeBlockIndicator) {
-          if (isCodeBlockComplete(buffer)) {
-            t.value.content = buffer
-            hasCodeBlockIndicator = false
-          }
-          else {
-            t.value.content = `${buffer}\n\`\`\``
-          }
-        }
-        else {
-          t.value.content = buffer
-        }
-      }
-
-      // keep container scroll to bottom
-      scrollToBottom()
-    })
-  }
-
-  function isCodeBlockComplete(text: string) {
-    const codeBlockRegex = /```/g
-    const matches = text.match(codeBlockRegex)
-    if (matches)
-      return matches.length % 2 === 0
-    else
-      return true
-  }
-
-  function scrollToBottom() {
-    const container = document.querySelector('.right-settings .ant-card-body')
-    if (container)
-      container.scrollTop = container.scrollHeight
-  }
+  loading.value = false
+  storeRecord()
 }
 
+/**
+ * send: Add user message into messages then call request
+ */
 async function send() {
   if (!messages.value)
     messages.value = []
 
   if (messages.value.length === 0) {
+    // The first message
     messages.value = [{
       role: 'user',
       content: `${props.content}\n\nCurrent Language Code: ${current.value}`,
     }]
   }
   else {
-    messages.value = [...messages.value, {
+    // Append user's new message
+    messages.value.push({
       role: 'user',
       content: askBuffer.value,
-    }]
+    })
     askBuffer.value = ''
   }
 
   await nextTick()
-
   await request()
 }
 
+// Markdown renderer
 const marked = new Marked(
   markedHighlight({
     langPrefix: 'hljs language-',
     highlight(code, lang) {
       const language = hljs.getLanguage(lang) ? lang : 'nginx'
-
       return hljs.highlight(code, { language }).value
     },
   }),
 )
 
+// Basic marked options
 marked.setOptions({
   pedantic: false,
   gfm: true,
   breaks: false,
 })
 
+/**
+ * storeRecord: Save chat history
+ */
 function storeRecord() {
   openai.store_record({
     file_name: props.path,
@@ -180,6 +272,9 @@ function storeRecord() {
   })
 }
 
+/**
+ * clearRecord: Clears all messages
+ */
 function clearRecord() {
   openai.store_record({
     file_name: props.path,
@@ -188,16 +283,23 @@ function clearRecord() {
   messages.value = []
 }
 
+// Manage editing
 const editingIdx = ref(-1)
 
+/**
+ * regenerate: Removes messages after index and re-request the answer
+ */
 async function regenerate(index: number) {
   editingIdx.value = -1
-  messages.value = messages.value?.slice(0, index)
+  messages.value = messages.value.slice(0, index)
   await nextTick()
   await request()
 }
 
-const show = computed(() => !messages.value || messages.value?.length === 0)
+/**
+ * show: If empty, display start button
+ */
+const show = computed(() => !messages.value || messages.value.length === 0)
 </script>
 
 <template>
@@ -216,6 +318,7 @@ const show = computed(() => !messages.value || messages.value?.length === 0)
       {{ $gettext('Ask ChatGPT for Help') }}
     </AButton>
   </div>
+
   <div
     v-else
     class="chatgpt-container"
@@ -231,7 +334,7 @@ const show = computed(() => !messages.value || messages.value?.length === 0)
             <template #content>
               <div
                 v-if="item.role === 'assistant' || editingIdx !== index"
-                v-dompurify-html="marked.parse(item.content)"
+                v-dompurify-html="marked.parse(transformText(item.content))"
                 class="content"
               />
               <AInput
@@ -263,6 +366,7 @@ const show = computed(() => !messages.value || messages.value?.length === 0)
         </AListItem>
       </template>
     </AList>
+
     <div class="input-msg">
       <div class="control-btn">
         <ASpace v-show="!loading">
@@ -314,6 +418,14 @@ const show = computed(() => !messages.value || messages.value?.length === 0)
       :deep(.hljs) {
         border-radius: 5px;
       }
+
+      :deep(blockquote) {
+        display: block;
+        opacity: 0.6;
+        margin: 0.5em 0;
+        padding-left: 1em;
+        border-left: 3px solid #ccc;
+      }
     }
 
     :deep(.ant-list-item) {

+ 9 - 1
app/src/main.ts

@@ -16,7 +16,15 @@ const app = createApp(App)
 pinia.use(piniaPluginPersistedstate)
 app.use(pinia)
 app.use(gettext)
-app.use(VueDOMPurifyHTML)
+app.use(VueDOMPurifyHTML, {
+  hooks: {
+    uponSanitizeElement: (node, data) => {
+      if (node.tagName && node.tagName.toLowerCase() === 'think') {
+        data.allowedTags.think = true
+      }
+    },
+  },
+})
 
 // after pinia created
 const settings = useSettingsStore()