1
0
forked from cgvr/DeltaVR
Files
DeltaVR3DModelGeneration/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs

541 lines
19 KiB
C#

using System;
using System.Diagnostics;
using System.Net.Http;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using UnityEngine;
using Valve.Newtonsoft.Json;
using Valve.Newtonsoft.Json.Linq;
public class InvokeAiClient : MonoBehaviour
{
public static InvokeAiClient Instance { get; private set; }
public string INVOKEAI_BASE_URL;
public string DEFAULT_QUEUE_ID = "default";
public string MODEL_KEY;
private HttpClient httpClient;
private void Awake()
{
httpClient = new HttpClient
{
Timeout = TimeSpan.FromSeconds(120)
};
httpClient.BaseAddress = new Uri(INVOKEAI_BASE_URL);
Instance = this;
}
// Start is called before the first frame update
void Start()
{
}
// Update is called once per frame
void Update()
{
}
private async Task<JArray> ListModels(string modelType = "main")
{
var requestUri = $"/api/v2/models/?model_type={Uri.EscapeDataString(modelType)}";
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
var root = JObject.Parse(json);
return (JArray) root["models"];
}
private async Task<JObject> GetModelInfo(string modelKey)
{
var requestUri = $"/api/v2/models/i/{Uri.EscapeDataString(modelKey)}";
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
return JObject.Parse(json);
}
private async Task<JObject> WaitForCompletion(string batchId, int timeoutSeconds = 300)
{
var sw = Stopwatch.StartNew();
string queueId = DEFAULT_QUEUE_ID;
while (true)
{
if (sw.Elapsed.TotalSeconds > timeoutSeconds)
throw new TimeoutException($"Image generation timed out after {timeoutSeconds} seconds");
// Get batch status
var statusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/b/{Uri.EscapeDataString(batchId)}/status";
using var statusResp = await httpClient.GetAsync(statusUrl).ConfigureAwait(false);
statusResp.EnsureSuccessStatusCode();
var statusJson = await statusResp.Content.ReadAsStringAsync().ConfigureAwait(false);
var statusData = JObject.Parse(statusJson);
// Check for failures
int failedCount = statusData.Value<int?>("failed") ?? 0;
if (failedCount > 0)
{
var queueStatusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/status";
using var queueResp = await httpClient.GetAsync(queueStatusUrl).ConfigureAwait(false);
queueResp.EnsureSuccessStatusCode();
var queueJson = await queueResp.Content.ReadAsStringAsync().ConfigureAwait(false);
var queueData = JObject.Parse(queueJson);
throw new InvalidOperationException(
$"Image generation failed. Batch {batchId} has {failedCount} failed item(s). " +
$"Queue status: {queueData.ToString(Formatting.Indented)}"
);
}
// Check completion
int completed = statusData.Value<int?>("completed") ?? 0;
int total = statusData.Value<int?>("total") ?? 0;
if (completed == total && total > 0)
{
// Get most recent non-intermediate image
const string imagesPath = "/api/v1/images/?is_intermediate=false&limit=10";
using var imagesResp = await httpClient.GetAsync(imagesPath).ConfigureAwait(false);
imagesResp.EnsureSuccessStatusCode();
var imagesJson = await imagesResp.Content.ReadAsStringAsync().ConfigureAwait(false);
var imagesData = JObject.Parse(imagesJson);
var items = imagesData["items"] as JArray;
if (items != null && items.Count > 0)
{
var imageName = items[0].Value<string>("image_name");
// Return result object mirroring your Python structure
var result = new JObject
{
["batch_id"] = batchId,
["status"] = "completed",
["result"] = new JObject
{
["outputs"] = new JObject
{
["save_image"] = new JObject
{
["type"] = "image_output",
["image"] = new JObject
{
["image_name"] = imageName
}
}
}
}
};
return result;
}
// If no images found, return the status object
return statusData;
}
// Wait before checking again
await Task.Delay(1000).ConfigureAwait(false);
}
}
private async Task<JObject> EnqueueGraph(JToken graph)
{
string queueId = DEFAULT_QUEUE_ID;
// Build request JSON dynamically
var payload = new JObject
{
["batch"] = new JObject
{
["graph"] = graph, // graph can be any JSON structure
["runs"] = 1,
["data"] = JValue.CreateNull()
}
};
var url = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/enqueue_batch";
using var content = new StringContent(payload.ToString(Formatting.None), Encoding.UTF8, "application/json");
using var resp = await httpClient.PostAsync(url, content).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
return JObject.Parse(json);
}
private static void RequireFields(JObject info, string nameOrKey, params string[] fields)
{
foreach (var f in fields)
{
var v = info[f];
if (v == null || v.Type == JTokenType.Null)
throw new ArgumentException($"Model {nameOrKey} is missing required field: {f}");
}
}
private static long GenerateUInt32Seed()
{
Span<byte> bytes = stackalloc byte[4];
RandomNumberGenerator.Fill(bytes);
uint u = BitConverter.ToUInt32(bytes);
return (long)u; // preserve full 0..4294967295 range
}
private static JObject Edge(string srcNode, string srcField, string dstNode, string dstField) => new JObject
{
["source"] = new JObject { ["node_id"] = srcNode, ["field"] = srcField },
["destination"] = new JObject { ["node_id"] = dstNode, ["field"] = dstField }
};
private async Task<JObject> CreateText2ImgGraph(
string prompt,
string negativePrompt = "",
string modelKey = null,
string loraKey = null,
double loraWeight = 1.0,
string vaeKey = null,
int width = 512,
int height = 512,
int steps = 30,
double cfgScale = 7.5,
string scheduler = "euler",
long? seed = null)
{
// 1) Use default model if not specified: pick first "sd-1" from main list
if (string.IsNullOrEmpty(modelKey))
{
var models = await ListModels("main");
foreach (var token in models)
{
if (token is JObject m && string.Equals(m.Value<string>("base"), "sd-1", StringComparison.OrdinalIgnoreCase))
{
modelKey = m.Value<string>("key");
break;
}
}
if (string.IsNullOrEmpty(modelKey))
throw new ArgumentException("No suitable model found (sd-1)", nameof(modelKey));
}
// 2) Get model information
var modelInfo = await GetModelInfo(modelKey);
if (modelInfo == null)
throw new ArgumentException($"Model {modelKey} not found", nameof(modelKey));
if (modelInfo.Type != JTokenType.Object)
throw new ArgumentException($"Model {modelKey} returned invalid data type: {modelInfo.Type}", nameof(modelKey));
// 3) Validate required fields
RequireFields(modelInfo, modelKey, "key", "hash", "name", "base", "type");
// 4) Generate random 32-bit seed if not provided (0..2^32-1)
if (seed == null)
seed = GenerateUInt32Seed();
// 5) Detect SDXL
bool isSdxl = string.Equals(modelInfo.Value<string>("base"), "sdxl", StringComparison.OrdinalIgnoreCase);
// 6) Build nodes
var nodes = new JObject
{
// Main model loader
["model_loader"] = new JObject
{
["type"] = isSdxl ? "sdxl_model_loader" : "main_model_loader",
["id"] = "model_loader",
["model"] = new JObject
{
["key"] = modelInfo.Value<string>("key"),
["hash"] = modelInfo.Value<string>("hash"),
["name"] = modelInfo.Value<string>("name"),
["base"] = modelInfo.Value<string>("base"),
["type"] = modelInfo.Value<string>("type")
}
},
// Positive prompt
["positive_prompt"] = new JObject
{
["type"] = isSdxl ? "sdxl_compel_prompt" : "compel",
["id"] = "positive_prompt",
["prompt"] = prompt
},
// Negative prompt
["negative_prompt"] = new JObject
{
["type"] = isSdxl ? "sdxl_compel_prompt" : "compel",
["id"] = "negative_prompt",
["prompt"] = negativePrompt
},
// Noise generation
["noise"] = new JObject
{
["type"] = "noise",
["id"] = "noise",
["seed"] = seed,
["width"] = width,
["height"] = height,
["use_cpu"] = false
},
// Denoise latents
["denoise"] = new JObject
{
["type"] = "denoise_latents",
["id"] = "denoise",
["steps"] = steps,
["cfg_scale"] = cfgScale,
["scheduler"] = scheduler,
["denoising_start"] = 0,
["denoising_end"] = 1
},
// Latents to image
["latents_to_image"] = new JObject
{
["type"] = "l2i",
["id"] = "latents_to_image"
},
// Save image
["save_image"] = new JObject
{
["type"] = "save_image",
["id"] = "save_image",
["is_intermediate"] = false
}
};
// SDXL: add style fields (matches your Python **kwargs expansions)
if (isSdxl)
{
(nodes["positive_prompt"] as JObject)["style"] = prompt;
(nodes["negative_prompt"] as JObject)["style"] = "";
}
// 7) Optional: LoRA
if (!string.IsNullOrEmpty(loraKey))
{
var loraInfo = await GetModelInfo(loraKey);
if (loraInfo == null)
throw new ArgumentException($"LoRA model {loraKey} not found", nameof(loraKey));
RequireFields(loraInfo, loraKey, "key", "hash", "name", "base", "type");
nodes["lora_loader"] = new JObject
{
["type"] = "lora_loader",
["id"] = "lora_loader",
["lora"] = new JObject
{
["key"] = loraInfo.Value<string>("key"),
["hash"] = loraInfo.Value<string>("hash"),
["name"] = loraInfo.Value<string>("name"),
["base"] = loraInfo.Value<string>("base"),
["type"] = loraInfo.Value<string>("type")
},
["weight"] = loraWeight
};
}
// 8) Optional: VAE override
if (!string.IsNullOrEmpty(vaeKey))
{
var vaeInfo = await GetModelInfo(vaeKey);
if (vaeInfo == null)
throw new ArgumentException($"VAE model {vaeKey} not found", nameof(vaeKey));
RequireFields(vaeInfo, vaeKey, "key", "hash", "name", "base", "type");
nodes["vae_loader"] = new JObject
{
["type"] = "vae_loader",
["id"] = "vae_loader",
["vae_model"] = new JObject
{
["key"] = vaeInfo.Value<string>("key"),
["hash"] = vaeInfo.Value<string>("hash"),
["name"] = vaeInfo.Value<string>("name"),
["base"] = vaeInfo.Value<string>("base"),
["type"] = vaeInfo.Value<string>("type")
}
};
}
var edges = new JArray();
// Determine sources
bool hasLora = !string.IsNullOrEmpty(loraKey);
string unetSource = hasLora ? "lora_loader" : "model_loader";
string clipSource = hasLora ? "lora_loader" : "model_loader";
string vaeSource = !string.IsNullOrEmpty(vaeKey) ? "vae_loader" : "model_loader";
// If using LoRA, connect model_loader -> lora_loader (unet & clip)
if (hasLora)
{
edges.Add(Edge("model_loader", "unet", "lora_loader", "unet"));
edges.Add(Edge("model_loader", "clip", "lora_loader", "clip"));
// Note: lora_loader doesn't have clip2; SDXL clip2 comes from model_loader directly (handled below)
}
// Connect UNet to denoise
edges.Add(Edge(unetSource, "unet", "denoise", "unet"));
// Connect CLIP to prompts
edges.Add(Edge(clipSource, "clip", "positive_prompt", "clip"));
edges.Add(Edge(clipSource, "clip", "negative_prompt", "clip"));
// SDXL: connect clip2 from model_loader to both prompts
if (isSdxl)
{
edges.Add(Edge("model_loader", "clip2", "positive_prompt", "clip2"));
edges.Add(Edge("model_loader", "clip2", "negative_prompt", "clip2"));
}
// Prompts -> denoise conditioning
edges.Add(Edge("positive_prompt", "conditioning", "denoise", "positive_conditioning"));
edges.Add(Edge("negative_prompt", "conditioning", "denoise", "negative_conditioning"));
// Noise -> denoise
edges.Add(Edge("noise", "noise", "denoise", "noise"));
// Denoise -> l2i, and VAE -> l2i
edges.Add(Edge("denoise", "latents", "latents_to_image", "latents"));
edges.Add(Edge(vaeSource, "vae", "latents_to_image", "vae"));
// l2i -> save_image
edges.Add(Edge("latents_to_image", "image", "save_image", "image"));
// 7) Return final graph object
var graph = new JObject
{
["id"] = "text2img_graph",
["nodes"] = nodes,
["edges"] = edges
};
return graph;
}
private async Task<string> GetImageUrl(string imageName)
{
var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls";
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
var root = JObject.Parse(json);
return root.Value<string>("image_url");
}
private async Task<string> GenerateImageUrl(JObject arguments)
{
if (arguments == null) throw new ArgumentNullException(nameof(arguments));
// --- Extract parameters (with defaults) ---
string prompt = arguments.Value<string>("prompt")
?? throw new ArgumentException("Argument 'prompt' is required.", nameof(arguments));
string negativePrompt = arguments.Value<string>("negative_prompt") ?? "";
int width = arguments.Value<int?>("width") ?? 512;
int height = arguments.Value<int?>("height") ?? 512;
int steps = arguments.Value<int?>("steps") ?? 30;
double cfgScale = arguments.Value<double?>("cfg_scale") ?? 7.5;
string scheduler = arguments.Value<string>("scheduler") ?? "euler";
long? seed = arguments.Value<long?>("seed");
string modelKey = arguments.Value<string>("model_key");
string loraKey = arguments.Value<string>("lora_key");
double loraWeight = arguments.Value<double?>("lora_weight") ?? 1.0;
string vaeKey = arguments.Value<string>("vae_key");
// --- Create graph ---
JObject graph = await CreateText2ImgGraph(
prompt: prompt,
negativePrompt: negativePrompt,
modelKey: modelKey,
loraKey: loraKey,
loraWeight: loraWeight,
vaeKey: vaeKey,
width: width,
height: height,
steps: steps,
cfgScale: cfgScale,
scheduler: scheduler,
seed: seed
);
// --- Enqueue ---
JObject enqueueResult = await EnqueueGraph(graph);
string batchId = enqueueResult.SelectToken("batch.batch_id")?.Value<string>();
if (string.IsNullOrEmpty(batchId))
throw new InvalidOperationException("Enqueue response did not contain 'batch.batch_id'.");
UnityEngine.Debug.Log($"Enqueued batch {batchId}, waiting for completion...");
// --- Wait for completion ---
JObject completed = await WaitForCompletion(batchId);
// --- Extract image output ---
var outputs = completed.SelectToken("result.outputs") as JObject;
if (outputs != null)
{
foreach (var prop in outputs.Properties())
{
var output = prop.Value as JObject;
if (output?.Value<string>("type") == "image_output")
{
string imageName = output.SelectToken("image.image_name")?.Value<string>();
if (string.IsNullOrEmpty(imageName))
continue;
return await GetImageUrl(imageName);
}
}
}
throw new InvalidOperationException("Failed to generate image: no image_output found in result.");
}
public async Task<byte[]> GenerateImage(string prompt)
{
JObject args = new JObject()
{
["prompt"] = prompt,
["width"] = 512,
["height"] = 512,
["model_key"] = MODEL_KEY,
};
UnityEngine.Debug.Log("Starting image generation...");
string imageUrl = await GenerateImageUrl(args);
UnityEngine.Debug.Log("Image URL ready: " + imageUrl);
var req = new HttpRequestMessage(HttpMethod.Get, imageUrl);
using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
resp.EnsureSuccessStatusCode();
return await resp.Content.ReadAsByteArrayAsync();
}
}