using System; using System.Diagnostics; using System.Net.Http; using System.Security.Cryptography; using System.Text; using System.Threading.Tasks; using UnityEngine; using Valve.Newtonsoft.Json; using Valve.Newtonsoft.Json.Linq; public class InvokeAiClient : MonoBehaviour { public static InvokeAiClient Instance { get; private set; } public string INVOKEAI_BASE_URL; public string DEFAULT_QUEUE_ID = "default"; public string MODEL_KEY; private HttpClient httpClient; private void Awake() { httpClient = new HttpClient { Timeout = TimeSpan.FromSeconds(120) }; httpClient.BaseAddress = new Uri(INVOKEAI_BASE_URL); Instance = this; } // Start is called before the first frame update void Start() { } // Update is called once per frame void Update() { } private async Task ListModels(string modelType = "main") { var requestUri = $"/api/v2/models/?model_type={Uri.EscapeDataString(modelType)}"; using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); var root = JObject.Parse(json); return (JArray) root["models"]; } private async Task GetModelInfo(string modelKey) { var requestUri = $"/api/v2/models/i/{Uri.EscapeDataString(modelKey)}"; using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); return JObject.Parse(json); } private async Task WaitForCompletion(string batchId, int timeoutSeconds = 300) { var sw = Stopwatch.StartNew(); string queueId = DEFAULT_QUEUE_ID; while (true) { if (sw.Elapsed.TotalSeconds > timeoutSeconds) throw new TimeoutException($"Image generation timed out after {timeoutSeconds} seconds"); // Get batch status var statusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/b/{Uri.EscapeDataString(batchId)}/status"; using var statusResp = await httpClient.GetAsync(statusUrl).ConfigureAwait(false); statusResp.EnsureSuccessStatusCode(); var statusJson = await statusResp.Content.ReadAsStringAsync().ConfigureAwait(false); var statusData = JObject.Parse(statusJson); // Check for failures int failedCount = statusData.Value("failed") ?? 0; if (failedCount > 0) { var queueStatusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/status"; using var queueResp = await httpClient.GetAsync(queueStatusUrl).ConfigureAwait(false); queueResp.EnsureSuccessStatusCode(); var queueJson = await queueResp.Content.ReadAsStringAsync().ConfigureAwait(false); var queueData = JObject.Parse(queueJson); throw new InvalidOperationException( $"Image generation failed. Batch {batchId} has {failedCount} failed item(s). " + $"Queue status: {queueData.ToString(Formatting.Indented)}" ); } // Check completion int completed = statusData.Value("completed") ?? 0; int total = statusData.Value("total") ?? 0; if (completed == total && total > 0) { // Get most recent non-intermediate image const string imagesPath = "/api/v1/images/?is_intermediate=false&limit=10"; using var imagesResp = await httpClient.GetAsync(imagesPath).ConfigureAwait(false); imagesResp.EnsureSuccessStatusCode(); var imagesJson = await imagesResp.Content.ReadAsStringAsync().ConfigureAwait(false); var imagesData = JObject.Parse(imagesJson); var items = imagesData["items"] as JArray; if (items != null && items.Count > 0) { var imageName = items[0].Value("image_name"); // Return result object mirroring your Python structure var result = new JObject { ["batch_id"] = batchId, ["status"] = "completed", ["result"] = new JObject { ["outputs"] = new JObject { ["save_image"] = new JObject { ["type"] = "image_output", ["image"] = new JObject { ["image_name"] = imageName } } } } }; return result; } // If no images found, return the status object return statusData; } // Wait before checking again await Task.Delay(1000).ConfigureAwait(false); } } private async Task EnqueueGraph(JToken graph) { string queueId = DEFAULT_QUEUE_ID; // Build request JSON dynamically var payload = new JObject { ["batch"] = new JObject { ["graph"] = graph, // graph can be any JSON structure ["runs"] = 1, ["data"] = JValue.CreateNull() } }; var url = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/enqueue_batch"; using var content = new StringContent(payload.ToString(Formatting.None), Encoding.UTF8, "application/json"); using var resp = await httpClient.PostAsync(url, content).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); return JObject.Parse(json); } private static void RequireFields(JObject info, string nameOrKey, params string[] fields) { foreach (var f in fields) { var v = info[f]; if (v == null || v.Type == JTokenType.Null) throw new ArgumentException($"Model {nameOrKey} is missing required field: {f}"); } } private static long GenerateUInt32Seed() { Span bytes = stackalloc byte[4]; RandomNumberGenerator.Fill(bytes); uint u = BitConverter.ToUInt32(bytes); return (long)u; // preserve full 0..4294967295 range } private static JObject Edge(string srcNode, string srcField, string dstNode, string dstField) => new JObject { ["source"] = new JObject { ["node_id"] = srcNode, ["field"] = srcField }, ["destination"] = new JObject { ["node_id"] = dstNode, ["field"] = dstField } }; private async Task CreateText2ImgGraph( string prompt, string negativePrompt = "", string modelKey = null, string loraKey = null, double loraWeight = 1.0, string vaeKey = null, int width = 512, int height = 512, int steps = 30, double cfgScale = 7.5, string scheduler = "euler", long? seed = null) { // 1) Use default model if not specified: pick first "sd-1" from main list if (string.IsNullOrEmpty(modelKey)) { var models = await ListModels("main"); foreach (var token in models) { if (token is JObject m && string.Equals(m.Value("base"), "sd-1", StringComparison.OrdinalIgnoreCase)) { modelKey = m.Value("key"); break; } } if (string.IsNullOrEmpty(modelKey)) throw new ArgumentException("No suitable model found (sd-1)", nameof(modelKey)); } // 2) Get model information var modelInfo = await GetModelInfo(modelKey); if (modelInfo == null) throw new ArgumentException($"Model {modelKey} not found", nameof(modelKey)); if (modelInfo.Type != JTokenType.Object) throw new ArgumentException($"Model {modelKey} returned invalid data type: {modelInfo.Type}", nameof(modelKey)); // 3) Validate required fields RequireFields(modelInfo, modelKey, "key", "hash", "name", "base", "type"); // 4) Generate random 32-bit seed if not provided (0..2^32-1) if (seed == null) seed = GenerateUInt32Seed(); // 5) Detect SDXL bool isSdxl = string.Equals(modelInfo.Value("base"), "sdxl", StringComparison.OrdinalIgnoreCase); // 6) Build nodes var nodes = new JObject { // Main model loader ["model_loader"] = new JObject { ["type"] = isSdxl ? "sdxl_model_loader" : "main_model_loader", ["id"] = "model_loader", ["model"] = new JObject { ["key"] = modelInfo.Value("key"), ["hash"] = modelInfo.Value("hash"), ["name"] = modelInfo.Value("name"), ["base"] = modelInfo.Value("base"), ["type"] = modelInfo.Value("type") } }, // Positive prompt ["positive_prompt"] = new JObject { ["type"] = isSdxl ? "sdxl_compel_prompt" : "compel", ["id"] = "positive_prompt", ["prompt"] = prompt }, // Negative prompt ["negative_prompt"] = new JObject { ["type"] = isSdxl ? "sdxl_compel_prompt" : "compel", ["id"] = "negative_prompt", ["prompt"] = negativePrompt }, // Noise generation ["noise"] = new JObject { ["type"] = "noise", ["id"] = "noise", ["seed"] = seed, ["width"] = width, ["height"] = height, ["use_cpu"] = false }, // Denoise latents ["denoise"] = new JObject { ["type"] = "denoise_latents", ["id"] = "denoise", ["steps"] = steps, ["cfg_scale"] = cfgScale, ["scheduler"] = scheduler, ["denoising_start"] = 0, ["denoising_end"] = 1 }, // Latents to image ["latents_to_image"] = new JObject { ["type"] = "l2i", ["id"] = "latents_to_image" }, // Save image ["save_image"] = new JObject { ["type"] = "save_image", ["id"] = "save_image", ["is_intermediate"] = false } }; // SDXL: add style fields (matches your Python **kwargs expansions) if (isSdxl) { (nodes["positive_prompt"] as JObject)["style"] = prompt; (nodes["negative_prompt"] as JObject)["style"] = ""; } // 7) Optional: LoRA if (!string.IsNullOrEmpty(loraKey)) { var loraInfo = await GetModelInfo(loraKey); if (loraInfo == null) throw new ArgumentException($"LoRA model {loraKey} not found", nameof(loraKey)); RequireFields(loraInfo, loraKey, "key", "hash", "name", "base", "type"); nodes["lora_loader"] = new JObject { ["type"] = "lora_loader", ["id"] = "lora_loader", ["lora"] = new JObject { ["key"] = loraInfo.Value("key"), ["hash"] = loraInfo.Value("hash"), ["name"] = loraInfo.Value("name"), ["base"] = loraInfo.Value("base"), ["type"] = loraInfo.Value("type") }, ["weight"] = loraWeight }; } // 8) Optional: VAE override if (!string.IsNullOrEmpty(vaeKey)) { var vaeInfo = await GetModelInfo(vaeKey); if (vaeInfo == null) throw new ArgumentException($"VAE model {vaeKey} not found", nameof(vaeKey)); RequireFields(vaeInfo, vaeKey, "key", "hash", "name", "base", "type"); nodes["vae_loader"] = new JObject { ["type"] = "vae_loader", ["id"] = "vae_loader", ["vae_model"] = new JObject { ["key"] = vaeInfo.Value("key"), ["hash"] = vaeInfo.Value("hash"), ["name"] = vaeInfo.Value("name"), ["base"] = vaeInfo.Value("base"), ["type"] = vaeInfo.Value("type") } }; } var edges = new JArray(); // Determine sources bool hasLora = !string.IsNullOrEmpty(loraKey); string unetSource = hasLora ? "lora_loader" : "model_loader"; string clipSource = hasLora ? "lora_loader" : "model_loader"; string vaeSource = !string.IsNullOrEmpty(vaeKey) ? "vae_loader" : "model_loader"; // If using LoRA, connect model_loader -> lora_loader (unet & clip) if (hasLora) { edges.Add(Edge("model_loader", "unet", "lora_loader", "unet")); edges.Add(Edge("model_loader", "clip", "lora_loader", "clip")); // Note: lora_loader doesn't have clip2; SDXL clip2 comes from model_loader directly (handled below) } // Connect UNet to denoise edges.Add(Edge(unetSource, "unet", "denoise", "unet")); // Connect CLIP to prompts edges.Add(Edge(clipSource, "clip", "positive_prompt", "clip")); edges.Add(Edge(clipSource, "clip", "negative_prompt", "clip")); // SDXL: connect clip2 from model_loader to both prompts if (isSdxl) { edges.Add(Edge("model_loader", "clip2", "positive_prompt", "clip2")); edges.Add(Edge("model_loader", "clip2", "negative_prompt", "clip2")); } // Prompts -> denoise conditioning edges.Add(Edge("positive_prompt", "conditioning", "denoise", "positive_conditioning")); edges.Add(Edge("negative_prompt", "conditioning", "denoise", "negative_conditioning")); // Noise -> denoise edges.Add(Edge("noise", "noise", "denoise", "noise")); // Denoise -> l2i, and VAE -> l2i edges.Add(Edge("denoise", "latents", "latents_to_image", "latents")); edges.Add(Edge(vaeSource, "vae", "latents_to_image", "vae")); // l2i -> save_image edges.Add(Edge("latents_to_image", "image", "save_image", "image")); // 7) Return final graph object var graph = new JObject { ["id"] = "text2img_graph", ["nodes"] = nodes, ["edges"] = edges }; return graph; } private async Task GetImageUrl(string imageName) { var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls"; using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); var root = JObject.Parse(json); return root.Value("image_url"); } private async Task GenerateImageUrl(JObject arguments) { if (arguments == null) throw new ArgumentNullException(nameof(arguments)); // --- Extract parameters (with defaults) --- string prompt = arguments.Value("prompt") ?? throw new ArgumentException("Argument 'prompt' is required.", nameof(arguments)); string negativePrompt = arguments.Value("negative_prompt") ?? ""; int width = arguments.Value("width") ?? 512; int height = arguments.Value("height") ?? 512; int steps = arguments.Value("steps") ?? 30; double cfgScale = arguments.Value("cfg_scale") ?? 7.5; string scheduler = arguments.Value("scheduler") ?? "euler"; long? seed = arguments.Value("seed"); string modelKey = arguments.Value("model_key"); string loraKey = arguments.Value("lora_key"); double loraWeight = arguments.Value("lora_weight") ?? 1.0; string vaeKey = arguments.Value("vae_key"); // --- Create graph --- JObject graph = await CreateText2ImgGraph( prompt: prompt, negativePrompt: negativePrompt, modelKey: modelKey, loraKey: loraKey, loraWeight: loraWeight, vaeKey: vaeKey, width: width, height: height, steps: steps, cfgScale: cfgScale, scheduler: scheduler, seed: seed ); // --- Enqueue --- JObject enqueueResult = await EnqueueGraph(graph); string batchId = enqueueResult.SelectToken("batch.batch_id")?.Value(); if (string.IsNullOrEmpty(batchId)) throw new InvalidOperationException("Enqueue response did not contain 'batch.batch_id'."); UnityEngine.Debug.Log($"Enqueued batch {batchId}, waiting for completion..."); // --- Wait for completion --- JObject completed = await WaitForCompletion(batchId); // --- Extract image output --- var outputs = completed.SelectToken("result.outputs") as JObject; if (outputs != null) { foreach (var prop in outputs.Properties()) { var output = prop.Value as JObject; if (output?.Value("type") == "image_output") { string imageName = output.SelectToken("image.image_name")?.Value(); if (string.IsNullOrEmpty(imageName)) continue; return await GetImageUrl(imageName); } } } throw new InvalidOperationException("Failed to generate image: no image_output found in result."); } public async Task GenerateImage(string prompt) { JObject args = new JObject() { ["prompt"] = prompt, ["width"] = 512, ["height"] = 512, ["model_key"] = MODEL_KEY, }; UnityEngine.Debug.Log("Starting image generation..."); string imageUrl = await GenerateImageUrl(args); UnityEngine.Debug.Log("Image URL ready: " + imageUrl); var req = new HttpRequestMessage(HttpMethod.Get, imageUrl); using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); resp.EnsureSuccessStatusCode(); return await resp.Content.ReadAsByteArrayAsync(); } }