incorporate InvokeAI into start_pipeline.py
This commit is contained in:
parent
fdd4ff827e
commit
590c62eadd
@ -8,7 +8,7 @@ load_dotenv()
|
|||||||
ACCOUNT_ID = os.environ["CLOUDFLARE_ACCOUNT_ID"]
|
ACCOUNT_ID = os.environ["CLOUDFLARE_ACCOUNT_ID"]
|
||||||
API_TOKEN = os.environ["CLOUDFLARE_API_TOKEN"]
|
API_TOKEN = os.environ["CLOUDFLARE_API_TOKEN"]
|
||||||
|
|
||||||
def text_to_image(prompt, output_path):
|
def text_to_image_cloudflare(prompt, output_path):
|
||||||
MODEL = "@cf/black-forest-labs/flux-1-schnell"
|
MODEL = "@cf/black-forest-labs/flux-1-schnell"
|
||||||
URL = f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/{MODEL}"
|
URL = f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/{MODEL}"
|
||||||
|
|
||||||
@ -34,6 +34,8 @@ def text_to_image(prompt, output_path):
|
|||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(img_bytes)
|
f.write(img_bytes)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def refine_text_prompt(prompt):
|
def refine_text_prompt(prompt):
|
||||||
MODEL = "@cf/meta/llama-3.2-3b-instruct"
|
MODEL = "@cf/meta/llama-3.2-3b-instruct"
|
||||||
|
|||||||
@ -1,28 +1,81 @@
|
|||||||
import torch
|
import requests
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusion3Pipeline
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_timestamp = time.time()
|
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url
|
||||||
#model = "stabilityai/stable-diffusion-3.5-medium" # generation time: 13 min
|
from urllib.parse import urljoin
|
||||||
model = "stabilityai/stable-diffusion-3-medium-diffusers" # generation time: 10 min
|
|
||||||
#model = "stabilityai/stable-diffusion-2" # generation time: 4 sec
|
|
||||||
|
|
||||||
pipe = StableDiffusion3Pipeline.from_pretrained(model, torch_dtype=torch.float16)
|
|
||||||
#pipe = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
|
|
||||||
pipe = pipe.to("cuda")
|
|
||||||
|
|
||||||
model_loaded_timestamp = time.time()
|
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
|
||||||
model_load_time = model_loaded_timestamp - start_timestamp
|
|
||||||
print(f"model load time: {round(model_load_time)} seconds")
|
|
||||||
|
|
||||||
prompt = "A majestic broadsword with a golden pommel, no background"
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
guidance_scale=3.0,
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
image_name = "image7.png"
|
async def generate_image(arguments: dict):
|
||||||
image.save(f"images/{image_name}")
|
|
||||||
|
|
||||||
generation_time = time.time() - model_loaded_timestamp
|
# Extract parameters
|
||||||
print(f"image generation time: {round(generation_time)} seconds")
|
prompt = arguments["prompt"]
|
||||||
|
negative_prompt = arguments.get("negative_prompt", "")
|
||||||
|
width = arguments.get("width", 512)
|
||||||
|
height = arguments.get("height", 512)
|
||||||
|
steps = arguments.get("steps", 30)
|
||||||
|
cfg_scale = arguments.get("cfg_scale", 7.5)
|
||||||
|
scheduler = arguments.get("scheduler", "euler")
|
||||||
|
seed = arguments.get("seed")
|
||||||
|
model_key = arguments.get("model_key")
|
||||||
|
lora_key = arguments.get("lora_key")
|
||||||
|
lora_weight = arguments.get("lora_weight", 1.0)
|
||||||
|
vae_key = arguments.get("vae_key")
|
||||||
|
|
||||||
|
print(f"Generating image with prompt: {prompt[:50]}...")
|
||||||
|
|
||||||
|
# Create graph
|
||||||
|
graph = await create_text2img_graph(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
model_key=model_key,
|
||||||
|
lora_key=lora_key,
|
||||||
|
lora_weight=lora_weight,
|
||||||
|
vae_key=vae_key,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
scheduler=scheduler,
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enqueue and wait for completion
|
||||||
|
result = await enqueue_graph(graph)
|
||||||
|
batch_id = result["batch"]["batch_id"]
|
||||||
|
|
||||||
|
print(f"Enqueued batch {batch_id}, waiting for completion...")
|
||||||
|
|
||||||
|
completed = await wait_for_completion(batch_id)
|
||||||
|
|
||||||
|
# Extract image name from result
|
||||||
|
if "result" in completed and "outputs" in completed["result"]:
|
||||||
|
outputs = completed["result"]["outputs"]
|
||||||
|
# Find the image output
|
||||||
|
for node_id, output in outputs.items():
|
||||||
|
if output.get("type") == "image_output":
|
||||||
|
image_name = output["image"]["image_name"]
|
||||||
|
image_url = await get_image_url(image_name)
|
||||||
|
|
||||||
|
return urljoin(INVOKEAI_BASE_URL, image_url)
|
||||||
|
|
||||||
|
raise RuntimeError("Failed to generate image!")
|
||||||
|
|
||||||
|
def download_file(url, filepath):
|
||||||
|
response = requests.get(url)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(filepath, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image_invoke_ai(prompt, output_path):
|
||||||
|
args = {
|
||||||
|
"prompt": prompt
|
||||||
|
}
|
||||||
|
image_url = await generate_image(args)
|
||||||
|
print("got image url: ", image_url)
|
||||||
|
download_file(image_url, output_path)
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_image(arguments: dict):
|
|
||||||
|
|
||||||
# Extract parameters
|
|
||||||
prompt = arguments["prompt"]
|
|
||||||
negative_prompt = arguments.get("negative_prompt", "")
|
|
||||||
width = arguments.get("width", 512)
|
|
||||||
height = arguments.get("height", 512)
|
|
||||||
steps = arguments.get("steps", 30)
|
|
||||||
cfg_scale = arguments.get("cfg_scale", 7.5)
|
|
||||||
scheduler = arguments.get("scheduler", "euler")
|
|
||||||
seed = arguments.get("seed")
|
|
||||||
model_key = arguments.get("model_key")
|
|
||||||
lora_key = arguments.get("lora_key")
|
|
||||||
lora_weight = arguments.get("lora_weight", 1.0)
|
|
||||||
vae_key = arguments.get("vae_key")
|
|
||||||
|
|
||||||
#logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
|
||||||
|
|
||||||
# Create graph
|
|
||||||
graph = await create_text2img_graph(
|
|
||||||
prompt=prompt,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
model_key=model_key,
|
|
||||||
lora_key=lora_key,
|
|
||||||
lora_weight=lora_weight,
|
|
||||||
vae_key=vae_key,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
scheduler=scheduler,
|
|
||||||
seed=seed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enqueue and wait for completion
|
|
||||||
result = await enqueue_graph(graph)
|
|
||||||
batch_id = result["batch"]["batch_id"]
|
|
||||||
|
|
||||||
#logger.info(f"Enqueued batch {batch_id}, waiting for completion...")
|
|
||||||
|
|
||||||
completed = await wait_for_completion(batch_id)
|
|
||||||
|
|
||||||
# Extract image name from result
|
|
||||||
if "result" in completed and "outputs" in completed["result"]:
|
|
||||||
outputs = completed["result"]["outputs"]
|
|
||||||
# Find the image output
|
|
||||||
for node_id, output in outputs.items():
|
|
||||||
if output.get("type") == "image_output":
|
|
||||||
image_name = output["image"]["image_name"]
|
|
||||||
image_url = await get_image_url(image_name)
|
|
||||||
|
|
||||||
text=f"Image generated successfully!\n\nImage Name: {image_name}\nImage URL: {image_url}\n\nYou can view the image at: {urljoin(INVOKEAI_BASE_URL, f'/api/v1/images/i/{image_name}/full')}"
|
|
||||||
print(text)
|
|
||||||
|
|
||||||
# Fallback if we couldn't find image output
|
|
||||||
#text=f"Image generation completed but output format was unexpected. Batch ID: {batch_id}\n\nResult: {json.dumps(completed, indent=2)}"
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
args = {
|
|
||||||
"prompt": "a golden katana with a fancy pommel"
|
|
||||||
}
|
|
||||||
await generate_image(args)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
122
3d-generation-pipeline/notebooks/local_image_generation.ipynb
Normal file
122
3d-generation-pipeline/notebooks/local_image_generation.ipynb
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "50e24baa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url\n",
|
||||||
|
"from urllib.parse import urljoin\n",
|
||||||
|
"\n",
|
||||||
|
"import asyncio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0407cd9a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"INVOKEAI_BASE_URL = \"http://127.0.0.1:9090\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"async def generate_image(arguments: dict):\n",
|
||||||
|
"\n",
|
||||||
|
" # Extract parameters\n",
|
||||||
|
" prompt = arguments[\"prompt\"]\n",
|
||||||
|
" negative_prompt = arguments.get(\"negative_prompt\", \"\")\n",
|
||||||
|
" width = arguments.get(\"width\", 512)\n",
|
||||||
|
" height = arguments.get(\"height\", 512)\n",
|
||||||
|
" steps = arguments.get(\"steps\", 30)\n",
|
||||||
|
" cfg_scale = arguments.get(\"cfg_scale\", 7.5)\n",
|
||||||
|
" scheduler = arguments.get(\"scheduler\", \"euler\")\n",
|
||||||
|
" seed = arguments.get(\"seed\")\n",
|
||||||
|
" model_key = arguments.get(\"model_key\")\n",
|
||||||
|
" lora_key = arguments.get(\"lora_key\")\n",
|
||||||
|
" lora_weight = arguments.get(\"lora_weight\", 1.0)\n",
|
||||||
|
" vae_key = arguments.get(\"vae_key\")\n",
|
||||||
|
"\n",
|
||||||
|
" #logger.info(f\"Generating image with prompt: {prompt[:50]}...\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Create graph\n",
|
||||||
|
" graph = await create_text2img_graph(\n",
|
||||||
|
" prompt=prompt,\n",
|
||||||
|
" negative_prompt=negative_prompt,\n",
|
||||||
|
" model_key=model_key,\n",
|
||||||
|
" lora_key=lora_key,\n",
|
||||||
|
" lora_weight=lora_weight,\n",
|
||||||
|
" vae_key=vae_key,\n",
|
||||||
|
" width=width,\n",
|
||||||
|
" height=height,\n",
|
||||||
|
" steps=steps,\n",
|
||||||
|
" cfg_scale=cfg_scale,\n",
|
||||||
|
" scheduler=scheduler,\n",
|
||||||
|
" seed=seed\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Enqueue and wait for completion\n",
|
||||||
|
" result = await enqueue_graph(graph)\n",
|
||||||
|
" batch_id = result[\"batch\"][\"batch_id\"]\n",
|
||||||
|
"\n",
|
||||||
|
" #logger.info(f\"Enqueued batch {batch_id}, waiting for completion...\")\n",
|
||||||
|
"\n",
|
||||||
|
" completed = await wait_for_completion(batch_id)\n",
|
||||||
|
"\n",
|
||||||
|
" # Extract image name from result\n",
|
||||||
|
" if \"result\" in completed and \"outputs\" in completed[\"result\"]:\n",
|
||||||
|
" outputs = completed[\"result\"][\"outputs\"]\n",
|
||||||
|
" # Find the image output\n",
|
||||||
|
" for node_id, output in outputs.items():\n",
|
||||||
|
" if output.get(\"type\") == \"image_output\":\n",
|
||||||
|
" image_name = output[\"image\"][\"image_name\"]\n",
|
||||||
|
" image_url = await get_image_url(image_name)\n",
|
||||||
|
"\n",
|
||||||
|
" text=f\"Image generated successfully!\\n\\nImage Name: {image_name}\\nImage URL: {image_url}\\n\\nYou can view the image at: {urljoin(INVOKEAI_BASE_URL, f'/api/v1/images/i/{image_name}/full')}\"\n",
|
||||||
|
" print(text)\n",
|
||||||
|
"\n",
|
||||||
|
" # Fallback if we couldn't find image output\n",
|
||||||
|
" #text=f\"Image generation completed but output format was unexpected. Batch ID: {batch_id}\\n\\nResult: {json.dumps(completed, indent=2)}\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6cf9d879",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"async def main():\n",
|
||||||
|
" args = {\n",
|
||||||
|
" \"prompt\": \"a golden katana with a fancy pommel\"\n",
|
||||||
|
" }\n",
|
||||||
|
" await generate_image(args)\n",
|
||||||
|
"\n",
|
||||||
|
"asyncio.run(main())"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from cloudflare_api import text_to_image, refine_text_prompt
|
from cloudflare_api import text_to_image_cloudflare, refine_text_prompt
|
||||||
|
from generate_image_local import text_to_image_invoke_ai
|
||||||
from generate_model_local import image_to_3d_api, image_to_3d_subprocess
|
from generate_model_local import image_to_3d_api, image_to_3d_subprocess
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@ -17,7 +19,7 @@ def get_timestamp():
|
|||||||
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description="Text to 3D model pipeline")
|
parser = argparse.ArgumentParser(description="Text to 3D model pipeline")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="User text prompt")
|
parser.add_argument("--prompt", type=str, required=True, help="User text prompt")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -35,13 +37,15 @@ def main():
|
|||||||
timestamp = get_timestamp()
|
timestamp = get_timestamp()
|
||||||
pipeline_folder = Path(PIPELINE_FOLDER)
|
pipeline_folder = Path(PIPELINE_FOLDER)
|
||||||
image_path = pipeline_folder / "images" / f"{timestamp}.jpg"
|
image_path = pipeline_folder / "images" / f"{timestamp}.jpg"
|
||||||
text_to_image(image_generation_prompt, image_path)
|
# TODO: use Invoke AI or Cloudflare, depending on env var
|
||||||
|
#text_to_image_cloudflare(image_generation_prompt, image_path)
|
||||||
|
await text_to_image_invoke_ai(image_generation_prompt, image_path)
|
||||||
|
|
||||||
print(f"Generated image file: {image_path}")
|
print(f"Generated image file: {image_path}")
|
||||||
model_path = pipeline_folder / "models" / timestamp
|
model_path = pipeline_folder / "models" / timestamp
|
||||||
model_file = image_to_3d_api(image_path, model_path)
|
model_file = image_to_3d_api(image_path, model_path)
|
||||||
#model_file_path = model_path / "0" / "mesh.glb"
|
|
||||||
print(f"Generated 3D model file: {model_file}")
|
print(f"Generated 3D model file: {model_file}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
asyncio.run(main())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user