forked from cgvr/DeltaVR
refactor image generation script + env vars
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
3D_GENERATION_URL=
|
||||
MODEL_FOLDER=
|
||||
INVOKEAI_URL=
|
||||
TRELLIS_URL=
|
||||
|
||||
CLOUDFLARE_ACCOUNT_ID=
|
||||
CLOUDFLARE_API_TOKEN=
|
||||
@@ -1,11 +1,111 @@
|
||||
# Based on: https://github.com/coinstax/invokeai-mcp-server
|
||||
|
||||
import requests
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import httpx
|
||||
import os
|
||||
|
||||
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("invokeai-mcp")
|
||||
|
||||
load_dotenv()
|
||||
|
||||
INVOKEAI_BASE_URL = os.environ["INVOKEAI_URL"]
|
||||
DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
# HTTP client
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
|
||||
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
|
||||
|
||||
def get_client() -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
global http_client
|
||||
if http_client is None:
|
||||
http_client = httpx.AsyncClient(base_url=INVOKEAI_BASE_URL, timeout=120.0)
|
||||
return http_client
|
||||
|
||||
|
||||
async def text_to_image_invoke_ai(prompt, output_path):
|
||||
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
||||
args = {
|
||||
"prompt": prompt,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
||||
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
||||
}
|
||||
image_url = await generate_image(args)
|
||||
print("got image url: ", image_url)
|
||||
download_file(image_url, output_path)
|
||||
|
||||
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
||||
"""Wait for a batch to complete and return the most recent image."""
|
||||
client = get_client()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while True:
|
||||
# Check if we've exceeded timeout
|
||||
if asyncio.get_event_loop().time() - start_time > timeout:
|
||||
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
||||
|
||||
# Get batch status
|
||||
response = await client.get(f"/api/v1/queue/{queue_id}/b/{batch_id}/status")
|
||||
response.raise_for_status()
|
||||
status_data = response.json()
|
||||
|
||||
# Check for failures
|
||||
failed_count = status_data.get("failed", 0)
|
||||
if failed_count > 0:
|
||||
# Try to get error details from the queue
|
||||
queue_status_response = await client.get(f"/api/v1/queue/{queue_id}/status")
|
||||
queue_status_response.raise_for_status()
|
||||
queue_data = queue_status_response.json()
|
||||
|
||||
raise RuntimeError(
|
||||
f"Image generation failed. Batch {batch_id} has {failed_count} failed item(s). "
|
||||
f"Queue status: {json.dumps(queue_data, indent=2)}"
|
||||
)
|
||||
|
||||
# Check completion
|
||||
completed = status_data.get("completed", 0)
|
||||
total = status_data.get("total", 0)
|
||||
|
||||
if completed == total and total > 0:
|
||||
# Get most recent non-intermediate image
|
||||
images_response = await client.get("/api/v1/images/?is_intermediate=false&limit=10")
|
||||
images_response.raise_for_status()
|
||||
images_data = images_response.json()
|
||||
|
||||
# Return the most recent image (first in the list)
|
||||
if images_data.get("items"):
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"status": "completed",
|
||||
"result": {
|
||||
"outputs": {
|
||||
"save_image": {
|
||||
"type": "image_output",
|
||||
"image": {
|
||||
"image_name": images_data["items"][0]["image_name"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# If no images found, return status
|
||||
return status_data
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def generate_image(arguments: dict):
|
||||
|
||||
@@ -71,16 +171,319 @@ def download_file(url, filepath):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||
|
||||
async def get_image_url(image_name: str) -> str:
|
||||
"""Get the URL for an image."""
|
||||
client = get_client()
|
||||
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("image_url", "")
|
||||
|
||||
async def text_to_image_invoke_ai(prompt, output_path):
|
||||
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
||||
args = {
|
||||
"prompt": prompt,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
||||
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
||||
|
||||
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
||||
"""Enqueue a graph for processing."""
|
||||
client = get_client()
|
||||
|
||||
batch = {
|
||||
"batch": {
|
||||
"graph": graph,
|
||||
"runs": 1,
|
||||
"data": None
|
||||
}
|
||||
}
|
||||
image_url = await generate_image(args)
|
||||
print("got image url: ", image_url)
|
||||
download_file(image_url, output_path)
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/queue/{queue_id}/enqueue_batch",
|
||||
json=batch
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def list_models(model_type: str = "main") -> list:
|
||||
"""List available models."""
|
||||
client = get_client()
|
||||
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("models", [])
|
||||
|
||||
async def get_model_info(model_key: str) -> Optional[dict]:
|
||||
"""Get information about a specific model."""
|
||||
client = get_client()
|
||||
try:
|
||||
response = await client.get(f"/api/v2/models/i/{model_key}")
|
||||
response.raise_for_status()
|
||||
model_data = response.json()
|
||||
|
||||
# Ensure we have a valid dictionary
|
||||
if not isinstance(model_data, dict):
|
||||
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
||||
return None
|
||||
|
||||
return model_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model info for {model_key}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_text2img_graph(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
model_key: Optional[str] = None,
|
||||
lora_key: Optional[str] = None,
|
||||
lora_weight: float = 1.0,
|
||||
vae_key: Optional[str] = None,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 30,
|
||||
cfg_scale: float = 7.5,
|
||||
scheduler: str = "euler",
|
||||
seed: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Create a text-to-image generation graph with optional LoRA and VAE support."""
|
||||
|
||||
# Use default model if not specified
|
||||
if model_key is None:
|
||||
# Try to find an sd-1 model
|
||||
models = await list_models("main")
|
||||
for model in models:
|
||||
if model.get("base") == "sd-1":
|
||||
model_key = model["key"]
|
||||
break
|
||||
if model_key is None:
|
||||
raise ValueError("No suitable model found")
|
||||
|
||||
# Get model information
|
||||
model_info = await get_model_info(model_key)
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {model_key} not found")
|
||||
|
||||
# Validate model info has required fields
|
||||
if not isinstance(model_info, dict):
|
||||
raise ValueError(f"Model {model_key} returned invalid data type: {type(model_info)}")
|
||||
|
||||
required_fields = ["key", "hash", "name", "base", "type"]
|
||||
for field in required_fields:
|
||||
if field not in model_info or model_info[field] is None:
|
||||
raise ValueError(f"Model {model_key} is missing required field: {field}")
|
||||
|
||||
# Generate random seed if not provided
|
||||
if seed is None:
|
||||
import random
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
# Detect if this is an SDXL model
|
||||
is_sdxl = model_info["base"] == "sdxl"
|
||||
|
||||
# Build nodes dictionary
|
||||
nodes = {
|
||||
# Main model loader - use sdxl_model_loader for SDXL models
|
||||
"model_loader": {
|
||||
"type": "sdxl_model_loader" if is_sdxl else "main_model_loader",
|
||||
"id": "model_loader",
|
||||
"model": {
|
||||
"key": model_info["key"],
|
||||
"hash": model_info["hash"],
|
||||
"name": model_info["name"],
|
||||
"base": model_info["base"],
|
||||
"type": model_info["type"]
|
||||
}
|
||||
},
|
||||
|
||||
# Positive prompt encoding - use sdxl_compel_prompt for SDXL
|
||||
"positive_prompt": {
|
||||
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
||||
"id": "positive_prompt",
|
||||
"prompt": prompt,
|
||||
**({"style": prompt} if is_sdxl else {})
|
||||
},
|
||||
|
||||
# Negative prompt encoding - use sdxl_compel_prompt for SDXL
|
||||
"negative_prompt": {
|
||||
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
||||
"id": "negative_prompt",
|
||||
"prompt": negative_prompt,
|
||||
**({"style": ""} if is_sdxl else {})
|
||||
},
|
||||
|
||||
# Noise generation
|
||||
"noise": {
|
||||
"type": "noise",
|
||||
"id": "noise",
|
||||
"seed": seed,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"use_cpu": False
|
||||
},
|
||||
|
||||
# Denoise latents (main generation step)
|
||||
"denoise": {
|
||||
"type": "denoise_latents",
|
||||
"id": "denoise",
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"scheduler": scheduler,
|
||||
"denoising_start": 0,
|
||||
"denoising_end": 1
|
||||
},
|
||||
|
||||
# Convert latents to image
|
||||
"latents_to_image": {
|
||||
"type": "l2i",
|
||||
"id": "latents_to_image"
|
||||
},
|
||||
|
||||
# Save image
|
||||
"save_image": {
|
||||
"type": "save_image",
|
||||
"id": "save_image",
|
||||
"is_intermediate": False
|
||||
}
|
||||
}
|
||||
|
||||
# Add LoRA loader if requested
|
||||
if lora_key is not None:
|
||||
lora_info = await get_model_info(lora_key)
|
||||
if not lora_info:
|
||||
raise ValueError(f"LoRA model {lora_key} not found")
|
||||
|
||||
# Validate LoRA info has required fields
|
||||
required_fields = ["key", "hash", "name", "base", "type"]
|
||||
for field in required_fields:
|
||||
if field not in lora_info or lora_info[field] is None:
|
||||
raise ValueError(f"LoRA model {lora_key} is missing required field: {field}")
|
||||
|
||||
nodes["lora_loader"] = {
|
||||
"type": "lora_loader",
|
||||
"id": "lora_loader",
|
||||
"lora": {
|
||||
"key": lora_info["key"],
|
||||
"hash": lora_info["hash"],
|
||||
"name": lora_info["name"],
|
||||
"base": lora_info["base"],
|
||||
"type": lora_info["type"]
|
||||
},
|
||||
"weight": lora_weight
|
||||
}
|
||||
|
||||
# Add VAE loader if requested (to override model's built-in VAE)
|
||||
if vae_key is not None:
|
||||
vae_info = await get_model_info(vae_key)
|
||||
if not vae_info:
|
||||
raise ValueError(f"VAE model {vae_key} not found")
|
||||
|
||||
# Validate VAE info has required fields
|
||||
required_fields = ["key", "hash", "name", "base", "type"]
|
||||
for field in required_fields:
|
||||
if field not in vae_info or vae_info[field] is None:
|
||||
raise ValueError(f"VAE model {vae_key} is missing required field: {field}")
|
||||
|
||||
nodes["vae_loader"] = {
|
||||
"type": "vae_loader",
|
||||
"id": "vae_loader",
|
||||
"vae_model": {
|
||||
"key": vae_info["key"],
|
||||
"hash": vae_info["hash"],
|
||||
"name": vae_info["name"],
|
||||
"base": vae_info["base"],
|
||||
"type": vae_info["type"]
|
||||
}
|
||||
}
|
||||
|
||||
# Build edges
|
||||
edges = []
|
||||
|
||||
# Determine source for UNet and CLIP (model_loader or lora_loader)
|
||||
unet_source = "lora_loader" if lora_key is not None else "model_loader"
|
||||
clip_source = "lora_loader" if lora_key is not None else "model_loader"
|
||||
# Determine source for VAE (vae_loader if specified, otherwise model_loader)
|
||||
vae_source = "vae_loader" if vae_key is not None else "model_loader"
|
||||
|
||||
# If using LoRA, connect model_loader to lora_loader first
|
||||
if lora_key is not None:
|
||||
edges.extend([
|
||||
{
|
||||
"source": {"node_id": "model_loader", "field": "unet"},
|
||||
"destination": {"node_id": "lora_loader", "field": "unet"}
|
||||
},
|
||||
{
|
||||
"source": {"node_id": "model_loader", "field": "clip"},
|
||||
"destination": {"node_id": "lora_loader", "field": "clip"}
|
||||
}
|
||||
])
|
||||
# Note: lora_loader doesn't have a clip2 field, so for SDXL we route clip2 directly from model_loader
|
||||
|
||||
# Connect UNet and CLIP to downstream nodes
|
||||
edges.extend([
|
||||
# Connect UNet to denoise
|
||||
{
|
||||
"source": {"node_id": unet_source, "field": "unet"},
|
||||
"destination": {"node_id": "denoise", "field": "unet"}
|
||||
},
|
||||
# Connect CLIP to prompts
|
||||
{
|
||||
"source": {"node_id": clip_source, "field": "clip"},
|
||||
"destination": {"node_id": "positive_prompt", "field": "clip"}
|
||||
},
|
||||
{
|
||||
"source": {"node_id": clip_source, "field": "clip"},
|
||||
"destination": {"node_id": "negative_prompt", "field": "clip"}
|
||||
},
|
||||
])
|
||||
|
||||
# For SDXL models, also connect clip2
|
||||
# Note: clip2 always comes from model_loader, even when using LoRA (lora_loader doesn't support clip2)
|
||||
if is_sdxl:
|
||||
edges.extend([
|
||||
{
|
||||
"source": {"node_id": "model_loader", "field": "clip2"},
|
||||
"destination": {"node_id": "positive_prompt", "field": "clip2"}
|
||||
},
|
||||
{
|
||||
"source": {"node_id": "model_loader", "field": "clip2"},
|
||||
"destination": {"node_id": "negative_prompt", "field": "clip2"}
|
||||
},
|
||||
])
|
||||
|
||||
edges.extend([
|
||||
|
||||
# Connect prompts to denoise
|
||||
{
|
||||
"source": {"node_id": "positive_prompt", "field": "conditioning"},
|
||||
"destination": {"node_id": "denoise", "field": "positive_conditioning"}
|
||||
},
|
||||
{
|
||||
"source": {"node_id": "negative_prompt", "field": "conditioning"},
|
||||
"destination": {"node_id": "denoise", "field": "negative_conditioning"}
|
||||
},
|
||||
|
||||
# Connect noise to denoise
|
||||
{
|
||||
"source": {"node_id": "noise", "field": "noise"},
|
||||
"destination": {"node_id": "denoise", "field": "noise"}
|
||||
},
|
||||
|
||||
# Connect denoise to latents_to_image
|
||||
{
|
||||
"source": {"node_id": "denoise", "field": "latents"},
|
||||
"destination": {"node_id": "latents_to_image", "field": "latents"}
|
||||
},
|
||||
{
|
||||
"source": {"node_id": vae_source, "field": "vae"},
|
||||
"destination": {"node_id": "latents_to_image", "field": "vae"}
|
||||
},
|
||||
|
||||
# Connect latents_to_image to save_image
|
||||
{
|
||||
"source": {"node_id": "latents_to_image", "field": "image"},
|
||||
"destination": {"node_id": "save_image", "field": "image"}
|
||||
}
|
||||
])
|
||||
|
||||
graph = {
|
||||
"id": "text2img_graph",
|
||||
"nodes": nodes,
|
||||
"edges": edges
|
||||
}
|
||||
|
||||
return graph
|
||||
@@ -1,4 +1,3 @@
|
||||
import subprocess
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
@@ -7,29 +6,7 @@ import base64
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
MODEL_FOLDER = os.environ["MODEL_FOLDER"]
|
||||
API_URL = os.environ["3D_GENERATION_URL"]
|
||||
|
||||
|
||||
def image_to_3d_subprocess(image_path, output_path):
|
||||
venv_python = MODEL_FOLDER + r"\.venv\Scripts\python.exe"
|
||||
script_path = MODEL_FOLDER + r"\run.py"
|
||||
|
||||
args = [image_path, "--output-dir", output_path]
|
||||
command = [venv_python, script_path] + args
|
||||
|
||||
try:
|
||||
# Run the subprocess
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
|
||||
# Print output and errors
|
||||
print("STDOUT:\n", result.stdout)
|
||||
print("STDERR:\n", result.stderr)
|
||||
print("Return Code:", result.returncode)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
API_URL = os.environ["TRELLIS_URL"]
|
||||
|
||||
|
||||
def generate_no_preview(image_base64: str):
|
||||
|
||||
Reference in New Issue
Block a user