From 0c8c55f2938b5d59f3fea9adca8545397e245e1d Mon Sep 17 00:00:00 2001 From: henrisel Date: Wed, 31 Dec 2025 14:13:25 +0200 Subject: [PATCH] use Whisper in streaming mode --- .../generate_image_local.py | 94 ++++++++++--------- .../_PROJECT/Scenes/DeltaBuilding_base.unity | 4 +- .../Scripts/ModeGeneration/InvokeAiClient.cs | 30 +++--- .../Scripts/ModeGeneration/TrellisClient.cs | 1 + .../ModeGeneration/VoiceTranscriptionBox.cs | 66 ++++--------- 5 files changed, 84 insertions(+), 111 deletions(-) diff --git a/3d-generation-pipeline/generate_image_local.py b/3d-generation-pipeline/generate_image_local.py index 22aafcfb..017c4717 100644 --- a/3d-generation-pipeline/generate_image_local.py +++ b/3d-generation-pipeline/generate_image_local.py @@ -34,6 +34,53 @@ def get_client() -> httpx.AsyncClient: return http_client +async def list_models(model_type: str = "main") -> list: + """List available models.""" + client = get_client() + response = await client.get("/api/v2/models/", params={"model_type": model_type}) + response.raise_for_status() + data = response.json() + return data.get("models", []) + + +async def get_model_info(model_key: str) -> Optional[dict]: + """Get information about a specific model.""" + client = get_client() + try: + response = await client.get(f"/api/v2/models/i/{model_key}") + response.raise_for_status() + model_data = response.json() + + # Ensure we have a valid dictionary + if not isinstance(model_data, dict): + logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}") + return None + + return model_data + except Exception as e: + logger.error(f"Error fetching model info for {model_key}: {e}") + return None + + +def download_file(url, filepath): + response = requests.get(url) + + if response.status_code == 200: + with open(filepath, "wb") as file: + file.write(response.content) + else: + raise RuntimeError(f"Failed to download image. Status code: {response.status_code}") + + +async def get_image_url(image_name: str) -> str: + """Get the URL for an image.""" + client = get_client() + response = await client.get(f"/api/v1/images/i/{image_name}/urls") + response.raise_for_status() + data = response.json() + return data.get("image_url", "") + + async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict: """Wait for a batch to complete and return the most recent image.""" client = get_client() @@ -94,25 +141,6 @@ async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, t # Wait before checking again await asyncio.sleep(1) - - - -def download_file(url, filepath): - response = requests.get(url) - - if response.status_code == 200: - with open(filepath, "wb") as file: - file.write(response.content) - else: - raise RuntimeError(f"Failed to download image. Status code: {response.status_code}") - -async def get_image_url(image_name: str) -> str: - """Get the URL for an image.""" - client = get_client() - response = await client.get(f"/api/v1/images/i/{image_name}/urls") - response.raise_for_status() - data = response.json() - return data.get("image_url", "") async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict: @@ -134,31 +162,7 @@ async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict: response.raise_for_status() return response.json() -async def list_models(model_type: str = "main") -> list: - """List available models.""" - client = get_client() - response = await client.get("/api/v2/models/", params={"model_type": model_type}) - response.raise_for_status() - data = response.json() - return data.get("models", []) -async def get_model_info(model_key: str) -> Optional[dict]: - """Get information about a specific model.""" - client = get_client() - try: - response = await client.get(f"/api/v2/models/i/{model_key}") - response.raise_for_status() - model_data = response.json() - - # Ensure we have a valid dictionary - if not isinstance(model_data, dict): - logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}") - return None - - return model_data - except Exception as e: - logger.error(f"Error fetching model info for {model_key}: {e}") - return None async def create_text2img_graph( @@ -423,6 +427,7 @@ async def create_text2img_graph( return graph + async def generate_image(arguments: dict): # Extract parameters @@ -488,5 +493,6 @@ async def text_to_image_invoke_ai(prompt, output_path): "model_key": INVOKEAI_MODEL_KEY } image_url = await generate_image(args) - print("got image url: ", image_url) + print("Got image url:", image_url) + print("Downloading image file...") download_file(image_url, output_path) diff --git a/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity b/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity index 13b300c1..a3c76177 100644 --- a/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity +++ b/Assets/_PROJECT/Scenes/DeltaBuilding_base.unity @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:329e8815aa533260a7abe80bf6c35c9428f5f81061753352dabcb9d6be828a15 -size 63208000 +oid sha256:e4c2da81f2d77fa3081c93b48f308f1ce4964dee8767417ffbc37371c4ce4043 +size 63208484 diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs index 0b7b144f..32514b34 100644 --- a/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs +++ b/Assets/_PROJECT/Scripts/ModeGeneration/InvokeAiClient.cs @@ -64,18 +64,6 @@ public class InvokeAiClient : MonoBehaviour return JObject.Parse(json); } - private async Task GetImageUrl(string imageName) - { - var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls"; - UnityEngine.Debug.Log("Get image URL: " + requestUri); - using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); - resp.EnsureSuccessStatusCode(); - - var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); - var root = JObject.Parse(json); - return root.Value("image_url"); - } - private async Task WaitForCompletion(string batchId, int timeoutSeconds = 300) { @@ -449,6 +437,17 @@ public class InvokeAiClient : MonoBehaviour return graph; } + private async Task GetImageUrl(string imageName) + { + var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls"; + using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false); + resp.EnsureSuccessStatusCode(); + + var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false); + var root = JObject.Parse(json); + return root.Value("image_url"); + } + private async Task GenerateImageUrl(JObject arguments) { @@ -510,9 +509,7 @@ public class InvokeAiClient : MonoBehaviour if (string.IsNullOrEmpty(imageName)) continue; - // Resolve relative URL for the image (API-dependent) - string imageRelativeUrl = await GetImageUrl(imageName); - return imageRelativeUrl; + return await GetImageUrl(imageName); } } } @@ -530,8 +527,9 @@ public class InvokeAiClient : MonoBehaviour ["model_key"] = MODEL_KEY, }; + UnityEngine.Debug.Log("Starting image generation..."); string imageUrl = await GenerateImageUrl(args); - + UnityEngine.Debug.Log("Image URL ready: " + imageUrl); var req = new HttpRequestMessage(HttpMethod.Get, imageUrl); using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/TrellisClient.cs b/Assets/_PROJECT/Scripts/ModeGeneration/TrellisClient.cs index 7c9cd37a..42c7019c 100644 --- a/Assets/_PROJECT/Scripts/ModeGeneration/TrellisClient.cs +++ b/Assets/_PROJECT/Scripts/ModeGeneration/TrellisClient.cs @@ -94,6 +94,7 @@ public class TrellisClient : MonoBehaviour { downloadResponse.EnsureSuccessStatusCode(); var bytes = await downloadResponse.Content.ReadAsByteArrayAsync(); + Debug.Log($"Downloaded {bytes.Length} bytes"); return bytes; } } diff --git a/Assets/_PROJECT/Scripts/ModeGeneration/VoiceTranscriptionBox.cs b/Assets/_PROJECT/Scripts/ModeGeneration/VoiceTranscriptionBox.cs index 24d3ffe8..4e6002b1 100644 --- a/Assets/_PROJECT/Scripts/ModeGeneration/VoiceTranscriptionBox.cs +++ b/Assets/_PROJECT/Scripts/ModeGeneration/VoiceTranscriptionBox.cs @@ -1,4 +1,3 @@ -using System.Diagnostics; using TMPro; using Unity.XR.CoreUtils; using UnityEngine; @@ -9,17 +8,14 @@ public class VoiceTranscriptionBox : MonoBehaviour { public Material activeMaterial; public Material inactiveMaterial; - public Material loadingMaterial; private MeshRenderer meshRenderer; - private bool isLoading; - public WhisperManager whisper; public MicrophoneRecord microphoneRecord; public TextMeshProUGUI outputText; - private string _buffer; + private WhisperStream stream; private string lastTextOutput; public string LastTextOutput @@ -30,19 +26,16 @@ public class VoiceTranscriptionBox : MonoBehaviour } } - private void Awake() - { - isLoading = false; - - whisper.OnNewSegment += OnNewSegment; - - microphoneRecord.OnRecordStop += OnRecordStop; - } - // Start is called before the first frame update - void Start() + async void Start() { meshRenderer = GetComponent(); + + // This causes about 1 sec long freeze, has to be done once at the start of the game + microphoneRecord.StartRecord(); + + stream = await whisper.CreateStream(microphoneRecord); + stream.OnResultUpdated += OnWhisperResult; } // Update is called once per frame @@ -53,17 +46,12 @@ public class VoiceTranscriptionBox : MonoBehaviour void OnTriggerEnter(Collider other) { - if (isLoading) - { - return; - } - KbmController controller = other.GetComponent(); XROrigin playerOrigin = other.GetComponent(); if (controller != null || playerOrigin != null) { meshRenderer.material = activeMaterial; - microphoneRecord.StartRecord(); + stream.StartStream(); } } @@ -73,40 +61,20 @@ public class VoiceTranscriptionBox : MonoBehaviour XROrigin playerOrigin = other.GetComponent(); if (controller != null | playerOrigin != null) { - microphoneRecord.StopRecord(); - meshRenderer.material = loadingMaterial; - isLoading = true; + stream.StopStream(); + meshRenderer.material = inactiveMaterial; } } - - private async void OnRecordStop(AudioChunk recordedAudio) + private void OnWhisperResult(string result) { - _buffer = ""; - - var sw = new Stopwatch(); - sw.Start(); - - var res = await whisper.GetTextAsync(recordedAudio.Data, recordedAudio.Frequency, recordedAudio.Channels); - if (res == null) - return; - - var time = sw.ElapsedMilliseconds; - var rate = recordedAudio.Length / (time * 0.001f); - UnityEngine.Debug.Log($"Time: {time} ms\nRate: {rate:F1}x"); - - var text = res.Result; - - lastTextOutput = text; - outputText.text = text; - - meshRenderer.material = inactiveMaterial; - isLoading = false; + lastTextOutput = result; + outputText.text = result; } - private void OnNewSegment(WhisperSegment segment) + private void OnDestroy() { - _buffer += segment.Text; - UnityEngine.Debug.Log(_buffer + "..."); + microphoneRecord.StopRecord(); + Destroy(gameObject); } }