1
0
forked from cgvr/DeltaVR

use Whisper in streaming mode

This commit is contained in:
2025-12-31 14:13:25 +02:00
parent 7bc58a48d0
commit 0c8c55f293
5 changed files with 84 additions and 111 deletions

View File

@@ -34,6 +34,53 @@ def get_client() -> httpx.AsyncClient:
return http_client return http_client
async def list_models(model_type: str = "main") -> list:
"""List available models."""
client = get_client()
response = await client.get("/api/v2/models/", params={"model_type": model_type})
response.raise_for_status()
data = response.json()
return data.get("models", [])
async def get_model_info(model_key: str) -> Optional[dict]:
"""Get information about a specific model."""
client = get_client()
try:
response = await client.get(f"/api/v2/models/i/{model_key}")
response.raise_for_status()
model_data = response.json()
# Ensure we have a valid dictionary
if not isinstance(model_data, dict):
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
return None
return model_data
except Exception as e:
logger.error(f"Error fetching model info for {model_key}: {e}")
return None
def download_file(url, filepath):
response = requests.get(url)
if response.status_code == 200:
with open(filepath, "wb") as file:
file.write(response.content)
else:
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
async def get_image_url(image_name: str) -> str:
"""Get the URL for an image."""
client = get_client()
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
response.raise_for_status()
data = response.json()
return data.get("image_url", "")
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict: async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
"""Wait for a batch to complete and return the most recent image.""" """Wait for a batch to complete and return the most recent image."""
client = get_client() client = get_client()
@@ -96,25 +143,6 @@ async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, t
await asyncio.sleep(1) await asyncio.sleep(1)
def download_file(url, filepath):
response = requests.get(url)
if response.status_code == 200:
with open(filepath, "wb") as file:
file.write(response.content)
else:
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
async def get_image_url(image_name: str) -> str:
"""Get the URL for an image."""
client = get_client()
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
response.raise_for_status()
data = response.json()
return data.get("image_url", "")
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict: async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
"""Enqueue a graph for processing.""" """Enqueue a graph for processing."""
client = get_client() client = get_client()
@@ -134,31 +162,7 @@ async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
async def list_models(model_type: str = "main") -> list:
"""List available models."""
client = get_client()
response = await client.get("/api/v2/models/", params={"model_type": model_type})
response.raise_for_status()
data = response.json()
return data.get("models", [])
async def get_model_info(model_key: str) -> Optional[dict]:
"""Get information about a specific model."""
client = get_client()
try:
response = await client.get(f"/api/v2/models/i/{model_key}")
response.raise_for_status()
model_data = response.json()
# Ensure we have a valid dictionary
if not isinstance(model_data, dict):
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
return None
return model_data
except Exception as e:
logger.error(f"Error fetching model info for {model_key}: {e}")
return None
async def create_text2img_graph( async def create_text2img_graph(
@@ -423,6 +427,7 @@ async def create_text2img_graph(
return graph return graph
async def generate_image(arguments: dict): async def generate_image(arguments: dict):
# Extract parameters # Extract parameters
@@ -488,5 +493,6 @@ async def text_to_image_invoke_ai(prompt, output_path):
"model_key": INVOKEAI_MODEL_KEY "model_key": INVOKEAI_MODEL_KEY
} }
image_url = await generate_image(args) image_url = await generate_image(args)
print("got image url: ", image_url) print("Got image url:", image_url)
print("Downloading image file...")
download_file(image_url, output_path) download_file(image_url, output_path)

View File

@@ -64,18 +64,6 @@ public class InvokeAiClient : MonoBehaviour
return JObject.Parse(json); return JObject.Parse(json);
} }
private async Task<string> GetImageUrl(string imageName)
{
var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls";
UnityEngine.Debug.Log("Get image URL: " + requestUri);
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
var root = JObject.Parse(json);
return root.Value<string>("image_url");
}
private async Task<JObject> WaitForCompletion(string batchId, int timeoutSeconds = 300) private async Task<JObject> WaitForCompletion(string batchId, int timeoutSeconds = 300)
{ {
@@ -449,6 +437,17 @@ public class InvokeAiClient : MonoBehaviour
return graph; return graph;
} }
private async Task<string> GetImageUrl(string imageName)
{
var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls";
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
resp.EnsureSuccessStatusCode();
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
var root = JObject.Parse(json);
return root.Value<string>("image_url");
}
private async Task<string> GenerateImageUrl(JObject arguments) private async Task<string> GenerateImageUrl(JObject arguments)
{ {
@@ -510,9 +509,7 @@ public class InvokeAiClient : MonoBehaviour
if (string.IsNullOrEmpty(imageName)) if (string.IsNullOrEmpty(imageName))
continue; continue;
// Resolve relative URL for the image (API-dependent) return await GetImageUrl(imageName);
string imageRelativeUrl = await GetImageUrl(imageName);
return imageRelativeUrl;
} }
} }
} }
@@ -530,8 +527,9 @@ public class InvokeAiClient : MonoBehaviour
["model_key"] = MODEL_KEY, ["model_key"] = MODEL_KEY,
}; };
UnityEngine.Debug.Log("Starting image generation...");
string imageUrl = await GenerateImageUrl(args); string imageUrl = await GenerateImageUrl(args);
UnityEngine.Debug.Log("Image URL ready: " + imageUrl);
var req = new HttpRequestMessage(HttpMethod.Get, imageUrl); var req = new HttpRequestMessage(HttpMethod.Get, imageUrl);
using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);

