forked from cgvr/DeltaVR
85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import requests
|
|
|
|
from invokeai_mcp_server import create_text2img_graph, enqueue_graph, wait_for_completion, get_image_url
|
|
from urllib.parse import urljoin
|
|
|
|
|
|
INVOKEAI_BASE_URL = "http://127.0.0.1:9090"
|
|
|
|
|
|
async def generate_image(arguments: dict):
|
|
|
|
# Extract parameters
|
|
prompt = arguments["prompt"]
|
|
negative_prompt = arguments.get("negative_prompt", "")
|
|
width = arguments.get("width", 512)
|
|
height = arguments.get("height", 512)
|
|
steps = arguments.get("steps", 30)
|
|
cfg_scale = arguments.get("cfg_scale", 7.5)
|
|
scheduler = arguments.get("scheduler", "euler")
|
|
seed = arguments.get("seed")
|
|
model_key = arguments.get("model_key")
|
|
lora_key = arguments.get("lora_key")
|
|
lora_weight = arguments.get("lora_weight", 1.0)
|
|
vae_key = arguments.get("vae_key")
|
|
|
|
print(f"Generating image with prompt: {prompt[:50]}...")
|
|
|
|
# Create graph
|
|
graph = await create_text2img_graph(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
model_key=model_key,
|
|
lora_key=lora_key,
|
|
lora_weight=lora_weight,
|
|
vae_key=vae_key,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
cfg_scale=cfg_scale,
|
|
scheduler=scheduler,
|
|
seed=seed
|
|
)
|
|
|
|
# Enqueue and wait for completion
|
|
result = await enqueue_graph(graph)
|
|
batch_id = result["batch"]["batch_id"]
|
|
|
|
print(f"Enqueued batch {batch_id}, waiting for completion...")
|
|
|
|
completed = await wait_for_completion(batch_id)
|
|
|
|
# Extract image name from result
|
|
if "result" in completed and "outputs" in completed["result"]:
|
|
outputs = completed["result"]["outputs"]
|
|
# Find the image output
|
|
for node_id, output in outputs.items():
|
|
if output.get("type") == "image_output":
|
|
image_name = output["image"]["image_name"]
|
|
image_url = await get_image_url(image_name)
|
|
|
|
return urljoin(INVOKEAI_BASE_URL, image_url)
|
|
|
|
raise RuntimeError("Failed to generate image!")
|
|
|
|
def download_file(url, filepath):
|
|
response = requests.get(url)
|
|
|
|
if response.status_code == 200:
|
|
with open(filepath, "wb") as file:
|
|
file.write(response.content)
|
|
else:
|
|
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
|
|
|
|
|
async def text_to_image_invoke_ai(prompt, output_path):
|
|
# see available model keys via GET http://127.0.0.1:9090/api/v2/models/?model_type=main
|
|
args = {
|
|
"prompt": prompt,
|
|
"model_key": "79401292-0a6b-428d-b7d7-f1e86caeba2b" # Juggernaut XL v9
|
|
#"model_key": "735f6485-6703-498f-929e-07cf0bbbd179" # Dreamshaper 8
|
|
}
|
|
image_url = await generate_image(args)
|
|
print("got image url: ", image_url)
|
|
download_file(image_url, output_path)
|