forked from cgvr/DeltaVR
use Whisper in streaming mode
This commit is contained in:
@@ -34,6 +34,53 @@ def get_client() -> httpx.AsyncClient:
|
|||||||
return http_client
|
return http_client
|
||||||
|
|
||||||
|
|
||||||
|
async def list_models(model_type: str = "main") -> list:
|
||||||
|
"""List available models."""
|
||||||
|
client = get_client()
|
||||||
|
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("models", [])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_info(model_key: str) -> Optional[dict]:
|
||||||
|
"""Get information about a specific model."""
|
||||||
|
client = get_client()
|
||||||
|
try:
|
||||||
|
response = await client.get(f"/api/v2/models/i/{model_key}")
|
||||||
|
response.raise_for_status()
|
||||||
|
model_data = response.json()
|
||||||
|
|
||||||
|
# Ensure we have a valid dictionary
|
||||||
|
if not isinstance(model_data, dict):
|
||||||
|
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return model_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model info for {model_key}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(url, filepath):
|
||||||
|
response = requests.get(url)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(filepath, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_image_url(image_name: str) -> str:
|
||||||
|
"""Get the URL for an image."""
|
||||||
|
client = get_client()
|
||||||
|
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("image_url", "")
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
||||||
"""Wait for a batch to complete and return the most recent image."""
|
"""Wait for a batch to complete and return the most recent image."""
|
||||||
client = get_client()
|
client = get_client()
|
||||||
@@ -96,25 +143,6 @@ async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, t
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def download_file(url, filepath):
|
|
||||||
response = requests.get(url)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
with open(filepath, "wb") as file:
|
|
||||||
file.write(response.content)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
|
||||||
|
|
||||||
async def get_image_url(image_name: str) -> str:
|
|
||||||
"""Get the URL for an image."""
|
|
||||||
client = get_client()
|
|
||||||
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return data.get("image_url", "")
|
|
||||||
|
|
||||||
|
|
||||||
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
||||||
"""Enqueue a graph for processing."""
|
"""Enqueue a graph for processing."""
|
||||||
client = get_client()
|
client = get_client()
|
||||||
@@ -134,31 +162,7 @@ async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def list_models(model_type: str = "main") -> list:
|
|
||||||
"""List available models."""
|
|
||||||
client = get_client()
|
|
||||||
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return data.get("models", [])
|
|
||||||
|
|
||||||
async def get_model_info(model_key: str) -> Optional[dict]:
|
|
||||||
"""Get information about a specific model."""
|
|
||||||
client = get_client()
|
|
||||||
try:
|
|
||||||
response = await client.get(f"/api/v2/models/i/{model_key}")
|
|
||||||
response.raise_for_status()
|
|
||||||
model_data = response.json()
|
|
||||||
|
|
||||||
# Ensure we have a valid dictionary
|
|
||||||
if not isinstance(model_data, dict):
|
|
||||||
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return model_data
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching model info for {model_key}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def create_text2img_graph(
|
async def create_text2img_graph(
|
||||||
@@ -423,6 +427,7 @@ async def create_text2img_graph(
|
|||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
async def generate_image(arguments: dict):
|
async def generate_image(arguments: dict):
|
||||||
|
|
||||||
# Extract parameters
|
# Extract parameters
|
||||||
@@ -488,5 +493,6 @@ async def text_to_image_invoke_ai(prompt, output_path):
|
|||||||
"model_key": INVOKEAI_MODEL_KEY
|
"model_key": INVOKEAI_MODEL_KEY
|
||||||
}
|
}
|
||||||
image_url = await generate_image(args)
|
image_url = await generate_image(args)
|
||||||
print("got image url: ", image_url)
|
print("Got image url:", image_url)
|
||||||
|
print("Downloading image file...")
|
||||||
download_file(image_url, output_path)
|
download_file(image_url, output_path)
|
||||||
|
|||||||
Binary file not shown.
@@ -64,18 +64,6 @@ public class InvokeAiClient : MonoBehaviour
|
|||||||
return JObject.Parse(json);
|
return JObject.Parse(json);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task<string> GetImageUrl(string imageName)
|
|
||||||
{
|
|
||||||
var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls";
|
|
||||||
UnityEngine.Debug.Log("Get image URL: " + requestUri);
|
|
||||||
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
|
|
||||||
resp.EnsureSuccessStatusCode();
|
|
||||||
|
|
||||||
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
|
|
||||||
var root = JObject.Parse(json);
|
|
||||||
return root.Value<string>("image_url");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private async Task<JObject> WaitForCompletion(string batchId, int timeoutSeconds = 300)
|
private async Task<JObject> WaitForCompletion(string batchId, int timeoutSeconds = 300)
|
||||||
{
|
{
|
||||||
@@ -449,6 +437,17 @@ public class InvokeAiClient : MonoBehaviour
|
|||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async Task<string> GetImageUrl(string imageName)
|
||||||
|
{
|
||||||
|
var requestUri = $"/api/v1/images/i/{Uri.EscapeDataString(imageName)}/urls";
|
||||||
|
using var resp = await httpClient.GetAsync(requestUri).ConfigureAwait(false);
|
||||||
|
resp.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
|
var json = await resp.Content.ReadAsStringAsync().ConfigureAwait(false);
|
||||||
|
var root = JObject.Parse(json);
|
||||||
|
return root.Value<string>("image_url");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private async Task<string> GenerateImageUrl(JObject arguments)
|
private async Task<string> GenerateImageUrl(JObject arguments)
|
||||||
{
|
{
|
||||||
@@ -510,9 +509,7 @@ public class InvokeAiClient : MonoBehaviour
|
|||||||
if (string.IsNullOrEmpty(imageName))
|
if (string.IsNullOrEmpty(imageName))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Resolve relative URL for the image (API-dependent)
|
return await GetImageUrl(imageName);
|
||||||
string imageRelativeUrl = await GetImageUrl(imageName);
|
|
||||||
return imageRelativeUrl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -530,8 +527,9 @@ public class InvokeAiClient : MonoBehaviour
|
|||||||
["model_key"] = MODEL_KEY,
|
["model_key"] = MODEL_KEY,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
UnityEngine.Debug.Log("Starting image generation...");
|
||||||
string imageUrl = await GenerateImageUrl(args);
|
string imageUrl = await GenerateImageUrl(args);
|
||||||
|
UnityEngine.Debug.Log("Image URL ready: " + imageUrl);
|
||||||
|
|
||||||
var req = new HttpRequestMessage(HttpMethod.Get, imageUrl);
|
var req = new HttpRequestMessage(HttpMethod.Get, imageUrl);
|
||||||
using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
|
using var resp = await httpClient.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ public class TrellisClient : MonoBehaviour
|
|||||||
{
|
{
|
||||||
downloadResponse.EnsureSuccessStatusCode();
|
downloadResponse.EnsureSuccessStatusCode();
|
||||||
var bytes = await downloadResponse.Content.ReadAsByteArrayAsync();
|
var bytes = await downloadResponse.Content.ReadAsByteArrayAsync();
|
||||||
|
Debug.Log($"Downloaded {bytes.Length} bytes");
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
using System.Diagnostics;
|
|
||||||
using TMPro;
|
using TMPro;
|
||||||
using Unity.XR.CoreUtils;
|
using Unity.XR.CoreUtils;
|
||||||
using UnityEngine;
|
using UnityEngine;
|
||||||
@@ -9,17 +8,14 @@ public class VoiceTranscriptionBox : MonoBehaviour
|
|||||||
{
|
{
|
||||||
public Material activeMaterial;
|
public Material activeMaterial;
|
||||||
public Material inactiveMaterial;
|
public Material inactiveMaterial;
|
||||||
public Material loadingMaterial;
|
|
||||||
|
|
||||||
private MeshRenderer meshRenderer;
|
private MeshRenderer meshRenderer;
|
||||||
private bool isLoading;
|
|
||||||
|
|
||||||
|
|
||||||
public WhisperManager whisper;
|
public WhisperManager whisper;
|
||||||
public MicrophoneRecord microphoneRecord;
|
public MicrophoneRecord microphoneRecord;
|
||||||
public TextMeshProUGUI outputText;
|
public TextMeshProUGUI outputText;
|
||||||
|
|
||||||
private string _buffer;
|
private WhisperStream stream;
|
||||||
|
|
||||||
private string lastTextOutput;
|
private string lastTextOutput;
|
||||||
public string LastTextOutput
|
public string LastTextOutput
|
||||||
@@ -30,19 +26,16 @@ public class VoiceTranscriptionBox : MonoBehaviour
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void Awake()
|
|
||||||
{
|
|
||||||
isLoading = false;
|
|
||||||
|
|
||||||
whisper.OnNewSegment += OnNewSegment;
|
|
||||||
|
|
||||||
microphoneRecord.OnRecordStop += OnRecordStop;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start is called before the first frame update
|
// Start is called before the first frame update
|
||||||
void Start()
|
async void Start()
|
||||||
{
|
{
|
||||||
meshRenderer = GetComponent<MeshRenderer>();
|
meshRenderer = GetComponent<MeshRenderer>();
|
||||||
|
|
||||||
|
// This causes about 1 sec long freeze, has to be done once at the start of the game
|
||||||
|
microphoneRecord.StartRecord();
|
||||||
|
|
||||||
|
stream = await whisper.CreateStream(microphoneRecord);
|
||||||
|
stream.OnResultUpdated += OnWhisperResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update is called once per frame
|
// Update is called once per frame
|
||||||
@@ -53,17 +46,12 @@ public class VoiceTranscriptionBox : MonoBehaviour
|
|||||||
|
|
||||||
void OnTriggerEnter(Collider other)
|
void OnTriggerEnter(Collider other)
|
||||||
{
|
{
|
||||||
if (isLoading)
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
KbmController controller = other.GetComponent<KbmController>();
|
KbmController controller = other.GetComponent<KbmController>();
|
||||||
XROrigin playerOrigin = other.GetComponent<XROrigin>();
|
XROrigin playerOrigin = other.GetComponent<XROrigin>();
|
||||||
if (controller != null || playerOrigin != null)
|
if (controller != null || playerOrigin != null)
|
||||||
{
|
{
|
||||||
meshRenderer.material = activeMaterial;
|
meshRenderer.material = activeMaterial;
|
||||||
microphoneRecord.StartRecord();
|
stream.StartStream();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,40 +61,20 @@ public class VoiceTranscriptionBox : MonoBehaviour
|
|||||||
XROrigin playerOrigin = other.GetComponent<XROrigin>();
|
XROrigin playerOrigin = other.GetComponent<XROrigin>();
|
||||||
if (controller != null | playerOrigin != null)
|
if (controller != null | playerOrigin != null)
|
||||||
{
|
{
|
||||||
microphoneRecord.StopRecord();
|
stream.StopStream();
|
||||||
meshRenderer.material = loadingMaterial;
|
|
||||||
isLoading = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private async void OnRecordStop(AudioChunk recordedAudio)
|
|
||||||
{
|
|
||||||
_buffer = "";
|
|
||||||
|
|
||||||
var sw = new Stopwatch();
|
|
||||||
sw.Start();
|
|
||||||
|
|
||||||
var res = await whisper.GetTextAsync(recordedAudio.Data, recordedAudio.Frequency, recordedAudio.Channels);
|
|
||||||
if (res == null)
|
|
||||||
return;
|
|
||||||
|
|
||||||
var time = sw.ElapsedMilliseconds;
|
|
||||||
var rate = recordedAudio.Length / (time * 0.001f);
|
|
||||||
UnityEngine.Debug.Log($"Time: {time} ms\nRate: {rate:F1}x");
|
|
||||||
|
|
||||||
var text = res.Result;
|
|
||||||
|
|
||||||
lastTextOutput = text;
|
|
||||||
outputText.text = text;
|
|
||||||
|
|
||||||
meshRenderer.material = inactiveMaterial;
|
meshRenderer.material = inactiveMaterial;
|
||||||
isLoading = false;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void OnNewSegment(WhisperSegment segment)
|
private void OnWhisperResult(string result)
|
||||||
{
|
{
|
||||||
_buffer += segment.Text;
|
lastTextOutput = result;
|
||||||
UnityEngine.Debug.Log(_buffer + "...");
|
outputText.text = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void OnDestroy()
|
||||||
|
{
|
||||||
|
microphoneRecord.StopRecord();
|
||||||
|
Destroy(gameObject);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user