瀏覽代碼

add chatgpt-api-response-timeout-secs flag, set this to 20 mins in test

Alex Cheema 9 月之前
父節點
當前提交
e49924e1b9
共有 3 個文件被更改,包括 28 次插入10 次删除
  1. 24 7
      .github/workflows/test.yml
  2. 2 2
      exo/api/chatgpt_api.py
  3. 2 1
      main.py

+ 24 - 7
.github/workflows/test.yml

@@ -78,11 +78,11 @@ jobs:
     - name: Run chatgpt api integration test
     - name: Run chatgpt api integration test
       run: |
       run: |
         # Start first instance
         # Start first instance
-        DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+        DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 1200 > output1.log 2>&1 &
         PID1=$!
         PID1=$!
 
 
         # Start second instance
         # Start second instance
-        DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+        DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 1200 > output2.log 2>&1 &
         PID2=$!
         PID2=$!
 
 
         # Wait for discovery
         # Wait for discovery
@@ -96,22 +96,39 @@ jobs:
               "messages": [{"role": "user", "content": "Placeholder to load model..."}],
               "messages": [{"role": "user", "content": "Placeholder to load model..."}],
               "temperature": 0.7
               "temperature": 0.7
             }'
             }'
+        curl -s http://localhost:8001/v1/chat/completions \
+            -H "Content-Type: application/json" \
+            -d '{
+              "model": "llama-3-8b",
+              "messages": [{"role": "user", "content": "Placeholder to load model..."}],
+              "temperature": 0.7
+            }'
+
+        response_1=$(curl -s http://localhost:8000/v1/chat/completions \
+          -H "Content-Type: application/json" \
+          -d '{
+            "model": "llama-3-8b",
+            "messages": [{"role": "user", "content": "Who was the king of pop?"}],
+            "temperature": 0.7
+          }')
+        echo "Response 1: $response_1"
 
 
-        response=$(curl -s http://localhost:8000/v1/chat/completions \
+        response_2=$(curl -s http://localhost:8000/v1/chat/completions \
           -H "Content-Type: application/json" \
           -H "Content-Type: application/json" \
           -d '{
           -d '{
             "model": "llama-3-8b",
             "model": "llama-3-8b",
             "messages": [{"role": "user", "content": "Who was the king of pop?"}],
             "messages": [{"role": "user", "content": "Who was the king of pop?"}],
             "temperature": 0.7
             "temperature": 0.7
           }')
           }')
-        echo "Response: $response"
+        echo "Response 2: $response_2"
 
 
-        if ! echo "$response" | grep -q "Michael Jackson"; then
+        if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
           echo "Test failed: Response does not contain 'Michael Jackson'"
           echo "Test failed: Response does not contain 'Michael Jackson'"
-          echo "Response: $response"
+          echo "Response 1: $response_1"
+          echo "Response 2: $response_2"
           exit 1
           exit 1
         else
         else
-          echo "Test passed: Response contains 'Michael Jackson'"
+          echo "Test passed: Response from both nodes contains 'Michael Jackson'"
         fi
         fi
 
 
         # Stop both instances
         # Stop both instances

+ 2 - 2
exo/api/chatgpt_api.py

@@ -117,10 +117,10 @@ def build_prompt(tokenizer, messages: List[Message]):
 
 
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
-    def __init__(self, node: Node, inference_engine_classname: str):
+    def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
         self.node = node
         self.node = node
         self.inference_engine_classname = inference_engine_classname
         self.inference_engine_classname = inference_engine_classname
-        self.response_timeout_secs = 90
+        self.response_timeout_secs = response_timeout_secs
         self.app = web.Application()
         self.app = web.Application()
         self.prev_token_lens: Dict[str, int] = {}
         self.prev_token_lens: Dict[str, int] = {}
         self.stream_tasks: Dict[str, asyncio.Task] = {}
         self.stream_tasks: Dict[str, asyncio.Task] = {}

+ 2 - 1
main.py

@@ -22,6 +22,7 @@ parser.add_argument("--listen-port", type=int, default=5678, help="Listening por
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
@@ -57,7 +58,7 @@ discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.b
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__)
+api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 
 
 node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))