forked from cgvr/DeltaVR
refactor image generation script + env vars
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
3D_GENERATION_URL=
|
INVOKEAI_URL=
|
||||||
MODEL_FOLDER=
|
TRELLIS_URL=
|
||||||
|
|
||||||
CLOUDFLARE_ACCOUNT_ID=
|
CLOUDFLARE_ACCOUNT_ID=
|
||||||
CLOUDFLARE_API_TOKEN=
|
CLOUDFLARE_API_TOKEN=
|
||||||
@@ -1,11 +1,111 @@
|
|||||||
|
# Based on: https://github.com/coinstax/invokeai-mcp-server
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
import os
|
||||||
|
|
||||||
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url
|
from typing import Optional
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("invokeai-mcp")
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
INVOKEAI_BASE_URL = os.environ["INVOKEAI_URL"]
|
||||||
|
DEFAULT_QUEUE_ID = "default"
|
||||||
|
|
||||||
|
# HTTP client
|
||||||
|
http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
|
||||||
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
|
|
||||||
|
|
||||||
|
def get_client() -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
global http_client
|
||||||
|
if http_client is None:
|
||||||
|
http_client = httpx.AsyncClient(base_url=INVOKEAI_BASE_URL, timeout=120.0)
|
||||||
|
return http_client
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image_invoke_ai(prompt, output_path):
|
||||||
|
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
||||||
|
args = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
||||||
|
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
||||||
|
}
|
||||||
|
image_url = await generate_image(args)
|
||||||
|
print("got image url: ", image_url)
|
||||||
|
download_file(image_url, output_path)
|
||||||
|
|
||||||
|
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
||||||
|
"""Wait for a batch to complete and return the most recent image."""
|
||||||
|
client = get_client()
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Check if we've exceeded timeout
|
||||||
|
if asyncio.get_event_loop().time() - start_time > timeout:
|
||||||
|
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
# Get batch status
|
||||||
|
response = await client.get(f"/api/v1/queue/{queue_id}/b/{batch_id}/status")
|
||||||
|
response.raise_for_status()
|
||||||
|
status_data = response.json()
|
||||||
|
|
||||||
|
# Check for failures
|
||||||
|
failed_count = status_data.get("failed", 0)
|
||||||
|
if failed_count > 0:
|
||||||
|
# Try to get error details from the queue
|
||||||
|
queue_status_response = await client.get(f"/api/v1/queue/{queue_id}/status")
|
||||||
|
queue_status_response.raise_for_status()
|
||||||
|
queue_data = queue_status_response.json()
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Image generation failed. Batch {batch_id} has {failed_count} failed item(s). "
|
||||||
|
f"Queue status: {json.dumps(queue_data, indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check completion
|
||||||
|
completed = status_data.get("completed", 0)
|
||||||
|
total = status_data.get("total", 0)
|
||||||
|
|
||||||
|
if completed == total and total > 0:
|
||||||
|
# Get most recent non-intermediate image
|
||||||
|
images_response = await client.get("/api/v1/images/?is_intermediate=false&limit=10")
|
||||||
|
images_response.raise_for_status()
|
||||||
|
images_data = images_response.json()
|
||||||
|
|
||||||
|
# Return the most recent image (first in the list)
|
||||||
|
if images_data.get("items"):
|
||||||
|
return {
|
||||||
|
"batch_id": batch_id,
|
||||||
|
"status": "completed",
|
||||||
|
"result": {
|
||||||
|
"outputs": {
|
||||||
|
"save_image": {
|
||||||
|
"type": "image_output",
|
||||||
|
"image": {
|
||||||
|
"image_name": images_data["items"][0]["image_name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# If no images found, return status
|
||||||
|
return status_data
|
||||||
|
|
||||||
|
# Wait before checking again
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def generate_image(arguments: dict):
|
async def generate_image(arguments: dict):
|
||||||
|
|
||||||
@@ -71,16 +171,319 @@ def download_file(url, filepath):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
async def get_image_url(image_name: str) -> str:
|
||||||
|
"""Get the URL for an image."""
|
||||||
|
client = get_client()
|
||||||
|
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("image_url", "")
|
||||||
|
|
||||||
async def text_to_image_invoke_ai(prompt, output_path):
|
|
||||||
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
||||||
args = {
|
"""Enqueue a graph for processing."""
|
||||||
"prompt": prompt,
|
client = get_client()
|
||||||
"width": 512,
|
|
||||||
"height": 512,
|
batch = {
|
||||||
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
"batch": {
|
||||||
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
"graph": graph,
|
||||||
|
"runs": 1,
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
image_url = await generate_image(args)
|
|
||||||
print("got image url: ", image_url)
|
response = await client.post(
|
||||||
download_file(image_url, output_path)
|
f"/api/v1/queue/{queue_id}/enqueue_batch",
|
||||||
|
json=batch
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def list_models(model_type: str = "main") -> list:
|
||||||
|
"""List available models."""
|
||||||
|
client = get_client()
|
||||||
|
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("models", [])
|
||||||
|
|
||||||
|
async def get_model_info(model_key: str) -> Optional[dict]:
|
||||||
|
"""Get information about a specific model."""
|
||||||
|
client = get_client()
|
||||||
|
try:
|
||||||
|
response = await client.get(f"/api/v2/models/i/{model_key}")
|
||||||
|
response.raise_for_status()
|
||||||
|
model_data = response.json()
|
||||||
|
|
||||||
|
# Ensure we have a valid dictionary
|
||||||
|
if not isinstance(model_data, dict):
|
||||||
|
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return model_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model info for {model_key}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_text2img_graph(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_key: Optional[str] = None,
|
||||||
|
lora_key: Optional[str] = None,
|
||||||
|
lora_weight: float = 1.0,
|
||||||
|
vae_key: Optional[str] = None,
|
||||||
|
width: int = 512,
|
||||||
|
height: int = 512,
|
||||||
|
steps: int = 30,
|
||||||
|
cfg_scale: float = 7.5,
|
||||||
|
scheduler: str = "euler",
|
||||||
|
seed: Optional[int] = None
|
||||||
|
) -> dict:
|
||||||
|
"""Create a text-to-image generation graph with optional LoRA and VAE support."""
|
||||||
|
|
||||||
|
# Use default model if not specified
|
||||||
|
if model_key is None:
|
||||||
|
# Try to find an sd-1 model
|
||||||
|
models = await list_models("main")
|
||||||
|
for model in models:
|
||||||
|
if model.get("base") == "sd-1":
|
||||||
|
model_key = model["key"]
|
||||||
|
break
|
||||||
|
if model_key is None:
|
||||||
|
raise ValueError("No suitable model found")
|
||||||
|
|
||||||
|
# Get model information
|
||||||
|
model_info = await get_model_info(model_key)
|
||||||
|
if not model_info:
|
||||||
|
raise ValueError(f"Model {model_key} not found")
|
||||||
|
|
||||||
|
# Validate model info has required fields
|
||||||
|
if not isinstance(model_info, dict):
|
||||||
|
raise ValueError(f"Model {model_key} returned invalid data type: {type(model_info)}")
|
||||||
|
|
||||||
|
required_fields = ["key", "hash", "name", "base", "type"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in model_info or model_info[field] is None:
|
||||||
|
raise ValueError(f"Model {model_key} is missing required field: {field}")
|
||||||
|
|
||||||
|
# Generate random seed if not provided
|
||||||
|
if seed is None:
|
||||||
|
import random
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
# Detect if this is an SDXL model
|
||||||
|
is_sdxl = model_info["base"] == "sdxl"
|
||||||
|
|
||||||
|
# Build nodes dictionary
|
||||||
|
nodes = {
|
||||||
|
# Main model loader - use sdxl_model_loader for SDXL models
|
||||||
|
"model_loader": {
|
||||||
|
"type": "sdxl_model_loader" if is_sdxl else "main_model_loader",
|
||||||
|
"id": "model_loader",
|
||||||
|
"model": {
|
||||||
|
"key": model_info["key"],
|
||||||
|
"hash": model_info["hash"],
|
||||||
|
"name": model_info["name"],
|
||||||
|
"base": model_info["base"],
|
||||||
|
"type": model_info["type"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Positive prompt encoding - use sdxl_compel_prompt for SDXL
|
||||||
|
"positive_prompt": {
|
||||||
|
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
||||||
|
"id": "positive_prompt",
|
||||||
|
"prompt": prompt,
|
||||||
|
**({"style": prompt} if is_sdxl else {})
|
||||||
|
},
|
||||||
|
|
||||||
|
# Negative prompt encoding - use sdxl_compel_prompt for SDXL
|
||||||
|
"negative_prompt": {
|
||||||
|
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
||||||
|
"id": "negative_prompt",
|
||||||
|
"prompt": negative_prompt,
|
||||||
|
**({"style": ""} if is_sdxl else {})
|
||||||
|
},
|
||||||
|
|
||||||
|
# Noise generation
|
||||||
|
"noise": {
|
||||||
|
"type": "noise",
|
||||||
|
"id": "noise",
|
||||||
|
"seed": seed,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"use_cpu": False
|
||||||
|
},
|
||||||
|
|
||||||
|
# Denoise latents (main generation step)
|
||||||
|
"denoise": {
|
||||||
|
"type": "denoise_latents",
|
||||||
|
"id": "denoise",
|
||||||
|
"steps": steps,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"denoising_start": 0,
|
||||||
|
"denoising_end": 1
|
||||||
|
},
|
||||||
|
|
||||||
|
# Convert latents to image
|
||||||
|
"latents_to_image": {
|
||||||
|
"type": "l2i",
|
||||||
|
"id": "latents_to_image"
|
||||||
|
},
|
||||||
|
|
||||||
|
# Save image
|
||||||
|
"save_image": {
|
||||||
|
"type": "save_image",
|
||||||
|
"id": "save_image",
|
||||||
|
"is_intermediate": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add LoRA loader if requested
|
||||||
|
if lora_key is not None:
|
||||||
|
lora_info = await get_model_info(lora_key)
|
||||||
|
if not lora_info:
|
||||||
|
raise ValueError(f"LoRA model {lora_key} not found")
|
||||||
|
|
||||||
|
# Validate LoRA info has required fields
|
||||||
|
required_fields = ["key", "hash", "name", "base", "type"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in lora_info or lora_info[field] is None:
|
||||||
|
raise ValueError(f"LoRA model {lora_key} is missing required field: {field}")
|
||||||
|
|
||||||
|
nodes["lora_loader"] = {
|
||||||
|
"type": "lora_loader",
|
||||||
|
"id": "lora_loader",
|
||||||
|
"lora": {
|
||||||
|
"key": lora_info["key"],
|
||||||
|
"hash": lora_info["hash"],
|
||||||
|
"name": lora_info["name"],
|
||||||
|
"base": lora_info["base"],
|
||||||
|
"type": lora_info["type"]
|
||||||
|
},
|
||||||
|
"weight": lora_weight
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add VAE loader if requested (to override model's built-in VAE)
|
||||||
|
if vae_key is not None:
|
||||||
|
vae_info = await get_model_info(vae_key)
|
||||||
|
if not vae_info:
|
||||||
|
raise ValueError(f"VAE model {vae_key} not found")
|
||||||
|
|
||||||
|
# Validate VAE info has required fields
|
||||||
|
required_fields = ["key", "hash", "name", "base", "type"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in vae_info or vae_info[field] is None:
|
||||||
|
raise ValueError(f"VAE model {vae_key} is missing required field: {field}")
|
||||||
|
|
||||||
|
nodes["vae_loader"] = {
|
||||||
|
"type": "vae_loader",
|
||||||
|
"id": "vae_loader",
|
||||||
|
"vae_model": {
|
||||||
|
"key": vae_info["key"],
|
||||||
|
"hash": vae_info["hash"],
|
||||||
|
"name": vae_info["name"],
|
||||||
|
"base": vae_info["base"],
|
||||||
|
"type": vae_info["type"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build edges
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
# Determine source for UNet and CLIP (model_loader or lora_loader)
|
||||||
|
unet_source = "lora_loader" if lora_key is not None else "model_loader"
|
||||||
|
clip_source = "lora_loader" if lora_key is not None else "model_loader"
|
||||||
|
# Determine source for VAE (vae_loader if specified, otherwise model_loader)
|
||||||
|
vae_source = "vae_loader" if vae_key is not None else "model_loader"
|
||||||
|
|
||||||
|
# If using LoRA, connect model_loader to lora_loader first
|
||||||
|
if lora_key is not None:
|
||||||
|
edges.extend([
|
||||||
|
{
|
||||||
|
"source": {"node_id": "model_loader", "field": "unet"},
|
||||||
|
"destination": {"node_id": "lora_loader", "field": "unet"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {"node_id": "model_loader", "field": "clip"},
|
||||||
|
"destination": {"node_id": "lora_loader", "field": "clip"}
|
||||||
|
}
|
||||||
|
])
|
||||||
|
# Note: lora_loader doesn't have a clip2 field, so for SDXL we route clip2 directly from model_loader
|
||||||
|
|
||||||
|
# Connect UNet and CLIP to downstream nodes
|
||||||
|
edges.extend([
|
||||||
|
# Connect UNet to denoise
|
||||||
|
{
|
||||||
|
"source": {"node_id": unet_source, "field": "unet"},
|
||||||
|
"destination": {"node_id": "denoise", "field": "unet"}
|
||||||
|
},
|
||||||
|
# Connect CLIP to prompts
|
||||||
|
{
|
||||||
|
"source": {"node_id": clip_source, "field": "clip"},
|
||||||
|
"destination": {"node_id": "positive_prompt", "field": "clip"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {"node_id": clip_source, "field": "clip"},
|
||||||
|
"destination": {"node_id": "negative_prompt", "field": "clip"}
|
||||||
|
},
|
||||||
|
])
|
||||||
|
|
||||||
|
# For SDXL models, also connect clip2
|
||||||
|
# Note: clip2 always comes from model_loader, even when using LoRA (lora_loader doesn't support clip2)
|
||||||
|
if is_sdxl:
|
||||||
|
edges.extend([
|
||||||
|
{
|
||||||
|
"source": {"node_id": "model_loader", "field": "clip2"},
|
||||||
|
"destination": {"node_id": "positive_prompt", "field": "clip2"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {"node_id": "model_loader", "field": "clip2"},
|
||||||
|
"destination": {"node_id": "negative_prompt", "field": "clip2"}
|
||||||
|
},
|
||||||
|
])
|
||||||
|
|
||||||
|
edges.extend([
|
||||||
|
|
||||||
|
# Connect prompts to denoise
|
||||||
|
{
|
||||||
|
"source": {"node_id": "positive_prompt", "field": "conditioning"},
|
||||||
|
"destination": {"node_id": "denoise", "field": "positive_conditioning"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {"node_id": "negative_prompt", "field": "conditioning"},
|
||||||
|
"destination": {"node_id": "denoise", "field": "negative_conditioning"}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Connect noise to denoise
|
||||||
|
{
|
||||||
|
"source": {"node_id": "noise", "field": "noise"},
|
||||||
|
"destination": {"node_id": "denoise", "field": "noise"}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Connect denoise to latents_to_image
|
||||||
|
{
|
||||||
|
"source": {"node_id": "denoise", "field": "latents"},
|
||||||
|
"destination": {"node_id": "latents_to_image", "field": "latents"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {"node_id": vae_source, "field": "vae"},
|
||||||
|
"destination": {"node_id": "latents_to_image", "field": "vae"}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Connect latents_to_image to save_image
|
||||||
|
{
|
||||||
|
"source": {"node_id": "latents_to_image", "field": "image"},
|
||||||
|
"destination": {"node_id": "save_image", "field": "image"}
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
graph = {
|
||||||
|
"id": "text2img_graph",
|
||||||
|
"nodes": nodes,
|
||||||
|
"edges": edges
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
@@ -7,29 +6,7 @@ import base64
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
MODEL_FOLDER = os.environ["MODEL_FOLDER"]
|
API_URL = os.environ["TRELLIS_URL"]
|
||||||
API_URL = os.environ["3D_GENERATION_URL"]
|
|
||||||
|
|
||||||
|
|
||||||
def image_to_3d_subprocess(image_path, output_path):
|
|
||||||
venv_python = MODEL_FOLDER + r"\.venv\Scripts\python.exe"
|
|
||||||
script_path = MODEL_FOLDER + r"\run.py"
|
|
||||||
|
|
||||||
args = [image_path, "--output-dir", output_path]
|
|
||||||
command = [venv_python, script_path] + args
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run the subprocess
|
|
||||||
result = subprocess.run(command, capture_output=True, text=True)
|
|
||||||
|
|
||||||
# Print output and errors
|
|
||||||
print("STDOUT:\n", result.stdout)
|
|
||||||
print("STDERR:\n", result.stderr)
|
|
||||||
print("Return Code:", result.returncode)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_no_preview(image_base64: str):
|
def generate_no_preview(image_base64: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user