Преглед изворни кода

added Stable Diffusion 3 support alongside ComfyUI configuration

this commit adds four environment variables:

- COMFYUI_CFG_SCALE
- COMFYUI_SAMPLER
- COMFYUI_SCHEDULER
- COMFYUI_SD3 (merely setting this at all will enable SD3 mode)
John Karabudak пре 10 месеци
родитељ
комит
ea074fa9bf
3 измењених фајлова са 60 додато и 1 уклоњено
  1. 20 0
      backend/apps/images/main.py
  2. 16 1
      backend/apps/images/utils/comfyui.py
  3. 24 0
      backend/config.py

+ 20 - 0
backend/apps/images/main.py

@@ -37,6 +37,10 @@ from config import (
     ENABLE_IMAGE_GENERATION,
     AUTOMATIC1111_BASE_URL,
     COMFYUI_BASE_URL,
+    COMFYUI_CFG_SCALE,
+    COMFYUI_SAMPLER,
+    COMFYUI_SCHEDULER,
+    COMFYUI_SD3,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_KEY,
     IMAGE_GENERATION_MODEL,
@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
 
 app.state.config.IMAGE_SIZE = IMAGE_SIZE
 app.state.config.IMAGE_STEPS = IMAGE_STEPS
+app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
+app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
+app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
+app.state.config.COMFYUI_SD3 = COMFYUI_SD3
 
 
 @app.get("/config")
@@ -457,6 +465,18 @@ def generate_image(
             if form_data.negative_prompt is not None:
                 data["negative_prompt"] = form_data.negative_prompt
 
+            if app.state.config.COMFYUI_CFG_SCALE:
+                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
+
+            if app.state.config.COMFYUI_SAMPLER is not None:
+                data["sampler"] = app.state.config.COMFYUI_SAMPLER
+
+            if app.state.config.COMFYUI_SCHEDULER is not None:
+                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
+
+            if app.state.config.COMFYUI_SD3 is not None:
+                data["sd3"] = app.state.config.COMFYUI_SD3
+
             data = ImageGenerationPayload(**data)
 
             res = comfyui_generate_image(

+ 16 - 1
backend/apps/images/utils/comfyui.py

@@ -190,7 +190,10 @@ class ImageGenerationPayload(BaseModel):
     width: int
     height: int
     n: int = 1
-
+    cfg_scale: Optional[float] = None
+    sampler: Optional[str] = None
+    scheduler: Optional[str] = None
+    sd3: Optional[bool] = None
 
 def comfyui_generate_image(
     model: str, payload: ImageGenerationPayload, client_id, base_url
@@ -199,6 +202,18 @@ def comfyui_generate_image(
 
     comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
 
+    if payload.cfg_scale:
+        comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
+
+    if payload.sampler:
+        comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
+
+    if payload.scheduler:
+        comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
+
+    if payload.sd3:
+        comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
+
     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
     comfyui_prompt["5"]["inputs"]["width"] = payload.width

+ 24 - 0
backend/config.py

@@ -1000,6 +1000,30 @@ COMFYUI_BASE_URL = PersistentConfig(
     os.getenv("COMFYUI_BASE_URL", ""),
 )
 
+COMFYUI_CFG_SCALE = PersistentConfig(
+    "COMFYUI_CFG_SCALE",
+    "image_generation.comfyui.cfg_scale",
+    os.getenv("COMFYUI_CFG_SCALE", ""),
+)
+
+COMFYUI_SAMPLER = PersistentConfig(
+    "COMFYUI_SAMPLER",
+    "image_generation.comfyui.sampler",
+    os.getenv("COMFYUI_SAMPLER", ""),
+)
+
+COMFYUI_SCHEDULER = PersistentConfig(
+    "COMFYUI_SCHEDULER",
+    "image_generation.comfyui.scheduler",
+    os.getenv("COMFYUI_SCHEDULER", ""),
+)
+
+COMFYUI_SD3 = PersistentConfig(
+    "COMFYUI_SD3",
+    "image_generation.comfyui.sd3",
+    os.environ.get("COMFYUI_SD3", "").lower() == "true",
+)
+
 IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
     "IMAGES_OPENAI_API_BASE_URL",
     "image_generation.openai.api_base_url",