|
@@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
|
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
|
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
|
|
},
|
|
|
+ "azure_openai_config": {
|
|
|
+ "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
|
|
+ "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
|
|
+ "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
|
|
|
+ "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
|
|
@@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel):
|
|
|
key: str
|
|
|
|
|
|
|
|
|
+class AzureOpenAIConfigForm(BaseModel):
|
|
|
+ url: str
|
|
|
+ key: str
|
|
|
+ deployment: str
|
|
|
+ version: str
|
|
|
+
|
|
|
+
|
|
|
class EmbeddingModelUpdateForm(BaseModel):
|
|
|
openai_config: Optional[OpenAIConfigForm] = None
|
|
|
ollama_config: Optional[OllamaConfigForm] = None
|
|
|
+ azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
|
|
embedding_engine: str
|
|
|
embedding_model: str
|
|
|
embedding_batch_size: Optional[int] = 1
|
|
@@ -271,7 +285,7 @@ async def update_embedding_config(
|
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
|
|
- if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
|
|
|
if form_data.openai_config is not None:
|
|
|
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
|
|
form_data.openai_config.url
|
|
@@ -288,6 +302,20 @@ async def update_embedding_config(
|
|
|
form_data.ollama_config.key
|
|
|
)
|
|
|
|
|
|
+ if form_data.azure_openai_config is not None:
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
|
+ form_data.azure_openai_config.url
|
|
|
+ )
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
|
+ form_data.azure_openai_config.key
|
|
|
+ )
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
|
|
|
+ form_data.azure_openai_config.deployment
|
|
|
+ )
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
|
|
|
+ form_data.azure_openai_config.version
|
|
|
+ )
|
|
|
+
|
|
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
|
|
form_data.embedding_batch_size
|
|
|
)
|
|
@@ -304,14 +332,32 @@ async def update_embedding_config(
|
|
|
(
|
|
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
|
- else request.app.state.config.RAG_OLLAMA_BASE_URL
|
|
|
+ else (
|
|
|
+ request.app.state.config.RAG_OLLAMA_BASE_URL
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
|
+ else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
|
|
+ )
|
|
|
),
|
|
|
(
|
|
|
request.app.state.config.RAG_OPENAI_API_KEY
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
|
- else request.app.state.config.RAG_OLLAMA_API_KEY
|
|
|
+ else (
|
|
|
+ request.app.state.config.RAG_OLLAMA_API_KEY
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
|
+ else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
|
|
+ )
|
|
|
),
|
|
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
|
|
+ (
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_VERSION
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
|
|
+ else None
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
return {
|
|
@@ -327,6 +373,12 @@ async def update_embedding_config(
|
|
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
|
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
|
|
},
|
|
|
+ "azure_openai_config": {
|
|
|
+ "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
|
|
+ "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
|
|
+ "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
|
|
|
+ "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
|
|
|
+ },
|
|
|
}
|
|
|
except Exception as e:
|
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
@@ -1043,14 +1095,32 @@ def save_docs_to_vector_db(
|
|
|
(
|
|
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
|
- else request.app.state.config.RAG_OLLAMA_BASE_URL
|
|
|
+ else (
|
|
|
+ request.app.state.config.RAG_OLLAMA_BASE_URL
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
|
+ else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
|
|
+ )
|
|
|
),
|
|
|
(
|
|
|
request.app.state.config.RAG_OPENAI_API_KEY
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
|
- else request.app.state.config.RAG_OLLAMA_API_KEY
|
|
|
+ else (
|
|
|
+ request.app.state.config.RAG_OLLAMA_API_KEY
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
|
+ else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
|
|
+ )
|
|
|
),
|
|
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
|
|
+ (
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ request.app.state.config.RAG_AZURE_OPENAI_VERSION
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
|
|
+ else None
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
embeddings = embedding_function(
|