|
@@ -80,16 +80,15 @@ app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
|
|
|
|
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
-app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path(
|
|
|
- app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
|
|
-)
|
|
|
|
|
|
|
|
|
app.state.TOP_K = 4
|
|
|
|
|
|
app.state.sentence_transformer_ef = (
|
|
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
- model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
+ model_name=get_embedding_model_path(
|
|
|
+ app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
|
|
+ ),
|
|
|
device=DEVICE_TYPE,
|
|
|
)
|
|
|
)
|
|
@@ -130,7 +129,6 @@ async def get_embedding_model(user=Depends(get_admin_user)):
|
|
|
return {
|
|
|
"status": True,
|
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
- "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -143,43 +141,32 @@ async def update_embedding_model(
|
|
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
|
|
|
- log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
|
|
|
log.info(
|
|
|
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
|
)
|
|
|
|
|
|
- embedding_model_path = None
|
|
|
- sentence_transformer_ef = None
|
|
|
try:
|
|
|
- embedding_model_path = get_embedding_model_path(form_data.embedding_model, True)
|
|
|
- if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path:
|
|
|
- sentence_transformer_ef = (
|
|
|
- embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
- model_name=embedding_model_path,
|
|
|
- device=DEVICE_TYPE,
|
|
|
- )
|
|
|
+ sentence_transformer_ef = (
|
|
|
+ embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ model_name=get_embedding_model_path(form_data.embedding_model, True),
|
|
|
+ device=DEVICE_TYPE,
|
|
|
)
|
|
|
- except Exception as e:
|
|
|
- log.exception(f"Problem updating embedding model: {e}")
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
- detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
)
|
|
|
|
|
|
- if sentence_transformer_ef:
|
|
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
- app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path
|
|
|
app.state.sentence_transformer_ef = sentence_transformer_ef
|
|
|
|
|
|
- log.debug(
|
|
|
- f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}"
|
|
|
- )
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ }
|
|
|
|
|
|
- return {
|
|
|
- "status": sentence_transformer_ef != None,
|
|
|
- "embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
- "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
|
|
- }
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Problem updating embedding model: {e}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@app.get("/config")
|