forked from cgvr/DeltaVR
use Whisper in streaming mode
This commit is contained in:
@@ -34,6 +34,53 @@ def get_client() -> httpx.AsyncClient:
|
||||
return http_client
|
||||
|
||||
|
||||
async def list_models(model_type: str = "main") -> list:
|
||||
"""List available models."""
|
||||
client = get_client()
|
||||
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("models", [])
|
||||
|
||||
|
||||
async def get_model_info(model_key: str) -> Optional[dict]:
|
||||
"""Get information about a specific model."""
|
||||
client = get_client()
|
||||
try:
|
||||
response = await client.get(f"/api/v2/models/i/{model_key}")
|
||||
response.raise_for_status()
|
||||
model_data = response.json()
|
||||
|
||||
# Ensure we have a valid dictionary
|
||||
if not isinstance(model_data, dict):
|
||||
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
||||
return None
|
||||
|
||||
return model_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model info for {model_key}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def download_file(url, filepath):
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(filepath, "wb") as file:
|
||||
file.write(response.content)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||
|
||||
|
||||
async def get_image_url(image_name: str) -> str:
|
||||
"""Get the URL for an image."""
|
||||
client = get_client()
|
||||
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("image_url", "")
|
||||
|
||||
|
||||
async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, timeout: int = 300) -> dict:
|
||||
"""Wait for a batch to complete and return the most recent image."""
|
||||
client = get_client()
|
||||
@@ -94,25 +141,6 @@ async def wait_for_completion(batch_id: str, queue_id: str = DEFAULT_QUEUE_ID, t
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
def download_file(url, filepath):
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(filepath, "wb") as file:
|
||||
file.write(response.content)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to download image. Status code: {response.status_code}")
|
||||
|
||||
async def get_image_url(image_name: str) -> str:
|
||||
"""Get the URL for an image."""
|
||||
client = get_client()
|
||||
response = await client.get(f"/api/v1/images/i/{image_name}/urls")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("image_url", "")
|
||||
|
||||
|
||||
async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
||||
@@ -134,31 +162,7 @@ async def enqueue_graph(graph: dict, queue_id: str = DEFAULT_QUEUE_ID) -> dict:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def list_models(model_type: str = "main") -> list:
|
||||
"""List available models."""
|
||||
client = get_client()
|
||||
response = await client.get("/api/v2/models/", params={"model_type": model_type})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("models", [])
|
||||
|
||||
async def get_model_info(model_key: str) -> Optional[dict]:
|
||||
"""Get information about a specific model."""
|
||||
client = get_client()
|
||||
try:
|
||||
response = await client.get(f"/api/v2/models/i/{model_key}")
|
||||
response.raise_for_status()
|
||||
model_data = response.json()
|
||||
|
||||
# Ensure we have a valid dictionary
|
||||
if not isinstance(model_data, dict):
|
||||
logger.error(f"Model info for {model_key} is not a dictionary: {type(model_data)}")
|
||||
return None
|
||||
|
||||
return model_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model info for {model_key}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_text2img_graph(
|
||||
@@ -423,6 +427,7 @@ async def create_text2img_graph(
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def generate_image(arguments: dict):
|
||||
|
||||
# Extract parameters
|
||||
@@ -488,5 +493,6 @@ async def text_to_image_invoke_ai(prompt, output_path):
|
||||
"model_key": INVOKEAI_MODEL_KEY
|
||||
}
|
||||
image_url = await generate_image(args)
|
||||
print("got image url: ", image_url)
|
||||
print("Got image url:", image_url)
|
||||
print("Downloading image file...")
|
||||
download_file(image_url, output_path)
|
||||
|
||||
Reference in New Issue
Block a user