diff --git a/3d-generation-pipeline/.env.example b/3d-generation-pipeline/.env.example index 976f48cb..ca034486 100644 --- a/3d-generation-pipeline/.env.example +++ b/3d-generation-pipeline/.env.example @@ -1,5 +1,5 @@ -3D_GENERATION_URL= -MODEL_FOLDER= +INVOKEAI_URL= +TRELLIS_URL= CLOUDFLARE_ACCOUNT_ID= CLOUDFLARE_API_TOKEN= \ No newline at end of file diff --git a/3d-generation-pipeline/generate_image_local.py b/3d-generation-pipeline/generate_image_local.py index f8415f25..2a7c4dd8 100644 --- a/3d-generation-pipeline/generate_image_local.py +++ b/3d-generation-pipeline/generate_image_local.py @@ -1,11 +1,111 @@ +# Based on: https://github.com/coinstax/invokeai-mcp-server + import requests +import asyncio +import json +import logging +import httpx +import os -from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url +from typing import Optional from urllib.parse import urljoin +from dotenv import load_dotenv + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("invokeai-mcp") + +load_dotenv() + +INVOKEAI_BASE_URL = os.environ["INVOKEAI_URL"] +DEFAULT_QUEUE_ID = "default" + +# HTTP client +http_client: Optional[httpx.AsyncClient] = None -INVOKEAI_BASE_URL = "http://127.0.0.1:9090" +def get_client() -> httpx.AsyncClient: + """Get or create HTTP client.""" + global http_client + if http_client is None: + http_client = httpx.AsyncClient(base_url=INVOKEAI_BASE_URL, timeout=120.0) + return http_client + + +async def text_to_image_invoke_ai(prompt, output_path): + # see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main + args = { + "prompt": prompt, + "width": 512, + "height": 512, + "model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9 + #"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8 + } + image_url = await generate_image(args) + print("got image url: ", image_url) + download_file(image_url, output_path) + +async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict: + """Wait for a batch to complete and return the most recent image.""" + client = get_client() + start_time = asyncio.get_event_loop().time() + + while True: + # Check if we've exceeded timeout + if asyncio.get_event_loop().time() - start_time > timeout: + raise TimeoutError(f"Image generation timed out after {timeout} seconds") + + # Get batch status + response = await client.get(f"/api/v1/queue/{queue_id}/b/{batch_id}/status") + response.raise_for_status() + status_data = response.json() + + # Check for failures + failed_count = status_data.get("failed", 0) + if failed_count > 0: + # Try to get error details from the queue + queue_status_response = await client.get(f"/api/v1/queue/{queue_id}/status") + queue_status_response.raise_for_status() + queue_data = queue_status_response.json() + + raise RuntimeError( + f"Image generation failed. Batch {batch_id} has {failed_count} failed item(s). " + f"Queue status: {json.dumps(queue_data, indent=2)}" + ) + + # Check completion + completed = status_data.get("completed", 0) + total = status_data.get("total", 0) + + if completed == total and total > 0: + # Get most recent non-intermediate image + images_response = await client.get("/api/v1/images/?is_intermediate=false&limit=10") + images_response.raise_for_status() + images_data = images_response.json() + + # Return the most recent image (first in the list) + if images_data.get("items"): + return { + "batch_id": batch_id, + "status": "completed", + "result": { + "outputs": { + "save_image": { + "type": "image_output", + "image": { + "image_name": images_data["items"][0]["image_name"] + } + } + } + } + } + + # If no images found, return status + return status_data + + # Wait before checking again + await asyncio.sleep(1) async def generate_image(arguments: dict): @@ -70,17 +170,320 @@ def download_file(url, filepath): file.write(response.content) else: raise RuntimeError(f"Failed to download image. Status code: {response.status_code}") + +async def get_image_url(image_name: str) -> str: + """Get the URL for an image.""" + client = get_client() + response = await client.get(f"/api/v1/images/i/{image_name}/urls") + response.raise_for_status() + data = response.json() + return data.get("image_url", "") + +async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict: + """Enqueue a graph for processing.""" + client = get_client() -async def text_to_image_invoke_ai(prompt, output_path): - # see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main - args = { - "prompt": prompt, - "width": 512, - "height": 512, - "model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9 - #"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8 + batch = { + "batch": { + "graph": graph, + "runs": 1, + "data": None + } } - image_url = await generate_image(args) - print("got image url: ", image_url) - download_file(image_url, output_path) + + response = await client.post( + f"/api/v1/queue/{queue_id}/enqueue_batch", + json=batch + ) + response.raise_for_status() + return response.json() + +async def list_models(model_type: str = "main") -> list: + """List available models.""" + client = get_client() + response = await client.get("/api/v2/models/", params={"model_type": model_type}) + response.raise_for_status() + data = response.json() + return data.get("models", []) + +async def get_model_info(model_key: str) -> Optional[dict]: + """Get information about a specific model.""" + client = get_client() + try: + response = await client.get(f"/api/v2/models/i/{model_key}") + response.raise_for_status() + model_data = response.json() + + # Ensure we have a valid dictionary + if not isinstance(model_data, dict): + logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}") + return None + + return model_data + except Exception as e: + logger.error(f"Error fetching model info for {model_key}: {e}") + return None + + +async def create_text2img_graph( + prompt: str, + negative_prompt: str = "", + model_key: Optional[str] = None, + lora_key: Optional[str] = None, + lora_weight: float = 1.0, + vae_key: Optional[str] = None, + width: int = 512, + height: int = 512, + steps: int = 30, + cfg_scale: float = 7.5, + scheduler: str = "euler", + seed: Optional[int] = None +) -> dict: + """Create a text-to-image generation graph with optional LoRA and VAE support.""" + + # Use default model if not specified + if model_key is None: + # Try to find an sd-1 model + models = await list_models("main") + for model in models: + if model.get("base") == "sd-1": + model_key = model["key"] + break + if model_key is None: + raise ValueError("No suitable model found") + + # Get model information + model_info = await get_model_info(model_key) + if not model_info: + raise ValueError(f"Model {model_key} not found") + + # Validate model info has required fields + if not isinstance(model_info, dict): + raise ValueError(f"Model {model_key} returned invalid data type: {type(model_info)}") + + required_fields = ["key", "hash", "name", "base", "type"] + for field in required_fields: + if field not in model_info or model_info[field] is None: + raise ValueError(f"Model {model_key} is missing required field: {field}") + + # Generate random seed if not provided + if seed is None: + import random + seed = random.randint(0, 2**32 - 1) + + # Detect if this is an SDXL model + is_sdxl = model_info["base"] == "sdxl" + + # Build nodes dictionary + nodes = { + # Main model loader - use sdxl_model_loader for SDXL models + "model_loader": { + "type": "sdxl_model_loader" if is_sdxl else "main_model_loader", + "id": "model_loader", + "model": { + "key": model_info["key"], + "hash": model_info["hash"], + "name": model_info["name"], + "base": model_info["base"], + "type": model_info["type"] + } + }, + + # Positive prompt encoding - use sdxl_compel_prompt for SDXL + "positive_prompt": { + "type": "sdxl_compel_prompt" if is_sdxl else "compel", + "id": "positive_prompt", + "prompt": prompt, + **({"style": prompt} if is_sdxl else {}) + }, + + # Negative prompt encoding - use sdxl_compel_prompt for SDXL + "negative_prompt": { + "type": "sdxl_compel_prompt" if is_sdxl else "compel", + "id": "negative_prompt", + "prompt": negative_prompt, + **({"style": ""} if is_sdxl else {}) + }, + + # Noise generation + "noise": { + "type": "noise", + "id": "noise", + "seed": seed, + "width": width, + "height": height, + "use_cpu": False + }, + + # Denoise latents (main generation step) + "denoise": { + "type": "denoise_latents", + "id": "denoise", + "steps": steps, + "cfg_scale": cfg_scale, + "scheduler": scheduler, + "denoising_start": 0, + "denoising_end": 1 + }, + + # Convert latents to image + "latents_to_image": { + "type": "l2i", + "id": "latents_to_image" + }, + + # Save image + "save_image": { + "type": "save_image", + "id": "save_image", + "is_intermediate": False + } + } + + # Add LoRA loader if requested + if lora_key is not None: + lora_info = await get_model_info(lora_key) + if not lora_info: + raise ValueError(f"LoRA model {lora_key} not found") + + # Validate LoRA info has required fields + required_fields = ["key", "hash", "name", "base", "type"] + for field in required_fields: + if field not in lora_info or lora_info[field] is None: + raise ValueError(f"LoRA model {lora_key} is missing required field: {field}") + + nodes["lora_loader"] = { + "type": "lora_loader", + "id": "lora_loader", + "lora": { + "key": lora_info["key"], + "hash": lora_info["hash"], + "name": lora_info["name"], + "base": lora_info["base"], + "type": lora_info["type"] + }, + "weight": lora_weight + } + + # Add VAE loader if requested (to override model's built-in VAE) + if vae_key is not None: + vae_info = await get_model_info(vae_key) + if not vae_info: + raise ValueError(f"VAE model {vae_key} not found") + + # Validate VAE info has required fields + required_fields = ["key", "hash", "name", "base", "type"] + for field in required_fields: + if field not in vae_info or vae_info[field] is None: + raise ValueError(f"VAE model {vae_key} is missing required field: {field}") + + nodes["vae_loader"] = { + "type": "vae_loader", + "id": "vae_loader", + "vae_model": { + "key": vae_info["key"], + "hash": vae_info["hash"], + "name": vae_info["name"], + "base": vae_info["base"], + "type": vae_info["type"] + } + } + + # Build edges + edges = [] + + # Determine source for UNet and CLIP (model_loader or lora_loader) + unet_source = "lora_loader" if lora_key is not None else "model_loader" + clip_source = "lora_loader" if lora_key is not None else "model_loader" + # Determine source for VAE (vae_loader if specified, otherwise model_loader) + vae_source = "vae_loader" if vae_key is not None else "model_loader" + + # If using LoRA, connect model_loader to lora_loader first + if lora_key is not None: + edges.extend([ + { + "source": {"node_id": "model_loader", "field": "unet"}, + "destination": {"node_id": "lora_loader", "field": "unet"} + }, + { + "source": {"node_id": "model_loader", "field": "clip"}, + "destination": {"node_id": "lora_loader", "field": "clip"} + } + ]) + # Note: lora_loader doesn't have a clip2 field, so for SDXL we route clip2 directly from model_loader + + # Connect UNet and CLIP to downstream nodes + edges.extend([ + # Connect UNet to denoise + { + "source": {"node_id": unet_source, "field": "unet"}, + "destination": {"node_id": "denoise", "field": "unet"} + }, + # Connect CLIP to prompts + { + "source": {"node_id": clip_source, "field": "clip"}, + "destination": {"node_id": "positive_prompt", "field": "clip"} + }, + { + "source": {"node_id": clip_source, "field": "clip"}, + "destination": {"node_id": "negative_prompt", "field": "clip"} + }, + ]) + + # For SDXL models, also connect clip2 + # Note: clip2 always comes from model_loader, even when using LoRA (lora_loader doesn't support clip2) + if is_sdxl: + edges.extend([ + { + "source": {"node_id": "model_loader", "field": "clip2"}, + "destination": {"node_id": "positive_prompt", "field": "clip2"} + }, + { + "source": {"node_id": "model_loader", "field": "clip2"}, + "destination": {"node_id": "negative_prompt", "field": "clip2"} + }, + ]) + + edges.extend([ + + # Connect prompts to denoise + { + "source": {"node_id": "positive_prompt", "field": "conditioning"}, + "destination": {"node_id": "denoise", "field": "positive_conditioning"} + }, + { + "source": {"node_id": "negative_prompt", "field": "conditioning"}, + "destination": {"node_id": "denoise", "field": "negative_conditioning"} + }, + + # Connect noise to denoise + { + "source": {"node_id": "noise", "field": "noise"}, + "destination": {"node_id": "denoise", "field": "noise"} + }, + + # Connect denoise to latents_to_image + { + "source": {"node_id": "denoise", "field": "latents"}, + "destination": {"node_id": "latents_to_image", "field": "latents"} + }, + { + "source": {"node_id": vae_source, "field": "vae"}, + "destination": {"node_id": "latents_to_image", "field": "vae"} + }, + + # Connect latents_to_image to save_image + { + "source": {"node_id": "latents_to_image", "field": "image"}, + "destination": {"node_id": "save_image", "field": "image"} + } + ]) + + graph = { + "id": "text2img_graph", + "nodes": nodes, + "edges": edges + } + + return graph \ No newline at end of file diff --git a/3d-generation-pipeline/generate_model_local.py b/3d-generation-pipeline/generate_model_local.py index 76474287..57ebbe74 100644 --- a/3d-generation-pipeline/generate_model_local.py +++ b/3d-generation-pipeline/generate_model_local.py @@ -1,4 +1,3 @@ -import subprocess import os import time import requests @@ -7,29 +6,7 @@ import base64 from dotenv import load_dotenv load_dotenv() -MODEL_FOLDER = os.environ["MODEL_FOLDER"] -API_URL = os.environ["3D_GENERATION_URL"] - - -def image_to_3d_subprocess(image_path, output_path): - venv_python = MODEL_FOLDER + r"\.venv\Scripts\python.exe" - script_path = MODEL_FOLDER + r"\run.py" - - args = [image_path, "--output-dir", output_path] - command = [venv_python, script_path] + args - - try: - # Run the subprocess - result = subprocess.run(command, capture_output=True, text=True) - - # Print output and errors - print("STDOUT:\n", result.stdout) - print("STDERR:\n", result.stderr) - print("Return Code:", result.returncode) - - except Exception as e: - print(f"Error occurred: {e}") - +API_URL = os.environ["TRELLIS_URL"] def generate_no_preview(image_base64: str):