-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
333 lines (296 loc) · 14.4 KB
/
__init__.py
File metadata and controls
333 lines (296 loc) · 14.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# ComfyUI custom node for FlowMatch Euler Scheduler
#
# This node creates a FlowMatchEulerDiscreteScheduler with configurable parameters
# so it can be used with compatible sampler nodes.
#
# Also registers the scheduler in ComfyUI's scheduler list with default config.
#
# Place this file into: ComfyUI/custom_nodes/
# Then restart ComfyUI. It will show up as "FlowMatch Euler Discrete Scheduler (Custom)"
import math
import torch
import numpy as np # <-- Required for robust slicing of PyTorch tensors
try:
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
except ImportError as e:
print("=" * 80)
print("ERROR: Failed to import FlowMatchEulerDiscreteScheduler from diffusers")
print("Please ensure dependencies are installed by running:")
print(" pip install -r requirements.txt")
print("=" * 80)
raise ImportError(
"FlowMatchEulerDiscreteScheduler not found. "
"Please install dependencies from requirements.txt"
) from e
try:
from diffusers import VQDiffusionScheduler
except ImportError:
VQDiffusionScheduler = None
print("[FlowMatch Scheduler] Warning: VQDiffusionScheduler not found in diffusers.")
from comfy.samplers import SchedulerHandler, SCHEDULER_HANDLERS, SCHEDULER_NAMES
# Import Nunchaku compatibility patches (auto-applies on import)
try:
from . import nunchaku_compat
except Exception as e:
print(f"[FlowMatch Scheduler] Warning: Could not load Nunchaku compatibility: {e}")
# Default config for registering in ComfyUI
default_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
def flow_match_euler_scheduler_handler(model_sampling, steps):
scheduler = FlowMatchEulerDiscreteScheduler.from_config(default_config)
scheduler.set_timesteps(steps, device=model_sampling.device if hasattr(model_sampling, 'device') else 'cpu', mu=0.0)
sigmas = scheduler.sigmas
return sigmas
def vq_diffusion_scheduler_handler(model_sampling, steps):
if VQDiffusionScheduler is None:
raise ImportError("VQDiffusionScheduler is not available.")
# VQDiffusionScheduler requires num_vec_classes.
print("[FlowMatch Scheduler] WARNING: VQDiffusionScheduler is for discrete models (VQ-Diffusion).")
print("It does not produce 'sigmas' for continuous diffusion.")
print("Returning dummy linear sigmas to prevent crash, but sampling will likely fail with standard models.")
# Dummy initialization
# scheduler = VQDiffusionScheduler(num_vec_classes=4096, num_train_timesteps=1000)
# Return dummy sigmas
sigmas = torch.linspace(1.0, 0.0, steps + 1)
if hasattr(model_sampling, 'device'):
sigmas = sigmas.to(model_sampling.device)
return sigmas
# Register the schedulers in ComfyUI
if "FlowMatchEulerDiscreteScheduler" not in SCHEDULER_HANDLERS:
handler = SchedulerHandler(handler=flow_match_euler_scheduler_handler, use_ms=True)
SCHEDULER_HANDLERS["FlowMatchEulerDiscreteScheduler"] = handler
SCHEDULER_NAMES.append("FlowMatchEulerDiscreteScheduler")
# Explicitly add to KSampler.SCHEDULERS to ensure compatibility with nodes
# that might replace the list object (like RES4LYF)
try:
from comfy.samplers import KSampler
if "FlowMatchEulerDiscreteScheduler" not in KSampler.SCHEDULERS:
KSampler.SCHEDULERS.append("FlowMatchEulerDiscreteScheduler")
except ImportError:
pass
# if "VQDiffusionScheduler" not in SCHEDULER_HANDLERS:
# SCHEDULER_HANDLERS["VQDiffusionScheduler"] = SchedulerHandler(handler=vq_diffusion_scheduler_handler, use_ms=True)
# SCHEDULER_NAMES.append("VQDiffusionScheduler")
# try:
# from comfy.samplers import KSampler
# if "VQDiffusionScheduler" not in KSampler.SCHEDULERS:
# KSampler.SCHEDULERS.append("VQDiffusionScheduler")
# except ImportError:
# pass
class FlowMatchEulerSchedulerNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"steps": ("INT", {
"default": 9,
"min": 1,
"max": 10000,
"tooltip": "Total number of diffusion steps to generate the full sigma schedule."
}),
"start_at_step": ("INT", {
"default": 0,
"min": 0,
"max": 10000,
"tooltip": "The starting step (index) of the sigma schedule to use. Set to 0 to start at the beginning (first step)."
}),
"end_at_step": ("INT", {
"default": 9999,
"min": 0,
"max": 10000,
"tooltip": "The ending step (index) of the sigma schedule to use. Set higher than 'steps' to use all steps."
}),
"base_image_seq_len": ("INT", {
"default": 256,
"tooltip": "Base sequence length for dynamic shifting. Should match model's training resolution (e.g., 256 for 512x512 images)."
}),
"base_shift": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"tooltip": "Stabilizes generation. Higher values = more consistent/predictable outputs. Z-Image-Turbo uses default 0.5."
}),
"invert_sigmas": (["disable", "enable"], {
"default": "disable",
"tooltip": "Reverses the sigma schedule. Keep disabled unless experimenting with advanced techniques."
}),
"max_image_seq_len": ("INT", {
"default": 8192,
"tooltip": "Maximum sequence length for dynamic shifting. Affects how the scheduler adapts to large images."
}),
"max_shift": ("FLOAT", {
"default": 1.15,
"step": 0.01,
"tooltip": "Maximum variation allowed. Higher = more exaggerated/stylized results. Z-Image-Turbo uses default 1.15."
}),
"num_train_timesteps": ("INT", {
"default": 1000,
"tooltip": "Timesteps the model was trained with. Should match your model's config (typically 1000)."
}),
"shift": ("FLOAT", {
"default": 3.0,
"step": 0.01,
"tooltip": "Global timestep schedule shift. Z-Image-Turbo uses 3.0 for optimal performance with the Turbo model."
}),
"shift_terminal": ("FLOAT", {
"default": 0.0,
"step": 0.01,
"tooltip": "End value for shifted schedule. Set to 0.0 to disable. Advanced parameter for timestep schedule control."
}),
"stochastic_sampling": (["disable", "enable"], {
"default": "disable",
"tooltip": "Adds controlled randomness to each step. Enable for more varied outputs (similar to ancestral samplers)."
}),
"time_shift_type": (["exponential", "linear"], {
"default": "exponential",
"tooltip": "Method for resolution-dependent shifting. Use 'exponential' for most cases, 'linear' for experiments."
}),
"use_beta_sigmas": (["disable", "enable"], {
"default": "disable",
"tooltip": "Uses beta distribution for sigmas. Experimental alternative noise schedule."
}),
"use_dynamic_shifting": (["disable", "enable"], {
"default": "disable",
"tooltip": "Auto-adjusts timesteps based on image resolution. Z-Image-Turbo disables this for consistent Turbo performance."
}),
"use_exponential_sigmas": (["disable", "enable"], {
"default": "disable",
"tooltip": "Uses exponential sigma spacing. Try enabling for different noise distribution characteristics."
}),
"use_karras_sigmas": (["disable", "enable"], {
"default": "disable",
"tooltip": "Uses Karras noise schedule for smoother results. Similar to DPM++ samplers, often improves quality."
}),
"device": (["auto", "cuda", "cpu"], {
"default": "auto",
"tooltip": "Device for sigma computation. 'auto' detects GPU if available, otherwise CPU. Using GPU avoids CPU->GPU transfers."
}),
}
}
RETURN_TYPES = ("SIGMAS",)
RETURN_NAMES = ("sigmas",)
FUNCTION = "create"
CATEGORY = "sampling/schedulers"
DESCRIPTION = "FlowMatch Euler Discrete Scheduler with full parameter control and ability to trim the schedule (start_at_step/end_at_step)."
def create(
self,
steps,
start_at_step,
end_at_step,
base_image_seq_len,
base_shift,
invert_sigmas,
max_image_seq_len,
max_shift,
num_train_timesteps,
shift,
shift_terminal,
stochastic_sampling,
time_shift_type,
use_beta_sigmas,
use_dynamic_shifting,
use_exponential_sigmas,
use_karras_sigmas,
device="auto",
):
# Convert string combo values to boolean
config = {
"base_image_seq_len": base_image_seq_len,
"base_shift": base_shift,
"invert_sigmas": invert_sigmas == "enable",
"max_image_seq_len": max_image_seq_len,
"max_shift": max_shift,
"num_train_timesteps": num_train_timesteps,
"shift": shift,
"shift_terminal": shift_terminal if shift_terminal != 0.0 else None,
"stochastic_sampling": stochastic_sampling == "enable",
"time_shift_type": time_shift_type,
"use_beta_sigmas": use_beta_sigmas == "enable",
"use_dynamic_shifting": use_dynamic_shifting == "enable",
"use_exponential_sigmas": use_exponential_sigmas == "enable",
"use_karras_sigmas": use_karras_sigmas == "enable",
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(config)
# 1. Generate the full sigma schedule
# Determine device to use for sigma computation
if device == "auto":
# Auto-detect: use CUDA if available, otherwise CPU
target_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[FlowMatch Scheduler] Auto-detected device: {target_device.upper()}")
else:
target_device = device
print(f"[FlowMatch Scheduler] Using manually specified device: {target_device.upper()}")
# Set timesteps and get sigmas for the specified number of steps
# Using the model's device avoids unnecessary CPU->GPU transfers during sampling
scheduler.set_timesteps(steps, device=target_device, mu=0.0)
sigmas = scheduler.sigmas
# 2. Apply start_at_step and end_at_step (Slicing the sigmas tensor)
# Determine the exclusive end index for the slice
# end_at_step is the step index (e.g., 5). We use 5+1=6 for the slice end index.
end_index = min(end_at_step + 1, len(sigmas))
# Slice the tensor: [start:end]
sigmas_sliced = sigmas[start_at_step:end_index]
# Check for empty schedule resulting from slicing
if sigmas_sliced.numel() == 0:
print("Warning: start_at_step/end_at_step resulted in an empty sigma schedule. Using full schedule as fallback.")
sigmas_sliced = sigmas
return (sigmas_sliced,)
class VQDiffusionSchedulerNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"num_vec_classes": ("INT", {"default": 4096, "min": 1, "max": 65536, "tooltip": "Number of vector classes for VQ model."}),
"num_train_timesteps": ("INT", {"default": 1000}),
}
}
RETURN_TYPES = ("SIGMAS",)
RETURN_NAMES = ("sigmas",)
FUNCTION = "create"
CATEGORY = "sampling/schedulers"
DESCRIPTION = "VQ Diffusion Scheduler (Experimental). For VQ-Diffusion models. Returns dummy sigmas for compatibility."
def create(self, steps, num_vec_classes, num_train_timesteps):
if VQDiffusionScheduler is None:
raise ImportError("VQDiffusionScheduler not found.")
print("[FlowMatch Scheduler] Creating VQDiffusionScheduler (Experimental)")
print("[FlowMatch Scheduler] WARNING: Returning dummy sigmas. This scheduler is for discrete latent models.")
# We don't actually use the scheduler to generate sigmas because it can't.
# We just return the dummy sigmas.
sigmas = torch.linspace(1.0, 0.0, steps + 1)
# Default to CPU, KSampler will move it if needed or we can try to detect
# But here we don't have model context easily.
return (sigmas,)
NODE_CLASS_MAPPINGS = {
"FlowMatchEulerDiscreteScheduler (Custom)": FlowMatchEulerSchedulerNode,
# "VQDiffusionScheduler": VQDiffusionSchedulerNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FlowMatchEulerDiscreteScheduler (Custom)": "FlowMatch Euler Discrete Scheduler (Custom)",
# "VQDiffusionScheduler": "VQ Diffusion Scheduler (Experimental)",
}
from .extract_metadata_node import NODE_CLASS_MAPPINGS as METADATA_NODE_MAPPINGS
from .extract_metadata_node import NODE_DISPLAY_NAME_MAPPINGS as METADATA_DISPLAY_MAPPINGS
NODE_CLASS_MAPPINGS.update(METADATA_NODE_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(METADATA_DISPLAY_MAPPINGS)
# Import Nunchaku nodes
try:
from .nunchaku_compat import NODE_CLASS_MAPPINGS as NUNCHAKU_NODES
from .nunchaku_compat import NODE_DISPLAY_NAME_MAPPINGS as NUNCHAKU_NAMES
NODE_CLASS_MAPPINGS.update(NUNCHAKU_NODES)
NODE_DISPLAY_NAME_MAPPINGS.update(NUNCHAKU_NAMES)
except Exception as e:
print(f"[FlowMatch Scheduler] Could not load Nunchaku nodes: {e}")