v1.0.0
This commit is contained in:
243
skills/creative/comfyui/scripts/run_batch.py
Executable file
243
skills/creative/comfyui/scripts/run_batch.py
Executable file
@@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
run_batch.py — Run a workflow many times, varying parameters per run.
|
||||
|
||||
Two modes:
|
||||
1. --count N --randomize-seed
|
||||
Submit N runs, each with a fresh random seed. Use for quick variations.
|
||||
2. --sweep '{"seed": [1,2,3], "steps": [20,30]}'
|
||||
Cartesian product of values. With cloud subscription, runs in parallel
|
||||
up to your tier's concurrent-job limit.
|
||||
|
||||
Both modes write each run's outputs into output-dir/run_NNN/.
|
||||
|
||||
Examples:
|
||||
python3 run_batch.py --workflow flux_dev.json \
|
||||
--args '{"prompt": "a cat"}' \
|
||||
--count 8 --randomize-seed \
|
||||
--output-dir ./outputs/cat-batch
|
||||
|
||||
python3 run_batch.py --workflow sdxl.json \
|
||||
--args '{"prompt": "abstract"}' \
|
||||
--sweep '{"seed": [1,2,3], "steps": [20, 40]}' \
|
||||
--output-dir ./outputs/sweep
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
from _common import ( # noqa: E402
|
||||
DEFAULT_LOCAL_HOST, ENV_API_KEY, coerce_seed, emit_json, log,
|
||||
looks_like_video_workflow, resolve_api_key, unwrap_workflow,
|
||||
)
|
||||
from run_workflow import ( # noqa: E402
|
||||
ComfyRunner, download_outputs, inject_params,
|
||||
)
|
||||
from extract_schema import extract_schema # noqa: E402
|
||||
|
||||
|
||||
def expand_sweep(sweep: dict, base_args: dict, count: int, randomize_seed: bool) -> list[dict]:
|
||||
"""Generate a list of args dicts for each run."""
|
||||
if sweep:
|
||||
# Cartesian product
|
||||
keys = list(sweep.keys())
|
||||
values = [sweep[k] if isinstance(sweep[k], list) else [sweep[k]] for k in keys]
|
||||
runs = []
|
||||
for combo in itertools.product(*values):
|
||||
ar = dict(base_args)
|
||||
for k, v in zip(keys, combo):
|
||||
ar[k] = v
|
||||
runs.append(ar)
|
||||
return runs
|
||||
# Count mode
|
||||
runs = []
|
||||
for _ in range(count):
|
||||
ar = dict(base_args)
|
||||
if randomize_seed:
|
||||
ar["seed"] = coerce_seed(None)
|
||||
runs.append(ar)
|
||||
return runs
|
||||
|
||||
|
||||
def execute_one(
|
||||
runner: ComfyRunner, workflow: dict, schema: dict, args: dict,
|
||||
*, output_dir: Path, timeout: int, ws: bool,
|
||||
) -> dict:
|
||||
wf, warnings = inject_params(workflow, schema, args)
|
||||
sub = runner.submit(wf)
|
||||
if "_http_error" in sub:
|
||||
return {"status": "error", "error": "submission HTTP error",
|
||||
"details": sub.get("body"), "args": args}
|
||||
pid = sub.get("prompt_id")
|
||||
if not pid:
|
||||
return {"status": "error", "error": "no prompt_id", "response": sub, "args": args}
|
||||
if sub.get("node_errors"):
|
||||
return {"status": "error", "error": "validation failed",
|
||||
"node_errors": sub["node_errors"], "args": args}
|
||||
|
||||
if ws:
|
||||
result = runner.monitor_ws(pid, timeout=timeout)
|
||||
else:
|
||||
result = runner.poll_status(pid, timeout=timeout)
|
||||
|
||||
if result["status"] != "success":
|
||||
return {
|
||||
"status": result["status"],
|
||||
"prompt_id": pid,
|
||||
"details": result.get("data"),
|
||||
"args": args,
|
||||
}
|
||||
|
||||
outputs = result.get("outputs") or runner.get_outputs(pid)
|
||||
downloaded = download_outputs(runner, outputs, output_dir, preserve_subfolder=False)
|
||||
return {
|
||||
"status": "success",
|
||||
"prompt_id": pid,
|
||||
"args": args,
|
||||
"outputs": downloaded,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Submit a workflow many times with varying parameters.",
|
||||
)
|
||||
p.add_argument("--workflow", required=True)
|
||||
p.add_argument("--args", default="{}", help="Base parameters JSON")
|
||||
p.add_argument("--count", type=int, default=0,
|
||||
help="Number of runs (use with --randomize-seed)")
|
||||
p.add_argument("--sweep", default="",
|
||||
help='JSON dict of param→list of values. Cartesian product. '
|
||||
'e.g. \'{"seed":[1,2,3],"cfg":[5,8]}\'')
|
||||
p.add_argument("--randomize-seed", action="store_true",
|
||||
help="In --count mode, vary seed per run")
|
||||
p.add_argument("--host", default=DEFAULT_LOCAL_HOST)
|
||||
p.add_argument("--api-key", help=f"or set ${ENV_API_KEY}")
|
||||
p.add_argument("--partner-key")
|
||||
p.add_argument("--parallel", type=int, default=1,
|
||||
help="Concurrent submissions (cloud: up to your tier limit). "
|
||||
"Default 1 (sequential)")
|
||||
p.add_argument("--output-dir", default="./outputs/batch")
|
||||
p.add_argument("--timeout", type=int, default=0)
|
||||
p.add_argument("--ws", action="store_true")
|
||||
p.add_argument("--continue-on-error", action="store_true",
|
||||
help="Don't stop the batch when a run fails")
|
||||
args = p.parse_args(argv)
|
||||
|
||||
if args.count <= 0 and not args.sweep:
|
||||
emit_json({"error": "Specify --count N or --sweep '{...}'"})
|
||||
return 1
|
||||
|
||||
base_args = json.loads(args.args) if args.args.strip() else {}
|
||||
sweep = json.loads(args.sweep) if args.sweep.strip() else {}
|
||||
|
||||
# Validate sweep shape
|
||||
if sweep:
|
||||
if not isinstance(sweep, dict):
|
||||
emit_json({"error": "--sweep must be a JSON object {param: [values]}"})
|
||||
return 1
|
||||
empty = [k for k, v in sweep.items() if isinstance(v, list) and len(v) == 0]
|
||||
if empty:
|
||||
emit_json({"error": f"--sweep parameters have empty value lists: {empty}"})
|
||||
return 1
|
||||
# If user passed BOTH --sweep and --count/--randomize-seed, --sweep wins
|
||||
if args.count or args.randomize_seed:
|
||||
log("--sweep set; ignoring --count / --randomize-seed (sweep defines the runs)")
|
||||
|
||||
wf_path = Path(args.workflow).expanduser()
|
||||
if not wf_path.exists():
|
||||
emit_json({"error": f"Workflow not found: {args.workflow}"})
|
||||
return 1
|
||||
try:
|
||||
with wf_path.open() as f:
|
||||
workflow = unwrap_workflow(json.load(f))
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
emit_json({"error": str(e)})
|
||||
return 1
|
||||
|
||||
schema = extract_schema(workflow)
|
||||
runs = expand_sweep(sweep, base_args, args.count, args.randomize_seed)
|
||||
log(f"Planned {len(runs)} run(s)")
|
||||
|
||||
api_key = resolve_api_key(args.api_key)
|
||||
runner = ComfyRunner(host=args.host, api_key=api_key, partner_key=args.partner_key)
|
||||
|
||||
ok, info = runner.check_server()
|
||||
if not ok:
|
||||
emit_json({"error": "Cannot reach server", "details": info, "host": args.host})
|
||||
return 1
|
||||
|
||||
timeout = args.timeout
|
||||
if timeout <= 0:
|
||||
timeout = 900 if looks_like_video_workflow(workflow) else 300
|
||||
|
||||
base_dir = Path(args.output_dir).expanduser()
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
results: list[dict] = []
|
||||
failures = 0
|
||||
|
||||
if args.parallel > 1:
|
||||
with ThreadPoolExecutor(max_workers=args.parallel) as ex:
|
||||
future_to_idx = {}
|
||||
for i, ar in enumerate(runs):
|
||||
run_dir = base_dir / f"run_{i:04d}"
|
||||
fut = ex.submit(
|
||||
execute_one, runner, workflow, schema, ar,
|
||||
output_dir=run_dir, timeout=timeout, ws=args.ws,
|
||||
)
|
||||
future_to_idx[fut] = i
|
||||
for fut in as_completed(future_to_idx):
|
||||
i = future_to_idx[fut]
|
||||
try:
|
||||
r = fut.result()
|
||||
except Exception as e:
|
||||
r = {"status": "error", "error": str(e), "args": runs[i]}
|
||||
r["index"] = i
|
||||
results.append(r)
|
||||
if r["status"] != "success":
|
||||
failures += 1
|
||||
log(f" run {i} → {r['status']}: {r.get('error','?')}")
|
||||
if not args.continue_on_error:
|
||||
log(" --continue-on-error not set; aborting batch")
|
||||
break
|
||||
else:
|
||||
log(f" run {i} → success: {len(r.get('outputs', []))} files")
|
||||
else:
|
||||
for i, ar in enumerate(runs):
|
||||
run_dir = base_dir / f"run_{i:04d}"
|
||||
r = execute_one(runner, workflow, schema, ar,
|
||||
output_dir=run_dir, timeout=timeout, ws=args.ws)
|
||||
r["index"] = i
|
||||
results.append(r)
|
||||
if r["status"] != "success":
|
||||
failures += 1
|
||||
log(f" run {i} → {r['status']}: {r.get('error','?')}")
|
||||
if not args.continue_on_error:
|
||||
log(" --continue-on-error not set; aborting batch")
|
||||
break
|
||||
else:
|
||||
log(f" run {i} → success: {len(r.get('outputs', []))} files")
|
||||
|
||||
results.sort(key=lambda x: x.get("index", 0))
|
||||
emit_json({
|
||||
"status": "success" if failures == 0 else "partial",
|
||||
"total": len(runs),
|
||||
"completed": sum(1 for r in results if r["status"] == "success"),
|
||||
"failed": failures,
|
||||
"output_dir": str(base_dir),
|
||||
"results": results,
|
||||
})
|
||||
return 0 if failures == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user