1
0
forked from cgvr/DeltaVR

refactor image generation script + env vars

This commit is contained in:
2025-12-17 13:23:59 +02:00
parent 90781191b7
commit 2c19602e9b
3 changed files with 419 additions and 39 deletions

View File

@@ -1,5 +1,5 @@
3D_GENERATION_URL= INVOKEAI_URL=
MODEL_FOLDER= TRELLIS_URL=
CLOUDFLARE_ACCOUNT_ID= CLOUDFLARE_ACCOUNT_ID=
CLOUDFLARE_API_TOKEN= CLOUDFLARE_API_TOKEN=

View File

@@ -1,11 +1,111 @@
# Based on: https://github.com/coinstax/invokeai-mcp-server
import requests import requests
import asyncio
import json
import logging
import httpx
import os
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url from typing import Optional
from urllib.parse import urljoin from urllib.parse import urljoin
from dotenv import load_dotenv
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("invokeai-mcp")
load_dotenv()
INVOKEAI_BASE_URL = os.environ["INVOKEAI_URL"]
DEFAULT_QUEUE_ID = "default"
# HTTP client
http_client: Optional[httpx.AsyncClient] = None
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
def get_client() -> httpx.AsyncClient:
"""Get or create HTTP client."""
global http_client
if http_client is None:
http_client = httpx.AsyncClient(base_url=INVOKEAI_BASE_URL, timeout=120.0)
return http_client
async def text_to_image_invoke_ai(prompt, output_path):
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
args = {
"prompt": prompt,
"width": 512,
"height": 512,
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
}
image_url = await generate_image(args)
print("got image url: ", image_url)
download_file(image_url, output_path)
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
"""Wait for a batch to complete and return the most recent image."""
client = get_client()
start_time = asyncio.get_event_loop().time()
while True:
# Check if we've exceeded timeout
if asyncio.get_event_loop().time() - start_time > timeout:
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
# Get batch status
response = await client.get(f"/api/v1/queue/{queue_id}/b/{batch_id}/status")
response.raise_for_status()
status_data = response.json()
# Check for failures
failed_count = status_data.get("failed", 0)
if failed_count > 0:
# Try to get error details from the queue
queue_status_response = await client.get(f"/api/v1/queue/{queue_id}/status")
queue_status_response.raise_for_status()
queue_data = queue_status_response.json()
raise RuntimeError(
f"Image generation failed. Batch {batch_id} has {failed_count} failed item(s). "
f"Queue status: {json.dumps(queue_data, indent=2)}"
)
# Check completion
completed = status_data.get("completed", 0)
total = status_data.get("total", 0)
if completed == total and total > 0:
# Get most recent non-intermediate image
images_response = await client.get("/api/v1/images/?is_intermediate=false&limit=10")
images_response.raise_for_status()
images_data = images_response.json()
# Return the most recent image (first in the list)
if images_data.get("items"):
return {
"batch_id": batch_id,
"status": "completed",
"result": {
"outputs": {
"save_image": {
"type": "image_output",
"image": {
"image_name": images_data["items"][0]["image_name"]
}
}
}
}
}
# If no images found, return status
return status_data
# Wait before checking again
await asyncio.sleep(1)
async def generate_image(arguments: dict): async def generate_image(arguments: dict):
@@ -71,16 +171,319 @@ def download_file(url, filepath):
else: else:
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}") raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
async def get_image_url(image_name: str) -> str:
"""Get the URL for an image."""
client = get_client()
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
response.raise_for_status()
data = response.json()
return data.get("image_url", "")
async def text_to_image_invoke_ai(prompt, output_path):
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
args = { """Enqueue a graph for processing."""
"prompt": prompt, client = get_client()
"width": 512,
"height": 512, batch = {
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9 "batch": {
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8 "graph": graph,
"runs": 1,
"data": None
}
} }
image_url = await generate_image(args)
print("got image url: ", image_url) response = await client.post(
download_file(image_url, output_path) f"/api/v1/queue/{queue_id}/enqueue_batch",
json=batch
)
response.raise_for_status()
return response.json()
async def list_models(model_type: str = "main") -> list:
"""List available models."""
client = get_client()
response = await client.get("/api/v2/models/", params={"model_type": model_type})
response.raise_for_status()
data = response.json()
return data.get("models", [])
async def get_model_info(model_key: str) -> Optional[dict]:
"""Get information about a specific model."""
client = get_client()
try:
response = await client.get(f"/api/v2/models/i/{model_key}")
response.raise_for_status()
model_data = response.json()
# Ensure we have a valid dictionary
if not isinstance(model_data, dict):
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
return None
return model_data
except Exception as e:
logger.error(f"Error fetching model info for {model_key}: {e}")
return None
async def create_text2img_graph(
prompt: str,
negative_prompt: str = "",
model_key: Optional[str] = None,
lora_key: Optional[str] = None,
lora_weight: float = 1.0,
vae_key: Optional[str] = None,
width: int = 512,
height: int = 512,
steps: int = 30,
cfg_scale: float = 7.5,
scheduler: str = "euler",
seed: Optional[int] = None
) -> dict:
"""Create a text-to-image generation graph with optional LoRA and VAE support."""
# Use default model if not specified
if model_key is None:
# Try to find an sd-1 model
models = await list_models("main")
for model in models:
if model.get("base") == "sd-1":
model_key = model["key"]
break
if model_key is None:
raise ValueError("No suitable model found")
# Get model information
model_info = await get_model_info(model_key)
if not model_info:
raise ValueError(f"Model {model_key} not found")
# Validate model info has required fields
if not isinstance(model_info, dict):
raise ValueError(f"Model {model_key} returned invalid data type: {type(model_info)}")
required_fields = ["key", "hash", "name", "base", "type"]
for field in required_fields:
if field not in model_info or model_info[field] is None:
raise ValueError(f"Model {model_key} is missing required field: {field}")
# Generate random seed if not provided
if seed is None:
import random
seed = random.randint(0, 2**32 - 1)
# Detect if this is an SDXL model
is_sdxl = model_info["base"] == "sdxl"
# Build nodes dictionary
nodes = {
# Main model loader - use sdxl_model_loader for SDXL models
"model_loader": {
"type": "sdxl_model_loader" if is_sdxl else "main_model_loader",
"id": "model_loader",
"model": {
"key": model_info["key"],
"hash": model_info["hash"],
"name": model_info["name"],
"base": model_info["base"],
"type": model_info["type"]
}
},
# Positive prompt encoding - use sdxl_compel_prompt for SDXL
"positive_prompt": {
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
"id": "positive_prompt",
"prompt": prompt,
**({"style": prompt} if is_sdxl else {})
},
# Negative prompt encoding - use sdxl_compel_prompt for SDXL
"negative_prompt": {
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
"id": "negative_prompt",
"prompt": negative_prompt,
**({"style": ""} if is_sdxl else {})
},
# Noise generation
"noise": {
"type": "noise",
"id": "noise",
"seed": seed,
"width": width,
"height": height,
"use_cpu": False
},
# Denoise latents (main generation step)
"denoise": {
"type": "denoise_latents",
"id": "denoise",
"steps": steps,
"cfg_scale": cfg_scale,
"scheduler": scheduler,
"denoising_start": 0,
"denoising_end": 1
},
# Convert latents to image
"latents_to_image": {
"type": "l2i",
"id": "latents_to_image"
},
# Save image
"save_image": {
"type": "save_image",
"id": "save_image",
"is_intermediate": False
}
}
# Add LoRA loader if requested
if lora_key is not None:
lora_info = await get_model_info(lora_key)
if not lora_info:
raise ValueError(f"LoRA model {lora_key} not found")
# Validate LoRA info has required fields
required_fields = ["key", "hash", "name", "base", "type"]
for field in required_fields:
if field not in lora_info or lora_info[field] is None:
raise ValueError(f"LoRA model {lora_key} is missing required field: {field}")
nodes["lora_loader"] = {
"type": "lora_loader",
"id": "lora_loader",
"lora": {
"key": lora_info["key"],
"hash": lora_info["hash"],
"name": lora_info["name"],
"base": lora_info["base"],
"type": lora_info["type"]
},
"weight": lora_weight
}
# Add VAE loader if requested (to override model's built-in VAE)
if vae_key is not None:
vae_info = await get_model_info(vae_key)
if not vae_info:
raise ValueError(f"VAE model {vae_key} not found")
# Validate VAE info has required fields
required_fields = ["key", "hash", "name", "base", "type"]
for field in required_fields:
if field not in vae_info or vae_info[field] is None:
raise ValueError(f"VAE model {vae_key} is missing required field: {field}")
nodes["vae_loader"] = {
"type": "vae_loader",
"id": "vae_loader",
"vae_model": {
"key": vae_info["key"],
"hash": vae_info["hash"],
"name": vae_info["name"],
"base": vae_info["base"],
"type": vae_info["type"]
}
}
# Build edges
edges = []
# Determine source for UNet and CLIP (model_loader or lora_loader)
unet_source = "lora_loader" if lora_key is not None else "model_loader"
clip_source = "lora_loader" if lora_key is not None else "model_loader"
# Determine source for VAE (vae_loader if specified, otherwise model_loader)
vae_source = "vae_loader" if vae_key is not None else "model_loader"
# If using LoRA, connect model_loader to lora_loader first
if lora_key is not None:
edges.extend([
{
"source": {"node_id": "model_loader", "field": "unet"},
"destination": {"node_id": "lora_loader", "field": "unet"}
},
{
"source": {"node_id": "model_loader", "field": "clip"},
"destination": {"node_id": "lora_loader", "field": "clip"}
}
])
# Note: lora_loader doesn't have a clip2 field, so for SDXL we route clip2 directly from model_loader
# Connect UNet and CLIP to downstream nodes
edges.extend([
# Connect UNet to denoise
{
"source": {"node_id": unet_source, "field": "unet"},
"destination": {"node_id": "denoise", "field": "unet"}
},
# Connect CLIP to prompts
{
"source": {"node_id": clip_source, "field": "clip"},
"destination": {"node_id": "positive_prompt", "field": "clip"}
},
{
"source": {"node_id": clip_source, "field": "clip"},
"destination": {"node_id": "negative_prompt", "field": "clip"}
},
])
# For SDXL models, also connect clip2
# Note: clip2 always comes from model_loader, even when using LoRA (lora_loader doesn't support clip2)
if is_sdxl:
edges.extend([
{
"source": {"node_id": "model_loader", "field": "clip2"},
"destination": {"node_id": "positive_prompt", "field": "clip2"}
},
{
"source": {"node_id": "model_loader", "field": "clip2"},
"destination": {"node_id": "negative_prompt", "field": "clip2"}
},
])
edges.extend([
# Connect prompts to denoise
{
"source": {"node_id": "positive_prompt", "field": "conditioning"},
"destination": {"node_id": "denoise", "field": "positive_conditioning"}
},
{
"source": {"node_id": "negative_prompt", "field": "conditioning"},
"destination": {"node_id": "denoise", "field": "negative_conditioning"}
},
# Connect noise to denoise
{
"source": {"node_id": "noise", "field": "noise"},
"destination": {"node_id": "denoise", "field": "noise"}
},
# Connect denoise to latents_to_image
{
"source": {"node_id": "denoise", "field": "latents"},
"destination": {"node_id": "latents_to_image", "field": "latents"}
},
{
"source": {"node_id": vae_source, "field": "vae"},
"destination": {"node_id": "latents_to_image", "field": "vae"}
},
# Connect latents_to_image to save_image
{
"source": {"node_id": "latents_to_image", "field": "image"},
"destination": {"node_id": "save_image", "field": "image"}
}
])
graph = {
"id": "text2img_graph",
"nodes": nodes,
"edges": edges
}
return graph

View File

@@ -1,4 +1,3 @@
import subprocess
import os import os
import time import time
import requests import requests
@@ -7,29 +6,7 @@ import base64
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
MODEL_FOLDER = os.environ["MODEL_FOLDER"] API_URL = os.environ["TRELLIS_URL"]
API_URL = os.environ["3D_GENERATION_URL"]
def image_to_3d_subprocess(image_path, output_path):
venv_python = MODEL_FOLDER + r"\.venv\Scripts\python.exe"
script_path = MODEL_FOLDER + r"\run.py"
args = [image_path, "--output-dir", output_path]
command = [venv_python, script_path] + args
try:
# Run the subprocess
result = subprocess.run(command, capture_output=True, text=True)
# Print output and errors
print("STDOUT:\n", result.stdout)
print("STDERR:\n", result.stderr)
print("Return Code:", result.returncode)
except Exception as e:
print(f"Error occurred: {e}")
def generate_no_preview(image_base64: str): def generate_no_preview(image_base64: str):