瀏覽代碼

fix streaming, change default model to llava

Alex Cheema 1 年之前
父節點
當前提交
8503543894
共有 1 個文件被更改,包括 94 次插入30 次删除
  1. 94 30
      examples/astra/astra/ContentView.swift

+ 94 - 30
examples/astra/astra/ContentView.swift

@@ -2,6 +2,7 @@ import SwiftUI
 import WhisperKit
 import AVFoundation
 import Foundation
+import Combine
 
 struct ContentView: View {
     @State private var whisperKit: WhisperKit?
@@ -18,13 +19,15 @@ struct ContentView: View {
     @State private var currentMemo = ""
     @State private var lastVoiceActivityTime = Date()
     @State private var silenceTimer: Timer?
-    @State private var voiceActivityThreshold: Float = 0.1 // Lower this value
+    @State private var voiceActivityThreshold: Float = 0.33
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
     @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true
+    @State private var streamingResponse = ""
+    @State private var cancellables = Set<AnyCancellable>()
 
     var body: some View {
         VStack {
@@ -62,6 +65,17 @@ struct ContentView: View {
             Slider(value: $voiceActivityThreshold, in: 0.01...1.0) {
                 Text("Voice Activity Threshold: \(voiceActivityThreshold, specifier: "%.2f")")
             }
+
+            Text("API Response:")
+                .font(.headline)
+                .padding(.top)
+
+            ScrollView {
+                Text(streamingResponse)
+                    .padding()
+            }
+            .frame(height: 200)
+            .border(Color.gray, width: 1)
         }
         .onAppear {
             setupWhisperKit()
@@ -250,42 +264,92 @@ struct ContentView: View {
     }
 
     private func sendMemoToAPI(_ memo: String) {
-        guard let url = URL(string: apiEndpoint) else {
-            print("Invalid API endpoint URL")
-            return
-        }
+        Task {
+            do {
+                print("Starting API request for memo: \(memo.prefix(50))...")
+
+                guard let url = URL(string: apiEndpoint) else {
+                    print("Invalid API endpoint URL: \(apiEndpoint)")
+                    return
+                }
 
-        let payload: [String: Any] = [
-            "model": "llama-3.1-8b",
-            "messages": [["role": "user", "content": memo]],
-            "temperature": 0.7,
-            "stream": true
-        ]
+                let payload: [String: Any] = [
+                    "model": "llava-1.5-7b-hf",
+                    "messages": [
+                        ["role": "system", "content": ["type": "text", "text": "You are a helpful chat assistant being used with Whisper voice transcription. Please assist the user with their queries."]],
+                        ["role": "user", "content": ["type": "text", "text": memo]]
+                    ],
+                    "temperature": 0.7,
+                    "stream": true
+                ]
+                // let payload: [String: Any] = [
+                //     "model": "llama-3.1-8b",
+                //     "messages": [["role": "system", "content": "You are a helpful chat assistant being used with Whisper voice transcription. Please assist the user with their queries."], ["role": "user", "content": memo]],
+                //     "temperature": 0.7,
+                //     "stream": true
+                // ]
+
+                guard let jsonData = try? JSONSerialization.data(withJSONObject: payload) else {
+                    print("Failed to serialize JSON payload")
+                    return
+                }
 
-        guard let jsonData = try? JSONSerialization.data(withJSONObject: payload) else {
-            print("Failed to serialize JSON payload")
-            return
-        }
+                var request = URLRequest(url: url)
+                request.httpMethod = "POST"
+                request.setValue("application/json", forHTTPHeaderField: "Content-Type")
+                request.httpBody = jsonData
 
-        var request = URLRequest(url: url)
-        request.httpMethod = "POST"
-        request.setValue("application/json", forHTTPHeaderField: "Content-Type")
-        request.httpBody = jsonData
+                print("Sending request to \(url.absoluteString)")
 
-        URLSession.shared.dataTask(with: request) { data, response, error in
-            if let error = error {
-                print("Error sending memo to API: \(error)")
-                return
-            }
+                // Reset the streaming response
+                await MainActor.run {
+                    self.streamingResponse = ""
+                }
 
-            if let httpResponse = response as? HTTPURLResponse {
-                print("API response status code: \(httpResponse.statusCode)")
+                let (bytes, response) = try await URLSession.shared.bytes(for: request)
+
+                guard let httpResponse = response as? HTTPURLResponse else {
+                    print("Invalid response")
+                    return
+                }
+
+                print("Response status code: \(httpResponse.statusCode)")
+
+                for try await line in bytes.lines {
+                    print("Received line: \(line)")
+                    await processStreamLine(line)
+                }
+
+                print("Stream completed")
+            } catch {
+                print("Error: \(error.localizedDescription)")
             }
+        }
+    }
 
-            if let data = data, let responseString = String(data: data, encoding: .utf8) {
-                print("API response: \(responseString)")
+    private func processStreamLine(_ line: String) async {
+        let jsonString: String
+        if line.hasPrefix("data: ") {
+            jsonString = String(line.dropFirst(6))
+        } else {
+            jsonString = line
+        }
+
+        if jsonString.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
+            return
+        }
+
+        if let jsonData = jsonString.data(using: .utf8),
+           let json = try? JSONSerialization.jsonObject(with: jsonData, options: []) as? [String: Any],
+           let choices = json["choices"] as? [[String: Any]],
+           let firstChoice = choices.first,
+           let delta = firstChoice["delta"] as? [String: String],
+           let content = delta["content"] {
+            print("Extracted content: \(content)")
+            await MainActor.run {
+                self.streamingResponse += content
             }
-        }.resume()
+        }
     }
 
     private func loadModel(_ model: String) async throws -> Bool {
@@ -308,4 +372,4 @@ struct ContentView: View {
             return false
         }
     }
-}
+}