1
0
Эх сурвалжийг харах

added support for the new Flux image gen model using ComfyUI

this commit adds three environment variables:

- COMFYUI_FLUX: determines whether Flux is used, the workflow is completely different so this is necessary.
- COMFYUI_FLUX_WEIGHT_DTYPE: sets the weight precision for Flux. you will probably want to set this to "fp8_e4m3fn" as the fp16 weights take up about 24GB of VRAM. optional, defaults to "default".
- COMFYUI_FLUX_FP8_CLIP: Flux requires two CLIP models downloaded, one of which is available in fp8 and fp16. set to true if you are using the fp8 CLIP weights.
John Karabudak 11 сар өмнө
parent
commit
ad6e8edcd3

+ 15 - 1
backend/apps/images/main.py

@@ -42,6 +42,9 @@ from config import (
     COMFYUI_SAMPLER,
     COMFYUI_SCHEDULER,
     COMFYUI_SD3,
+    COMFYUI_FLUX,
+    COMFYUI_FLUX_WEIGHT_DTYPE,
+    COMFYUI_FLUX_FP8_CLIP,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_KEY,
     IMAGE_GENERATION_MODEL,
@@ -85,7 +88,9 @@ app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
 app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
 app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
 app.state.config.COMFYUI_SD3 = COMFYUI_SD3
-
+app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
+app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
+app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
 
 def get_automatic1111_api_auth():
     if app.state.config.AUTOMATIC1111_API_AUTH == None:
@@ -497,6 +502,15 @@ async def image_generations(
             if app.state.config.COMFYUI_SD3 is not None:
                 data["sd3"] = app.state.config.COMFYUI_SD3
 
+            if app.state.config.COMFYUI_FLUX is not None:
+                data["flux"] = app.state.config.COMFYUI_FLUX
+
+            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
+                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE
+
+            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
+                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP
+
             data = ImageGenerationPayload(**data)
 
             res = comfyui_generate_image(

+ 163 - 8
backend/apps/images/utils/comfyui.py

@@ -125,6 +125,135 @@ COMFYUI_DEFAULT_PROMPT = """
 }
 """
 
+FLUX_DEFAULT_PROMPT = """
+{
+    "5": {
+        "inputs": {
+            "width": 1024,
+            "height": 1024,
+            "batch_size": 1
+        },
+        "class_type": "EmptyLatentImage"
+    },
+    "6": {
+        "inputs": {
+            "text": "Input Text Here",
+            "clip": [
+                "11",
+                0
+            ]
+        },
+        "class_type": "CLIPTextEncode"
+    },
+    "8": {
+        "inputs": {
+            "samples": [
+                "13",
+                0
+            ],
+            "vae": [
+                "10",
+                0
+            ]
+        },
+        "class_type": "VAEDecode"
+    },
+    "9": {
+        "inputs": {
+            "filename_prefix": "ComfyUI",
+            "images": [
+                "8",
+                0
+            ]
+        },
+        "class_type": "SaveImage"
+    },
+    "10": {
+        "inputs": {
+            "vae_name": "ae.sft"
+        },
+        "class_type": "VAELoader"
+    },
+    "11": {
+        "inputs": {
+            "clip_name1": "clip_l.safetensors",
+            "clip_name2": "t5xxl_fp16.safetensors",
+            "type": "flux"
+        },
+        "class_type": "DualCLIPLoader"
+    },
+    "12": {
+        "inputs": {
+            "unet_name": "flux1-dev.sft",
+            "weight_dtype": "default"
+        },
+        "class_type": "UNETLoader"
+    },
+    "13": {
+        "inputs": {
+            "noise": [
+                "25",
+                0
+            ],
+            "guider": [
+                "22",
+                0
+            ],
+            "sampler": [
+                "16",
+                0
+            ],
+            "sigmas": [
+                "17",
+                0
+            ],
+            "latent_image": [
+                "5",
+                0
+            ]
+        },
+        "class_type": "SamplerCustomAdvanced"
+    },
+    "16": {
+        "inputs": {
+            "sampler_name": "euler"
+        },
+        "class_type": "KSamplerSelect"
+    },
+    "17": {
+        "inputs": {
+            "scheduler": "simple",
+            "steps": 20,
+            "denoise": 1,
+            "model": [
+                "12",
+                0
+            ]
+        },
+        "class_type": "BasicScheduler"
+    },
+    "22": {
+        "inputs": {
+            "model": [
+                "12",
+                0
+            ],
+            "conditioning": [
+                "6",
+                0
+            ]
+        },
+        "class_type": "BasicGuider"
+    },
+    "25": {
+        "inputs": {
+            "noise_seed": 778937779713005
+        },
+        "class_type": "RandomNoise"
+    }
+}
+"""
+
 
 def queue_prompt(prompt, client_id, base_url):
     log.info("queue_prompt")
@@ -194,6 +323,9 @@ class ImageGenerationPayload(BaseModel):
     sampler: Optional[str] = None
     scheduler: Optional[str] = None
     sd3: Optional[bool] = None
+    flux: Optional[bool] = None
+    flux_weight_dtype: Optional[str] = None
+    flux_fp8_clip: Optional[bool] = None
 
 
 def comfyui_generate_image(
@@ -215,21 +347,44 @@ def comfyui_generate_image(
     if payload.sd3:
         comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
 
+    if payload.steps:
+        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
+
     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
+    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
+    comfyui_prompt["3"]["inputs"]["seed"] = (
+        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+    )
+
+    # as Flux uses a completely different workflow, we must treat it specially
+    if payload.flux:
+        comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT)
+        comfyui_prompt["12"]["inputs"]["unet_name"] = model
+        comfyui_prompt["25"]["inputs"]["noise_seed"] = (
+            payload.seed if payload.seed else random.randint(0, 18446744073709551614)
+        )
+
+        if payload.sampler:
+            comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler
+
+        if payload.steps:
+            comfyui_prompt["17"]["inputs"]["steps"] = payload.steps
+
+        if payload.scheduler:
+            comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler
+
+        if payload.flux_weight_dtype:
+            comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype
+        
+        if payload.flux_fp8_clip:
+            comfyui_prompt["11"]["inputs"]["clip_name2"] = "t5xxl_fp8_e4m3fn.safetensors"
+
     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
     comfyui_prompt["5"]["inputs"]["width"] = payload.width
     comfyui_prompt["5"]["inputs"]["height"] = payload.height
 
     # set the text prompt for our positive CLIPTextEncode
     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
-    comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
-
-    if payload.steps:
-        comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
-
-    comfyui_prompt["3"]["inputs"]["seed"] = (
-        payload.seed if payload.seed else random.randint(0, 18446744073709551614)
-    )
 
     try:
         ws = websocket.WebSocket()

+ 18 - 0
backend/config.py

@@ -1302,6 +1302,24 @@ COMFYUI_SD3 = PersistentConfig(
     os.environ.get("COMFYUI_SD3", "").lower() == "true",
 )
 
+COMFYUI_FLUX = PersistentConfig(
+    "COMFYUI_FLUX",
+    "image_generation.comfyui.flux",
+    os.environ.get("COMFYUI_FLUX", "").lower() == "true",
+)
+
+COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig(
+    "COMFYUI_FLUX_WEIGHT_DTYPE",
+    "image_generation.comfyui.flux_weight_dtype",
+    os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""),
+)
+
+COMFYUI_FLUX_FP8_CLIP = PersistentConfig(
+    "COMFYUI_FLUX_FP8_CLIP",
+    "image_generation.comfyui.flux_fp8_clip",
+    os.getenv("COMFYUI_FLUX_FP8_CLIP", ""),
+)
+
 IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
     "IMAGES_OPENAI_API_BASE_URL",
     "image_generation.openai.api_base_url",