From 1b3b3db1bfb3a15b6301d7ccc3ce0fd44ce0a335 Mon Sep 17 00:00:00 2001 From: henrisel Date: Mon, 22 Dec 2025 13:19:28 +0200 Subject: [PATCH] port InvokeAI API client to Unity, use it in ImageGenerationBox --- .../generate_image_local.py | 135 ++--- .../ModelGeneration/ImageGenerationBox.prefab | 303 ++++++++++ .../ImageGenerationBox.prefab.meta | 7 + .../VoiceTranscriptionBox.prefab | 11 +- .../_PROJECT/Scenes/DeltaBuilding_base.unity | 4 +- .../ModeGeneration/ImageGenerationBox.cs | 75 +++ .../ModeGeneration/ImageGenerationBox.cs.meta | 11 + .../Scripts/ModeGeneration/InvokeAiClient.cs | 542 ++++++++++++++++++ .../ModeGeneration/InvokeAiClient.cs.meta | 11 + Packages/manifest.json | 1 + Packages/packages-lock.json | 7 + 11 files changed, 1033 insertions(+), 74 deletions(-) create mode 100644 Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab create mode 100644 Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab.meta create mode 100644 Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs create mode 100644 Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs.meta create mode 100644 Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs create mode 100644 Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs.meta diff --git a/3d-generation-pipeline/generate_image_local.py b/3d-generation-pipeline/generate_image_local.py index 18c31226..22aafcfb 100644 --- a/3d-generation-pipeline/generate_image_local.py +++ b/3d-generation-pipeline/generate_image_local.py @@ -34,18 +34,6 @@ def get_client() -> httpx.AsyncClient: return http_client -async def text_to_image_invoke_ai(prompt, output_path): - # see available model keys via GET http://INVOKEAI_BASE_URL:9090/api/v2/models/?model_type=main - args = { - "prompt": prompt, - "width": 512, - "height": 512, - "model_key": INVOKEAI_MODEL_KEY - } - image_url = await generate_image(args) - print("got image url: ", image_url) - download_file(image_url, output_path) - async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict: """Wait for a batch to complete and return the most recent image.""" client = get_client() @@ -107,60 +95,7 @@ async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, t # Wait before checking again await asyncio.sleep(1) -async def generate_image(arguments: dict): - # Extract parameters - prompt = arguments["prompt"] - negative_prompt = arguments.get("negative_prompt", "") - width = arguments.get("width", 512) - height = arguments.get("height", 512) - steps = arguments.get("steps", 30) - cfg_scale = arguments.get("cfg_scale", 7.5) - scheduler = arguments.get("scheduler", "euler") - seed = arguments.get("seed") - model_key = arguments.get("model_key") - lora_key = arguments.get("lora_key") - lora_weight = arguments.get("lora_weight", 1.0) - vae_key = arguments.get("vae_key") - - print(f"Generating image with prompt: {prompt[:50]}...") - - # Create graph - graph = await create_text2img_graph( - prompt=prompt, - negative_prompt=negative_prompt, - model_key=model_key, - lora_key=lora_key, - lora_weight=lora_weight, - vae_key=vae_key, - width=width, - height=height, - steps=steps, - cfg_scale=cfg_scale, - scheduler=scheduler, - seed=seed - ) - - # Enqueue and wait for completion - result = await enqueue_graph(graph) - batch_id = result["batch"]["batch_id"] - - print(f"Enqueued batch {batch_id}, waiting for completion...") - - completed = await wait_for_completion(batch_id) - - # Extract image name from result - if "result" in completed and "outputs" in completed["result"]: - outputs = completed["result"]["outputs"] - # Find the image output - for node_id, output in outputs.items(): - if output.get("type") == "image_output": - image_name = output["image"]["image_name"] - image_url = await get_image_url(image_name) - - return urljoin(INVOKEAI_BASE_URL, image_url) - - raise RuntimeError("Failed to generate image!") def download_file(url, filepath): response = requests.get(url) @@ -486,4 +421,72 @@ async def create_text2img_graph( "edges": edges } - return graph \ No newline at end of file + return graph + +async def generate_image(arguments: dict): + + # Extract parameters + prompt = arguments["prompt"] + negative_prompt = arguments.get("negative_prompt", "") + width = arguments.get("width", 512) + height = arguments.get("height", 512) + steps = arguments.get("steps", 30) + cfg_scale = arguments.get("cfg_scale", 7.5) + scheduler = arguments.get("scheduler", "euler") + seed = arguments.get("seed") + model_key = arguments.get("model_key") + lora_key = arguments.get("lora_key") + lora_weight = arguments.get("lora_weight", 1.0) + vae_key = arguments.get("vae_key") + + print(f"Generating image with prompt: {prompt[:50]}...") + + # Create graph + graph = await create_text2img_graph( + prompt=prompt, + negative_prompt=negative_prompt, + model_key=model_key, + lora_key=lora_key, + lora_weight=lora_weight, + vae_key=vae_key, + width=width, + height=height, + steps=steps, + cfg_scale=cfg_scale, + scheduler=scheduler, + seed=seed + ) + + # Enqueue and wait for completion + result = await enqueue_graph(graph) + batch_id = result["batch"]["batch_id"] + + print(f"Enqueued batch {batch_id}, waiting for completion...") + + completed = await wait_for_completion(batch_id) + + # Extract image name from result + if "result" in completed and "outputs" in completed["result"]: + outputs = completed["result"]["outputs"] + # Find the image output + for node_id, output in outputs.items(): + if output.get("type") == "image_output": + image_name = output["image"]["image_name"] + image_url = await get_image_url(image_name) + + return urljoin(INVOKEAI_BASE_URL, image_url) + + raise RuntimeError("Failed to generate image!") + + +async def text_to_image_invoke_ai(prompt, output_path): + # see available model keys via GET http://INVOKEAI_BASE_URL:9090/api/v2/models/?model_type=main + args = { + "prompt": prompt, + "width": 512, + "height": 512, + "model_key": INVOKEAI_MODEL_KEY + } + image_url = await generate_image(args) + print("got image url: ", image_url) + download_file(image_url, output_path) diff --git a/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab b/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab new file mode 100644 index 00000000..18dce247 --- /dev/null +++ b/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab @@ -0,0 +1,303 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!1 &2138134584281388958 +GameObject: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 + m_Component: + - component: {fileID: 2929509286471910883} + - component: {fileID: 4538985267174058603} + - component: {fileID: 6899506685386251279} + m_Layer: 0 + m_Name: Image + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &2929509286471910883 +RectTransform: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 2138134584281388958} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_ConstrainProportionsScale: 0 + m_Children: [] + m_Father: {fileID: 8725989738242400994} + m_RootOrder: -1 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0.5, y: 0.5} + m_AnchorMax: {x: 0.5, y: 0.5} + m_AnchoredPosition: {x: 0, y: 0} + m_SizeDelta: {x: 100, y: 100} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!222 &4538985267174058603 +CanvasRenderer: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 2138134584281388958} + m_CullTransparentMesh: 1 +--- !u!114 &6899506685386251279 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 2138134584281388958} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: fe87c0e1cc204ed48ad3b37840f39efc, type: 3} + m_Name: + m_EditorClassIdentifier: + m_Material: {fileID: 0} + m_Color: {r: 1, g: 1, b: 1, a: 1} + m_RaycastTarget: 1 + m_RaycastPadding: {x: 0, y: 0, z: 0, w: 0} + m_Maskable: 1 + m_OnCullStateChanged: + m_PersistentCalls: + m_Calls: [] + m_Sprite: {fileID: 0} + m_Type: 0 + m_PreserveAspect: 0 + m_FillCenter: 1 + m_FillMethod: 4 + m_FillAmount: 1 + m_FillClockwise: 1 + m_FillOrigin: 0 + m_UseSpriteMesh: 0 + m_PixelsPerUnitMultiplier: 1 +--- !u!1 &6494138504010530631 +GameObject: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 + m_Component: + - component: {fileID: 8725989738242400994} + - component: {fileID: 6997284789583139794} + - component: {fileID: 6770534627625930609} + - component: {fileID: 21121296870448173} + m_Layer: 0 + m_Name: Canvas + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &8725989738242400994 +RectTransform: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 6494138504010530631} + m_LocalRotation: {x: -0, y: -1, z: -0, w: -0.00000035762784} + m_LocalPosition: {x: 0, y: 0, z: 0.5} + m_LocalScale: {x: 0.01, y: 0.01, z: 0.01} + m_ConstrainProportionsScale: 0 + m_Children: + - {fileID: 2929509286471910883} + m_Father: {fileID: 1000498446801613149} + m_RootOrder: -1 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0, y: 0} + m_AnchorMax: {x: 0, y: 0} + m_AnchoredPosition: {x: 0, y: 1.2} + m_SizeDelta: {x: 100, y: 100} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!223 &6997284789583139794 +Canvas: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 6494138504010530631} + m_Enabled: 1 + serializedVersion: 3 + m_RenderMode: 2 + m_Camera: {fileID: 0} + m_PlaneDistance: 100 + m_PixelPerfect: 0 + m_ReceivesEvents: 1 + m_OverrideSorting: 0 + m_OverridePixelPerfect: 0 + m_SortingBucketNormalizedSize: 0 + m_AdditionalShaderChannelsFlag: 0 + m_UpdateRectTransformForStandalone: 0 + m_SortingLayerID: 0 + m_SortingOrder: 0 + m_TargetDisplay: 0 +--- !u!114 &6770534627625930609 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 6494138504010530631} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 0cd44c1031e13a943bb63640046fad76, type: 3} + m_Name: + m_EditorClassIdentifier: + m_UiScaleMode: 0 + m_ReferencePixelsPerUnit: 100 + m_ScaleFactor: 1 + m_ReferenceResolution: {x: 800, y: 600} + m_ScreenMatchMode: 0 + m_MatchWidthOrHeight: 0 + m_PhysicalUnit: 3 + m_FallbackScreenDPI: 96 + m_DefaultSpriteDPI: 96 + m_DynamicPixelsPerUnit: 1 + m_PresetInfoIsWorld: 1 +--- !u!114 &21121296870448173 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 6494138504010530631} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: dc42784cf147c0c48a680349fa168899, type: 3} + m_Name: + m_EditorClassIdentifier: + m_IgnoreReversedGraphics: 1 + m_BlockingObjects: 0 + m_BlockingMask: + serializedVersion: 2 + m_Bits: 4294967295 +--- !u!1 &8617702063501079407 +GameObject: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 + m_Component: + - component: {fileID: 1000498446801613149} + - component: {fileID: 2692232214199165587} + - component: {fileID: 3054822165453666587} + - component: {fileID: 6212693736535064192} + - component: {fileID: 643945743491794782} + m_Layer: 0 + m_Name: ImageGenerationBox + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!4 &1000498446801613149 +Transform: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 8617702063501079407} + m_LocalRotation: {x: -0, y: 1, z: -0, w: -0.00000035762784} + m_LocalPosition: {x: -77.521, y: 5.092, z: -13.493} + m_LocalScale: {x: 0.75, y: 0.75, z: 0.75} + m_ConstrainProportionsScale: 1 + m_Children: + - {fileID: 8725989738242400994} + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 180, z: 0} +--- !u!33 &2692232214199165587 +MeshFilter: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 8617702063501079407} + m_Mesh: {fileID: 10202, guid: 0000000000000000e000000000000000, type: 0} +--- !u!23 &3054822165453666587 +MeshRenderer: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 8617702063501079407} + m_Enabled: 1 + m_CastShadows: 1 + m_ReceiveShadows: 1 + m_DynamicOccludee: 1 + m_StaticShadowCaster: 0 + m_MotionVectors: 1 + m_LightProbeUsage: 1 + m_ReflectionProbeUsage: 1 + m_RayTracingMode: 2 + m_RayTraceProcedural: 0 + m_RenderingLayerMask: 1 + m_RendererPriority: 0 + m_Materials: + - {fileID: 2100000, guid: 707a698b0ec80454a8c68700bca72941, type: 2} + m_StaticBatchInfo: + firstSubMesh: 0 + subMeshCount: 0 + m_StaticBatchRoot: {fileID: 0} + m_ProbeAnchor: {fileID: 0} + m_LightProbeVolumeOverride: {fileID: 0} + m_ScaleInLightmap: 1 + m_ReceiveGI: 1 + m_PreserveUVs: 0 + m_IgnoreNormalsForChartDetection: 0 + m_ImportantGI: 0 + m_StitchLightmapSeams: 1 + m_SelectedEditorRenderState: 3 + m_MinimumChartSize: 4 + m_AutoUVMaxDistance: 0.5 + m_AutoUVMaxAngle: 89 + m_LightmapParameters: {fileID: 0} + m_SortingLayerID: 0 + m_SortingLayer: 0 + m_SortingOrder: 0 + m_AdditionalVertexStreams: {fileID: 0} +--- !u!65 &6212693736535064192 +BoxCollider: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 8617702063501079407} + m_Material: {fileID: 0} + m_IncludeLayers: + serializedVersion: 2 + m_Bits: 0 + m_ExcludeLayers: + serializedVersion: 2 + m_Bits: 0 + m_LayerOverridePriority: 0 + m_IsTrigger: 1 + m_ProvidesContacts: 0 + m_Enabled: 1 + serializedVersion: 3 + m_Size: {x: 1, y: 1, z: 1} + m_Center: {x: 0, y: 0, z: 0} +--- !u!114 &643945743491794782 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 8617702063501079407} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: ea7eedaa608bac7449ba7c5a36697607, type: 3} + m_Name: + m_EditorClassIdentifier: + inactiveMaterial: {fileID: 2100000, guid: 707a698b0ec80454a8c68700bca72941, type: 2} + loadingMaterial: {fileID: 2100000, guid: 33390c6f2eb32df47809c60975868a0c, type: 2} + voiceTranscriptionTestBox: {fileID: 0} + UIImage: {fileID: 6899506685386251279} diff --git a/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab.meta b/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab.meta new file mode 100644 index 00000000..a77af7e8 --- /dev/null +++ b/Assets/_PROJECT/Prefabs/ModelGeneration/ImageGenerationBox.prefab.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: bb349299ccb9f2046b015f7b15478f54 +PrefabImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/_PROJECT/Prefabs/ModelGeneration/VoiceTranscriptionBox.prefab b/Assets/_PROJECT/Prefabs/ModelGeneration/VoiceTranscriptionBox.prefab index fa5c4ccd..906474ca 100644 --- a/Assets/_PROJECT/Prefabs/ModelGeneration/VoiceTranscriptionBox.prefab +++ b/Assets/_PROJECT/Prefabs/ModelGeneration/VoiceTranscriptionBox.prefab @@ -27,7 +27,7 @@ RectTransform: m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 669736891457552810} m_LocalRotation: {x: -0, y: -0, z: -0, w: 1} - m_LocalPosition: {x: 0, y: 0, z: 0.857} + m_LocalPosition: {x: 0, y: 0, z: 0.5} m_LocalScale: {x: 0.01, y: 0.01, z: 0.01} m_ConstrainProportionsScale: 0 m_Children: @@ -37,8 +37,8 @@ RectTransform: m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} m_AnchorMin: {x: 0, y: 0} m_AnchorMax: {x: 0, y: 0} - m_AnchoredPosition: {x: 0, y: 1.204} - m_SizeDelta: {x: 400, y: 100} + m_AnchoredPosition: {x: 0, y: 1.2} + m_SizeDelta: {x: 250, y: 100} m_Pivot: {x: 0.5, y: 0.5} --- !u!223 &6879637936960607693 Canvas: @@ -214,7 +214,7 @@ Transform: m_Children: - {fileID: 4986844661789441171} m_Father: {fileID: 0} - m_RootOrder: 36 + m_RootOrder: 0 m_LocalEulerAnglesHint: {x: 0, y: 180, z: 0} --- !u!33 &3449909605412981856 MeshFilter: @@ -334,7 +334,6 @@ MonoBehaviour: whisper: {fileID: 0} microphoneRecord: {fileID: 4391541691227486968} outputText: {fileID: 4513192310212875305} - currentTextOutput: --- !u!1 &5819114791296431922 GameObject: m_ObjectHideFlags: 0 @@ -371,7 +370,7 @@ RectTransform: m_AnchorMin: {x: 0.5, y: 0.5} m_AnchorMax: {x: 0.5, y: 0.5} m_AnchoredPosition: {x: 0, y: 0} - m_SizeDelta: {x: 380, y: 80} + m_SizeDelta: {x: 240, y: 80} m_Pivot: {x: 0.5, y: 0.5} --- !u!222 &525409158048368038 CanvasRenderer: diff --git a/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity b/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity index 10b79f50..6d96c48a 100644 --- a/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity +++ b/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2359b488e4de5293f5fa23a2a331147989b2e3bc3fe2dc789888ba91d6b8b4b4 -size 63188794 +oid sha256:9350888e1de24412c17b98e0860e69b6891be95789d7a56b558f01fa431d1de4 +size 63207713 diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs b/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs new file mode 100644 index 00000000..36901b86 --- /dev/null +++ b/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs @@ -0,0 +1,75 @@ +using System; +using Unity.XR.CoreUtils; +using UnityEngine; +using UnityEngine.UI; + +public class ImageGenerationBox : MonoBehaviour +{ + public Material inactiveMaterial; + public Material loadingMaterial; + + public VoiceTranscriptionBox voiceTranscriptionTestBox; + public Image UIImage; + public string promptSuffix = ", single object, front and side fully visible, realistic style, plain neutral background, clear details, soft studio lighting, true-to-scale"; + + private MeshRenderer meshRenderer; + private bool isLoading; + + // Start is called before the first frame update + void Start() + { + meshRenderer = GetComponent(); + } + + // Update is called once per frame + void Update() + { + + } + + async void OnTriggerEnter(Collider other) + { + if (isLoading) return; + + KbmController controller = other.GetComponent(); + XROrigin playerOrigin = other.GetComponent(); + if (controller != null || playerOrigin != null) + { + string inputPrompt = voiceTranscriptionTestBox.LastTextOutput; + string refinedPrompt = inputPrompt + promptSuffix; + + isLoading = true; + meshRenderer.material = loadingMaterial; + + byte[] imageBytes = await InvokeAiClient.Instance.GenerateImage(refinedPrompt); + Sprite sprite = CreateSprite(imageBytes); + UIImage.sprite = sprite; + + isLoading = false; + meshRenderer.material = inactiveMaterial; + } + } + + private Sprite CreateSprite(byte[] imageBytes) + { + var tex = new Texture2D(2, 2, TextureFormat.RGBA32, false); + // ImageConversion.LoadImage returns bool (true = success) + if (!ImageConversion.LoadImage(tex, imageBytes, markNonReadable: false)) + { + Destroy(tex); + throw new InvalidOperationException("Failed to decode image bytes into Texture2D."); + } + + tex.filterMode = FilterMode.Bilinear; + tex.wrapMode = TextureWrapMode.Clamp; + + var sprite = Sprite.Create( + tex, + new Rect(0, 0, tex.width, tex.height), + new Vector2(0.5f, 0.5f), + pixelsPerUnit: 100f + ); + + return sprite; + } +} diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs.meta b/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs.meta new file mode 100644 index 00000000..f0e54ff4 --- /dev/null +++ b/Assets/_PROJECT/Scripts/ModeGeneration/ImageGenerationBox.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ea7eedaa608bac7449ba7c5a36697607 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs new file mode 100644 index 00000000..0b7b144f --- /dev/null +++ b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs @@ -0,0 +1,542 @@ + +using System; +using System.Diagnostics; +using System.Net.Http; +using System.Security.Cryptography; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using Valve.Newtonsoft.Json; +using Valve.Newtonsoft.Json.Linq; + + +public class InvokeAiClient : MonoBehaviour +{ + public static InvokeAiClient Instance { get; private set; } + + public string INVOKEAI_BASE_URL; + public string DEFAULT_QUEUE_ID = "default"; + public string MODEL_KEY; + + private HttpClient httpClient; + + private void Awake() + { + httpClient = new HttpClient + { + Timeout = TimeSpan.FromSeconds(120) + }; + httpClient.BaseAddress = new Uri(INVOKEAI_BASE_URL); + + Instance = this; + } + + // Start is called before the first frame update + void Start() + { + + } + + // Update is called once per frame + void Update() + { + + } + + private async Task ListModels(string modelType = "main") + { + var requestUri = $"/api/v2/models/?model_type={Uri.EscapeDataString(modelType)}"; + using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); + resp.EnsureSuccessStatusCode(); + + var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); + var root = JObject.Parse(json); + return (JArray) root["models"]; + } + + private async Task GetModelInfo(string modelKey) + { + var requestUri = $"/api/v2/models/i/{Uri.EscapeDataString(modelKey)}"; + using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); + resp.EnsureSuccessStatusCode(); + + var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); + return JObject.Parse(json); + } + + private async Task GetImageUrl(string imageName) + { + var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls"; + UnityEngine.Debug.Log("Get image URL: " + requestUri); + using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); + resp.EnsureSuccessStatusCode(); + + var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); + var root = JObject.Parse(json); + return root.Value("image_url"); + } + + + private async Task WaitForCompletion(string batchId, int timeoutSeconds = 300) + { + var sw = Stopwatch.StartNew(); + string queueId = DEFAULT_QUEUE_ID; + + while (true) + { + if (sw.Elapsed.TotalSeconds > timeoutSeconds) + throw new TimeoutException($"Image generation timed out after {timeoutSeconds} seconds"); + + // Get batch status + var statusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/b/{Uri.EscapeDataString(batchId)}/status"; + using var statusResp = await httpClient.GetAsync(statusUrl).ConfigureAwait(false); + statusResp.EnsureSuccessStatusCode(); + + var statusJson = await statusResp.Content.ReadAsStringAsync().ConfigureAwait(false); + var statusData = JObject.Parse(statusJson); + + // Check for failures + int failedCount = statusData.Value("failed") ?? 0; + if (failedCount > 0) + { + var queueStatusUrl = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/status"; + using var queueResp = await httpClient.GetAsync(queueStatusUrl).ConfigureAwait(false); + queueResp.EnsureSuccessStatusCode(); + + var queueJson = await queueResp.Content.ReadAsStringAsync().ConfigureAwait(false); + var queueData = JObject.Parse(queueJson); + + throw new InvalidOperationException( + $"Image generation failed. Batch {batchId} has {failedCount} failed item(s). " + + $"Queue status: {queueData.ToString(Formatting.Indented)}" + ); + } + + // Check completion + int completed = statusData.Value("completed") ?? 0; + int total = statusData.Value("total") ?? 0; + + if (completed == total && total > 0) + { + // Get most recent non-intermediate image + const string imagesPath = "/api/v1/images/?is_intermediate=false&limit=10"; + using var imagesResp = await httpClient.GetAsync(imagesPath).ConfigureAwait(false); + imagesResp.EnsureSuccessStatusCode(); + + var imagesJson = await imagesResp.Content.ReadAsStringAsync().ConfigureAwait(false); + var imagesData = JObject.Parse(imagesJson); + + var items = imagesData["items"] as JArray; + if (items != null && items.Count > 0) + { + var imageName = items[0].Value("image_name"); + + // Return result object mirroring your Python structure + var result = new JObject + { + ["batch_id"] = batchId, + ["status"] = "completed", + ["result"] = new JObject + { + ["outputs"] = new JObject + { + ["save_image"] = new JObject + { + ["type"] = "image_output", + ["image"] = new JObject + { + ["image_name"] = imageName + } + } + } + } + }; + + return result; + } + + // If no images found, return the status object + return statusData; + } + + // Wait before checking again + await Task.Delay(1000).ConfigureAwait(false); + } + } + + + private async Task EnqueueGraph(JToken graph) + { + string queueId = DEFAULT_QUEUE_ID; + + // Build request JSON dynamically + var payload = new JObject + { + ["batch"] = new JObject + { + ["graph"] = graph, // graph can be any JSON structure + ["runs"] = 1, + ["data"] = JValue.CreateNull() + } + }; + + var url = $"/api/v1/queue/{Uri.EscapeDataString(queueId)}/enqueue_batch"; + using var content = new StringContent(payload.ToString(Formatting.None), Encoding.UTF8, "application/json"); + + using var resp = await httpClient.PostAsync(url, content).ConfigureAwait(false); + resp.EnsureSuccessStatusCode(); + + var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); + return JObject.Parse(json); + } + + + private static void RequireFields(JObject info, string nameOrKey, params string[] fields) + { + foreach (var f in fields) + { + var v = info[f]; + if (v == null || v.Type == JTokenType.Null) + throw new ArgumentException($"Model {nameOrKey} is missing required field: {f}"); + } + } + + private static long GenerateUInt32Seed() + { + Span bytes = stackalloc byte[4]; + RandomNumberGenerator.Fill(bytes); + uint u = BitConverter.ToUInt32(bytes); + return (long)u; // preserve full 0..4294967295 range + } + + + private static JObject Edge(string srcNode, string srcField, string dstNode, string dstField) => new JObject + { + ["source"] = new JObject { ["node_id"] = srcNode, ["field"] = srcField }, + ["destination"] = new JObject { ["node_id"] = dstNode, ["field"] = dstField } + }; + + + + private async Task CreateText2ImgGraph( + string prompt, + string negativePrompt = "", + string modelKey = null, + string loraKey = null, + double loraWeight = 1.0, + string vaeKey = null, + int width = 512, + int height = 512, + int steps = 30, + double cfgScale = 7.5, + string scheduler = "euler", + long? seed = null) + { + // 1) Use default model if not specified: pick first "sd-1" from main list + if (string.IsNullOrEmpty(modelKey)) + { + var models = await ListModels("main"); + foreach (var token in models) + { + if (token is JObject m && string.Equals(m.Value("base"), "sd-1", StringComparison.OrdinalIgnoreCase)) + { + modelKey = m.Value("key"); + break; + } + } + + if (string.IsNullOrEmpty(modelKey)) + throw new ArgumentException("No suitable model found (sd-1)", nameof(modelKey)); + } + + // 2) Get model information + var modelInfo = await GetModelInfo(modelKey); + if (modelInfo == null) + throw new ArgumentException($"Model {modelKey} not found", nameof(modelKey)); + if (modelInfo.Type != JTokenType.Object) + throw new ArgumentException($"Model {modelKey} returned invalid data type: {modelInfo.Type}", nameof(modelKey)); + + // 3) Validate required fields + RequireFields(modelInfo, modelKey, "key", "hash", "name", "base", "type"); + + // 4) Generate random 32-bit seed if not provided (0..2^32-1) + if (seed == null) + seed = GenerateUInt32Seed(); + + // 5) Detect SDXL + bool isSdxl = string.Equals(modelInfo.Value("base"), "sdxl", StringComparison.OrdinalIgnoreCase); + + // 6) Build nodes + var nodes = new JObject + { + // Main model loader + ["model_loader"] = new JObject + { + ["type"] = isSdxl ? "sdxl_model_loader" : "main_model_loader", + ["id"] = "model_loader", + ["model"] = new JObject + { + ["key"] = modelInfo.Value("key"), + ["hash"] = modelInfo.Value("hash"), + ["name"] = modelInfo.Value("name"), + ["base"] = modelInfo.Value("base"), + ["type"] = modelInfo.Value("type") + } + }, + + // Positive prompt + ["positive_prompt"] = new JObject + { + ["type"] = isSdxl ? "sdxl_compel_prompt" : "compel", + ["id"] = "positive_prompt", + ["prompt"] = prompt + }, + + // Negative prompt + ["negative_prompt"] = new JObject + { + ["type"] = isSdxl ? "sdxl_compel_prompt" : "compel", + ["id"] = "negative_prompt", + ["prompt"] = negativePrompt + }, + + // Noise generation + ["noise"] = new JObject + { + ["type"] = "noise", + ["id"] = "noise", + ["seed"] = seed, + ["width"] = width, + ["height"] = height, + ["use_cpu"] = false + }, + + // Denoise latents + ["denoise"] = new JObject + { + ["type"] = "denoise_latents", + ["id"] = "denoise", + ["steps"] = steps, + ["cfg_scale"] = cfgScale, + ["scheduler"] = scheduler, + ["denoising_start"] = 0, + ["denoising_end"] = 1 + }, + + // Latents to image + ["latents_to_image"] = new JObject + { + ["type"] = "l2i", + ["id"] = "latents_to_image" + }, + + // Save image + ["save_image"] = new JObject + { + ["type"] = "save_image", + ["id"] = "save_image", + ["is_intermediate"] = false + } + }; + + // SDXL: add style fields (matches your Python **kwargs expansions) + if (isSdxl) + { + (nodes["positive_prompt"] as JObject)["style"] = prompt; + (nodes["negative_prompt"] as JObject)["style"] = ""; + } + + // 7) Optional: LoRA + if (!string.IsNullOrEmpty(loraKey)) + { + var loraInfo = await GetModelInfo(loraKey); + if (loraInfo == null) + throw new ArgumentException($"LoRA model {loraKey} not found", nameof(loraKey)); + RequireFields(loraInfo, loraKey, "key", "hash", "name", "base", "type"); + + nodes["lora_loader"] = new JObject + { + ["type"] = "lora_loader", + ["id"] = "lora_loader", + ["lora"] = new JObject + { + ["key"] = loraInfo.Value("key"), + ["hash"] = loraInfo.Value("hash"), + ["name"] = loraInfo.Value("name"), + ["base"] = loraInfo.Value("base"), + ["type"] = loraInfo.Value("type") + }, + ["weight"] = loraWeight + }; + } + + // 8) Optional: VAE override + if (!string.IsNullOrEmpty(vaeKey)) + { + var vaeInfo = await GetModelInfo(vaeKey); + if (vaeInfo == null) + throw new ArgumentException($"VAE model {vaeKey} not found", nameof(vaeKey)); + RequireFields(vaeInfo, vaeKey, "key", "hash", "name", "base", "type"); + + nodes["vae_loader"] = new JObject + { + ["type"] = "vae_loader", + ["id"] = "vae_loader", + ["vae_model"] = new JObject + { + ["key"] = vaeInfo.Value("key"), + ["hash"] = vaeInfo.Value("hash"), + ["name"] = vaeInfo.Value("name"), + ["base"] = vaeInfo.Value("base"), + ["type"] = vaeInfo.Value("type") + } + }; + } + + + var edges = new JArray(); + + // Determine sources + bool hasLora = !string.IsNullOrEmpty(loraKey); + string unetSource = hasLora ? "lora_loader" : "model_loader"; + string clipSource = hasLora ? "lora_loader" : "model_loader"; + string vaeSource = !string.IsNullOrEmpty(vaeKey) ? "vae_loader" : "model_loader"; + + // If using LoRA, connect model_loader -> lora_loader (unet & clip) + if (hasLora) + { + edges.Add(Edge("model_loader", "unet", "lora_loader", "unet")); + edges.Add(Edge("model_loader", "clip", "lora_loader", "clip")); + // Note: lora_loader doesn't have clip2; SDXL clip2 comes from model_loader directly (handled below) + } + + // Connect UNet to denoise + edges.Add(Edge(unetSource, "unet", "denoise", "unet")); + + // Connect CLIP to prompts + edges.Add(Edge(clipSource, "clip", "positive_prompt", "clip")); + edges.Add(Edge(clipSource, "clip", "negative_prompt", "clip")); + + // SDXL: connect clip2 from model_loader to both prompts + if (isSdxl) + { + edges.Add(Edge("model_loader", "clip2", "positive_prompt", "clip2")); + edges.Add(Edge("model_loader", "clip2", "negative_prompt", "clip2")); + } + + // Prompts -> denoise conditioning + edges.Add(Edge("positive_prompt", "conditioning", "denoise", "positive_conditioning")); + edges.Add(Edge("negative_prompt", "conditioning", "denoise", "negative_conditioning")); + + // Noise -> denoise + edges.Add(Edge("noise", "noise", "denoise", "noise")); + + // Denoise -> l2i, and VAE -> l2i + edges.Add(Edge("denoise", "latents", "latents_to_image", "latents")); + edges.Add(Edge(vaeSource, "vae", "latents_to_image", "vae")); + + // l2i -> save_image + edges.Add(Edge("latents_to_image", "image", "save_image", "image")); + + // 7) Return final graph object + var graph = new JObject + { + ["id"] = "text2img_graph", + ["nodes"] = nodes, + ["edges"] = edges + }; + + return graph; + } + + + private async Task GenerateImageUrl(JObject arguments) + { + if (arguments == null) throw new ArgumentNullException(nameof(arguments)); + + // --- Extract parameters (with defaults) --- + string prompt = arguments.Value("prompt") + ?? throw new ArgumentException("Argument 'prompt' is required.", nameof(arguments)); + string negativePrompt = arguments.Value("negative_prompt") ?? ""; + int width = arguments.Value("width") ?? 512; + int height = arguments.Value("height") ?? 512; + int steps = arguments.Value("steps") ?? 30; + double cfgScale = arguments.Value("cfg_scale") ?? 7.5; + string scheduler = arguments.Value("scheduler") ?? "euler"; + long? seed = arguments.Value("seed"); + string modelKey = arguments.Value("model_key"); + string loraKey = arguments.Value("lora_key"); + double loraWeight = arguments.Value("lora_weight") ?? 1.0; + string vaeKey = arguments.Value("vae_key"); + + + // --- Create graph --- + JObject graph = await CreateText2ImgGraph( + prompt: prompt, + negativePrompt: negativePrompt, + modelKey: modelKey, + loraKey: loraKey, + loraWeight: loraWeight, + vaeKey: vaeKey, + width: width, + height: height, + steps: steps, + cfgScale: cfgScale, + scheduler: scheduler, + seed: seed + ); + + // --- Enqueue --- + JObject enqueueResult = await EnqueueGraph(graph); + string batchId = enqueueResult.SelectToken("batch.batch_id")?.Value(); + if (string.IsNullOrEmpty(batchId)) + throw new InvalidOperationException("Enqueue response did not contain 'batch.batch_id'."); + + UnityEngine.Debug.Log($"Enqueued batch {batchId}, waiting for completion..."); + + // --- Wait for completion --- + JObject completed = await WaitForCompletion(batchId); + + // --- Extract image output --- + var outputs = completed.SelectToken("result.outputs") as JObject; + if (outputs != null) + { + foreach (var prop in outputs.Properties()) + { + var output = prop.Value as JObject; + if (output?.Value("type") == "image_output") + { + string imageName = output.SelectToken("image.image_name")?.Value(); + if (string.IsNullOrEmpty(imageName)) + continue; + + // Resolve relative URL for the image (API-dependent) + string imageRelativeUrl = await GetImageUrl(imageName); + return imageRelativeUrl; + } + } + } + + throw new InvalidOperationException("Failed to generate image: no image_output found in result."); + } + + public async Task GenerateImage(string prompt) + { + JObject args = new JObject() + { + ["prompt"] = prompt, + ["width"] = 512, + ["height"] = 512, + ["model_key"] = MODEL_KEY, + }; + + string imageUrl = await GenerateImageUrl(args); + + + var req = new HttpRequestMessage(HttpMethod.Get, imageUrl); + using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); + resp.EnsureSuccessStatusCode(); + + return await resp.Content.ReadAsByteArrayAsync(); + } +} diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs.meta b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs.meta new file mode 100644 index 00000000..9b5cb901 --- /dev/null +++ b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4591f6805db240a4ca28e515091ca909 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/manifest.json b/Packages/manifest.json index 5d426229..c1e4353a 100644 --- a/Packages/manifest.json +++ b/Packages/manifest.json @@ -13,6 +13,7 @@ "com.unity.ide.vscode": "1.2.5", "com.unity.inputsystem": "1.5.0", "com.unity.memoryprofiler": "1.0.0", + "com.unity.nuget.newtonsoft-json": "3.2.2", "com.unity.postprocessing": "3.2.2", "com.unity.probuilder": "5.0.6", "com.unity.progrids": "3.0.3-preview.6", diff --git a/Packages/packages-lock.json b/Packages/packages-lock.json index d15b3e71..d2cb6840 100644 --- a/Packages/packages-lock.json +++ b/Packages/packages-lock.json @@ -152,6 +152,13 @@ }, "url": "https://packages.unity.com" }, + "com.unity.nuget.newtonsoft-json": { + "version": "3.2.2", + "depth": 0, + "source": "registry", + "dependencies": {}, + "url": "https://packages.unity.com" + }, "com.unity.postprocessing": { "version": "3.2.2", "depth": 0,