main.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.middleware.cors import CORSMiddleware
  4. import logging
  5. from fastapi import FastAPI, Request, Depends, status, Response
  6. from fastapi.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  8. from starlette.responses import StreamingResponse
  9. import json
  10. from utils.utils import get_http_authorization_cred, get_current_user
  11. from config import SRC_LOG_LEVELS, ENV
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["LITELLM"])
  14. from config import (
  15. MODEL_FILTER_ENABLED,
  16. MODEL_FILTER_LIST,
  17. )
  18. import asyncio
  19. import subprocess
  20. app = FastAPI()
  21. origins = ["*"]
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=origins,
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. async def run_background_process(command):
  30. process = await asyncio.create_subprocess_exec(
  31. *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
  32. )
  33. return process
  34. async def start_litellm_background():
  35. # Command to run in the background
  36. command = "litellm --config ./data/litellm/config.yaml"
  37. await run_background_process(command)
  38. @app.on_event("startup")
  39. async def startup_event():
  40. asyncio.create_task(start_litellm_background())
  41. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  42. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  43. @app.middleware("http")
  44. async def auth_middleware(request: Request, call_next):
  45. auth_header = request.headers.get("Authorization", "")
  46. request.state.user = None
  47. try:
  48. user = get_current_user(get_http_authorization_cred(auth_header))
  49. log.debug(f"user: {user}")
  50. request.state.user = user
  51. except Exception as e:
  52. return JSONResponse(status_code=400, content={"detail": str(e)})
  53. response = await call_next(request)
  54. return response
  55. @app.get("/")
  56. async def get_status():
  57. return {"status": True}
  58. class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  59. async def dispatch(
  60. self, request: Request, call_next: RequestResponseEndpoint
  61. ) -> Response:
  62. response = await call_next(request)
  63. user = request.state.user
  64. if "/models" in request.url.path:
  65. if isinstance(response, StreamingResponse):
  66. # Read the content of the streaming response
  67. body = b""
  68. async for chunk in response.body_iterator:
  69. body += chunk
  70. data = json.loads(body.decode("utf-8"))
  71. if app.state.MODEL_FILTER_ENABLED:
  72. if user and user.role == "user":
  73. data["data"] = list(
  74. filter(
  75. lambda model: model["id"]
  76. in app.state.MODEL_FILTER_LIST,
  77. data["data"],
  78. )
  79. )
  80. # Modified Flag
  81. data["modified"] = True
  82. return JSONResponse(content=data)
  83. return response
  84. app.add_middleware(ModifyModelsResponseMiddleware)
  85. # from litellm.proxy.proxy_server import ProxyConfig, initialize
  86. # from litellm.proxy.proxy_server import app
  87. # proxy_config = ProxyConfig()
  88. # async def config():
  89. # router, model_list, general_settings = await proxy_config.load_config(
  90. # router=None, config_file_path="./data/litellm/config.yaml"
  91. # )
  92. # await initialize(config="./data/litellm/config.yaml", telemetry=False)
  93. # async def startup():
  94. # await config()
  95. # @app.on_event("startup")
  96. # async def on_startup():
  97. # await startup()