forked from cgvr/DeltaVR
489 lines
16 KiB
Python
489 lines
16 KiB
Python
# Based on: https://github.com/coinstax/invokeai-mcp-server
|
|
|
|
import requests
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import httpx
|
|
import os
|
|
|
|
from typing import Optional
|
|
from urllib.parse import urljoin
|
|
from dotenv import load_dotenv
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("invokeai-mcp")
|
|
|
|
load_dotenv()
|
|
|
|
INVOKEAI_BASE_URL = os.environ["INVOKEAI_URL"]
|
|
DEFAULT_QUEUE_ID = "default"
|
|
|
|
# HTTP client
|
|
http_client: Optional[httpx.AsyncClient] = None
|
|
|
|
|
|
|
|
def get_client() -> httpx.AsyncClient:
|
|
"""Get or create HTTP client."""
|
|
global http_client
|
|
if http_client is None:
|
|
http_client = httpx.AsyncClient(base_url=INVOKEAI_BASE_URL, timeout=120.0)
|
|
return http_client
|
|
|
|
|
|
async def text_to_image_invoke_ai(prompt, output_path):
|
|
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
|
args = {
|
|
"prompt": prompt,
|
|
"width": 512,
|
|
"height": 512,
|
|
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
|
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
|
}
|
|
image_url = await generate_image(args)
|
|
print("got image url: ", image_url)
|
|
download_file(image_url, output_path)
|
|
|
|
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
|
"""Wait for a batch to complete and return the most recent image."""
|
|
client = get_client()
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
while True:
|
|
# Check if we've exceeded timeout
|
|
if asyncio.get_event_loop().time() - start_time > timeout:
|
|
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
|
|
# Get batch status
|
|
response = await client.get(f"/api/v1/queue/{queue_id}/b/{batch_id}/status")
|
|
response.raise_for_status()
|
|
status_data = response.json()
|
|
|
|
# Check for failures
|
|
failed_count = status_data.get("failed", 0)
|
|
if failed_count > 0:
|
|
# Try to get error details from the queue
|
|
queue_status_response = await client.get(f"/api/v1/queue/{queue_id}/status")
|
|
queue_status_response.raise_for_status()
|
|
queue_data = queue_status_response.json()
|
|
|
|
raise RuntimeError(
|
|
f"Image generation failed. Batch {batch_id} has {failed_count} failed item(s). "
|
|
f"Queue status: {json.dumps(queue_data, indent=2)}"
|
|
)
|
|
|
|
# Check completion
|
|
completed = status_data.get("completed", 0)
|
|
total = status_data.get("total", 0)
|
|
|
|
if completed == total and total > 0:
|
|
# Get most recent non-intermediate image
|
|
images_response = await client.get("/api/v1/images/?is_intermediate=false&limit=10")
|
|
images_response.raise_for_status()
|
|
images_data = images_response.json()
|
|
|
|
# Return the most recent image (first in the list)
|
|
if images_data.get("items"):
|
|
return {
|
|
"batch_id": batch_id,
|
|
"status": "completed",
|
|
"result": {
|
|
"outputs": {
|
|
"save_image": {
|
|
"type": "image_output",
|
|
"image": {
|
|
"image_name": images_data["items"][0]["image_name"]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
# If no images found, return status
|
|
return status_data
|
|
|
|
# Wait before checking again
|
|
await asyncio.sleep(1)
|
|
|
|
async def generate_image(arguments: dict):
|
|
|
|
# Extract parameters
|
|
prompt = arguments["prompt"]
|
|
negative_prompt = arguments.get("negative_prompt", "")
|
|
width = arguments.get("width", 512)
|
|
height = arguments.get("height", 512)
|
|
steps = arguments.get("steps", 30)
|
|
cfg_scale = arguments.get("cfg_scale", 7.5)
|
|
scheduler = arguments.get("scheduler", "euler")
|
|
seed = arguments.get("seed")
|
|
model_key = arguments.get("model_key")
|
|
lora_key = arguments.get("lora_key")
|
|
lora_weight = arguments.get("lora_weight", 1.0)
|
|
vae_key = arguments.get("vae_key")
|
|
|
|
print(f"Generating image with prompt: {prompt[:50]}...")
|
|
|
|
# Create graph
|
|
graph = await create_text2img_graph(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
model_key=model_key,
|
|
lora_key=lora_key,
|
|
lora_weight=lora_weight,
|
|
vae_key=vae_key,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
cfg_scale=cfg_scale,
|
|
scheduler=scheduler,
|
|
seed=seed
|
|
)
|
|
|
|
# Enqueue and wait for completion
|
|
result = await enqueue_graph(graph)
|
|
batch_id = result["batch"]["batch_id"]
|
|
|
|
print(f"Enqueued batch {batch_id}, waiting for completion...")
|
|
|
|
completed = await wait_for_completion(batch_id)
|
|
|
|
# Extract image name from result
|
|
if "result" in completed and "outputs" in completed["result"]:
|
|
outputs = completed["result"]["outputs"]
|
|
# Find the image output
|
|
for node_id, output in outputs.items():
|
|
if output.get("type") == "image_output":
|
|
image_name = output["image"]["image_name"]
|
|
image_url = await get_image_url(image_name)
|
|
|
|
return urljoin(INVOKEAI_BASE_URL, image_url)
|
|
|
|
raise RuntimeError("Failed to generate image!")
|
|
|
|
def download_file(url, filepath):
|
|
response = requests.get(url)
|
|
|
|
if response.status_code == 200:
|
|
with open(filepath, "wb") as file:
|
|
file.write(response.content)
|
|
else:
|
|
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
|
|
|
async def get_image_url(image_name: str) -> str:
|
|
"""Get the URL for an image."""
|
|
client = get_client()
|
|
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("image_url", "")
|
|
|
|
|
|
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
|
"""Enqueue a graph for processing."""
|
|
client = get_client()
|
|
|
|
batch = {
|
|
"batch": {
|
|
"graph": graph,
|
|
"runs": 1,
|
|
"data": None
|
|
}
|
|
}
|
|
|
|
response = await client.post(
|
|
f"/api/v1/queue/{queue_id}/enqueue_batch",
|
|
json=batch
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def list_models(model_type: str = "main") -> list:
|
|
"""List available models."""
|
|
client = get_client()
|
|
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("models", [])
|
|
|
|
async def get_model_info(model_key: str) -> Optional[dict]:
|
|
"""Get information about a specific model."""
|
|
client = get_client()
|
|
try:
|
|
response = await client.get(f"/api/v2/models/i/{model_key}")
|
|
response.raise_for_status()
|
|
model_data = response.json()
|
|
|
|
# Ensure we have a valid dictionary
|
|
if not isinstance(model_data, dict):
|
|
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
|
return None
|
|
|
|
return model_data
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model info for {model_key}: {e}")
|
|
return None
|
|
|
|
|
|
async def create_text2img_graph(
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
model_key: Optional[str] = None,
|
|
lora_key: Optional[str] = None,
|
|
lora_weight: float = 1.0,
|
|
vae_key: Optional[str] = None,
|
|
width: int = 512,
|
|
height: int = 512,
|
|
steps: int = 30,
|
|
cfg_scale: float = 7.5,
|
|
scheduler: str = "euler",
|
|
seed: Optional[int] = None
|
|
) -> dict:
|
|
"""Create a text-to-image generation graph with optional LoRA and VAE support."""
|
|
|
|
# Use default model if not specified
|
|
if model_key is None:
|
|
# Try to find an sd-1 model
|
|
models = await list_models("main")
|
|
for model in models:
|
|
if model.get("base") == "sd-1":
|
|
model_key = model["key"]
|
|
break
|
|
if model_key is None:
|
|
raise ValueError("No suitable model found")
|
|
|
|
# Get model information
|
|
model_info = await get_model_info(model_key)
|
|
if not model_info:
|
|
raise ValueError(f"Model {model_key} not found")
|
|
|
|
# Validate model info has required fields
|
|
if not isinstance(model_info, dict):
|
|
raise ValueError(f"Model {model_key} returned invalid data type: {type(model_info)}")
|
|
|
|
required_fields = ["key", "hash", "name", "base", "type"]
|
|
for field in required_fields:
|
|
if field not in model_info or model_info[field] is None:
|
|
raise ValueError(f"Model {model_key} is missing required field: {field}")
|
|
|
|
# Generate random seed if not provided
|
|
if seed is None:
|
|
import random
|
|
seed = random.randint(0, 2**32 - 1)
|
|
|
|
# Detect if this is an SDXL model
|
|
is_sdxl = model_info["base"] == "sdxl"
|
|
|
|
# Build nodes dictionary
|
|
nodes = {
|
|
# Main model loader - use sdxl_model_loader for SDXL models
|
|
"model_loader": {
|
|
"type": "sdxl_model_loader" if is_sdxl else "main_model_loader",
|
|
"id": "model_loader",
|
|
"model": {
|
|
"key": model_info["key"],
|
|
"hash": model_info["hash"],
|
|
"name": model_info["name"],
|
|
"base": model_info["base"],
|
|
"type": model_info["type"]
|
|
}
|
|
},
|
|
|
|
# Positive prompt encoding - use sdxl_compel_prompt for SDXL
|
|
"positive_prompt": {
|
|
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
|
"id": "positive_prompt",
|
|
"prompt": prompt,
|
|
**({"style": prompt} if is_sdxl else {})
|
|
},
|
|
|
|
# Negative prompt encoding - use sdxl_compel_prompt for SDXL
|
|
"negative_prompt": {
|
|
"type": "sdxl_compel_prompt" if is_sdxl else "compel",
|
|
"id": "negative_prompt",
|
|
"prompt": negative_prompt,
|
|
**({"style": ""} if is_sdxl else {})
|
|
},
|
|
|
|
# Noise generation
|
|
"noise": {
|
|
"type": "noise",
|
|
"id": "noise",
|
|
"seed": seed,
|
|
"width": width,
|
|
"height": height,
|
|
"use_cpu": False
|
|
},
|
|
|
|
# Denoise latents (main generation step)
|
|
"denoise": {
|
|
"type": "denoise_latents",
|
|
"id": "denoise",
|
|
"steps": steps,
|
|
"cfg_scale": cfg_scale,
|
|
"scheduler": scheduler,
|
|
"denoising_start": 0,
|
|
"denoising_end": 1
|
|
},
|
|
|
|
# Convert latents to image
|
|
"latents_to_image": {
|
|
"type": "l2i",
|
|
"id": "latents_to_image"
|
|
},
|
|
|
|
# Save image
|
|
"save_image": {
|
|
"type": "save_image",
|
|
"id": "save_image",
|
|
"is_intermediate": False
|
|
}
|
|
}
|
|
|
|
# Add LoRA loader if requested
|
|
if lora_key is not None:
|
|
lora_info = await get_model_info(lora_key)
|
|
if not lora_info:
|
|
raise ValueError(f"LoRA model {lora_key} not found")
|
|
|
|
# Validate LoRA info has required fields
|
|
required_fields = ["key", "hash", "name", "base", "type"]
|
|
for field in required_fields:
|
|
if field not in lora_info or lora_info[field] is None:
|
|
raise ValueError(f"LoRA model {lora_key} is missing required field: {field}")
|
|
|
|
nodes["lora_loader"] = {
|
|
"type": "lora_loader",
|
|
"id": "lora_loader",
|
|
"lora": {
|
|
"key": lora_info["key"],
|
|
"hash": lora_info["hash"],
|
|
"name": lora_info["name"],
|
|
"base": lora_info["base"],
|
|
"type": lora_info["type"]
|
|
},
|
|
"weight": lora_weight
|
|
}
|
|
|
|
# Add VAE loader if requested (to override model's built-in VAE)
|
|
if vae_key is not None:
|
|
vae_info = await get_model_info(vae_key)
|
|
if not vae_info:
|
|
raise ValueError(f"VAE model {vae_key} not found")
|
|
|
|
# Validate VAE info has required fields
|
|
required_fields = ["key", "hash", "name", "base", "type"]
|
|
for field in required_fields:
|
|
if field not in vae_info or vae_info[field] is None:
|
|
raise ValueError(f"VAE model {vae_key} is missing required field: {field}")
|
|
|
|
nodes["vae_loader"] = {
|
|
"type": "vae_loader",
|
|
"id": "vae_loader",
|
|
"vae_model": {
|
|
"key": vae_info["key"],
|
|
"hash": vae_info["hash"],
|
|
"name": vae_info["name"],
|
|
"base": vae_info["base"],
|
|
"type": vae_info["type"]
|
|
}
|
|
}
|
|
|
|
# Build edges
|
|
edges = []
|
|
|
|
# Determine source for UNet and CLIP (model_loader or lora_loader)
|
|
unet_source = "lora_loader" if lora_key is not None else "model_loader"
|
|
clip_source = "lora_loader" if lora_key is not None else "model_loader"
|
|
# Determine source for VAE (vae_loader if specified, otherwise model_loader)
|
|
vae_source = "vae_loader" if vae_key is not None else "model_loader"
|
|
|
|
# If using LoRA, connect model_loader to lora_loader first
|
|
if lora_key is not None:
|
|
edges.extend([
|
|
{
|
|
"source": {"node_id": "model_loader", "field": "unet"},
|
|
"destination": {"node_id": "lora_loader", "field": "unet"}
|
|
},
|
|
{
|
|
"source": {"node_id": "model_loader", "field": "clip"},
|
|
"destination": {"node_id": "lora_loader", "field": "clip"}
|
|
}
|
|
])
|
|
# Note: lora_loader doesn't have a clip2 field, so for SDXL we route clip2 directly from model_loader
|
|
|
|
# Connect UNet and CLIP to downstream nodes
|
|
edges.extend([
|
|
# Connect UNet to denoise
|
|
{
|
|
"source": {"node_id": unet_source, "field": "unet"},
|
|
"destination": {"node_id": "denoise", "field": "unet"}
|
|
},
|
|
# Connect CLIP to prompts
|
|
{
|
|
"source": {"node_id": clip_source, "field": "clip"},
|
|
"destination": {"node_id": "positive_prompt", "field": "clip"}
|
|
},
|
|
{
|
|
"source": {"node_id": clip_source, "field": "clip"},
|
|
"destination": {"node_id": "negative_prompt", "field": "clip"}
|
|
},
|
|
])
|
|
|
|
# For SDXL models, also connect clip2
|
|
# Note: clip2 always comes from model_loader, even when using LoRA (lora_loader doesn't support clip2)
|
|
if is_sdxl:
|
|
edges.extend([
|
|
{
|
|
"source": {"node_id": "model_loader", "field": "clip2"},
|
|
"destination": {"node_id": "positive_prompt", "field": "clip2"}
|
|
},
|
|
{
|
|
"source": {"node_id": "model_loader", "field": "clip2"},
|
|
"destination": {"node_id": "negative_prompt", "field": "clip2"}
|
|
},
|
|
])
|
|
|
|
edges.extend([
|
|
|
|
# Connect prompts to denoise
|
|
{
|
|
"source": {"node_id": "positive_prompt", "field": "conditioning"},
|
|
"destination": {"node_id": "denoise", "field": "positive_conditioning"}
|
|
},
|
|
{
|
|
"source": {"node_id": "negative_prompt", "field": "conditioning"},
|
|
"destination": {"node_id": "denoise", "field": "negative_conditioning"}
|
|
},
|
|
|
|
# Connect noise to denoise
|
|
{
|
|
"source": {"node_id": "noise", "field": "noise"},
|
|
"destination": {"node_id": "denoise", "field": "noise"}
|
|
},
|
|
|
|
# Connect denoise to latents_to_image
|
|
{
|
|
"source": {"node_id": "denoise", "field": "latents"},
|
|
"destination": {"node_id": "latents_to_image", "field": "latents"}
|
|
},
|
|
{
|
|
"source": {"node_id": vae_source, "field": "vae"},
|
|
"destination": {"node_id": "latents_to_image", "field": "vae"}
|
|
},
|
|
|
|
# Connect latents_to_image to save_image
|
|
{
|
|
"source": {"node_id": "latents_to_image", "field": "image"},
|
|
"destination": {"node_id": "save_image", "field": "image"}
|
|
}
|
|
])
|
|
|
|
graph = {
|
|
"id": "text2img_graph",
|
|
"nodes": nodes,
|
|
"edges": edges
|
|
}
|
|
|
|
return graph |