View File

@@ -94,6 +94,7 @@ public class TrellisClient : MonoBehaviour
{ {
downloadResponse.EnsureSuccessStatusCode(); downloadResponse.EnsureSuccessStatusCode();
var bytes = await downloadResponse.Content.ReadAsByteArrayAsync(); var bytes = await downloadResponse.Content.ReadAsByteArrayAsync();
Debug.Log($"Downloaded {bytes.Length} bytes");
return bytes; return bytes;
} }
} }

View File

@@ -1,4 +1,3 @@
using System.Diagnostics;
using TMPro; using TMPro;
using Unity.XR.CoreUtils; using Unity.XR.CoreUtils;
using UnityEngine; using UnityEngine;
@@ -9,17 +8,14 @@ public class VoiceTranscriptionBox : MonoBehaviour
{ {
public Material activeMaterial; public Material activeMaterial;
public Material inactiveMaterial; public Material inactiveMaterial;
public Material loadingMaterial;
private MeshRenderer meshRenderer; private MeshRenderer meshRenderer;
private bool isLoading;
public WhisperManager whisper; public WhisperManager whisper;
public MicrophoneRecord microphoneRecord; public MicrophoneRecord microphoneRecord;
public TextMeshProUGUI outputText; public TextMeshProUGUI outputText;
private string _buffer; private WhisperStream stream;
private string lastTextOutput; private string lastTextOutput;
public string LastTextOutput public string LastTextOutput
@@ -30,19 +26,16 @@ public class VoiceTranscriptionBox : MonoBehaviour
} }
} }
private void Awake()
{
isLoading = false;
whisper.OnNewSegment += OnNewSegment;
microphoneRecord.OnRecordStop += OnRecordStop;
}
// Start is called before the first frame update // Start is called before the first frame update
void Start() async void Start()
{ {
meshRenderer = GetComponent<MeshRenderer>(); meshRenderer = GetComponent<MeshRenderer>();
// This causes about 1 sec long freeze, has to be done once at the start of the game
microphoneRecord.StartRecord();
stream = await whisper.CreateStream(microphoneRecord);
stream.OnResultUpdated += OnWhisperResult;
} }
// Update is called once per frame // Update is called once per frame
@@ -53,17 +46,12 @@ public class VoiceTranscriptionBox : MonoBehaviour
void OnTriggerEnter(Collider other) void OnTriggerEnter(Collider other)
{ {
if (isLoading)
{
return;
}
KbmController controller = other.GetComponent<KbmController>(); KbmController controller = other.GetComponent<KbmController>();
XROrigin playerOrigin = other.GetComponent<XROrigin>(); XROrigin playerOrigin = other.GetComponent<XROrigin>();
if (controller != null || playerOrigin != null) if (controller != null || playerOrigin != null)
{ {
meshRenderer.material = activeMaterial; meshRenderer.material = activeMaterial;
microphoneRecord.StartRecord(); stream.StartStream();
} }
} }
@@ -73,40 +61,20 @@ public class VoiceTranscriptionBox : MonoBehaviour
XROrigin playerOrigin = other.GetComponent<XROrigin>(); XROrigin playerOrigin = other.GetComponent<XROrigin>();
if (controller != null | playerOrigin != null) if (controller != null | playerOrigin != null)
{ {
microphoneRecord.StopRecord(); stream.StopStream();
meshRenderer.material = loadingMaterial; meshRenderer.material = inactiveMaterial;
isLoading = true;
} }
} }
private void OnWhisperResult(string result)
private async void OnRecordStop(AudioChunk recordedAudio)
{ {
_buffer = ""; lastTextOutput = result;
outputText.text = result;
var sw = new Stopwatch();
sw.Start();
var res = await whisper.GetTextAsync(recordedAudio.Data, recordedAudio.Frequency, recordedAudio.Channels);
if (res == null)
return;
var time = sw.ElapsedMilliseconds;
var rate = recordedAudio.Length / (time * 0.001f);
UnityEngine.Debug.Log($"Time: {time} ms\nRate: {rate:F1}x");
var text = res.Result;
lastTextOutput = text;
outputText.text = text;
meshRenderer.material = inactiveMaterial;
isLoading = false;
} }
private void OnNewSegment(WhisperSegment segment) private void OnDestroy()
{ {
_buffer += segment.Text; microphoneRecord.StopRecord();
UnityEngine.Debug.Log(_buffer + "..."); Destroy(gameObject);
} }
} }