DeltaVR/3d-generation-pipeline/start_pipeline.py
2025-11-14 09:53:30 +02:00

74 lines
2.3 KiB
Python

import os
import argparse
import asyncio
import logging
import time
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
from cloudflare_api import text_to_image_cloudflare, refine_text_prompt
from generate_image_local import text_to_image_invoke_ai
from generate_model_local import image_to_3d_api, image_to_3d_subprocess
load_dotenv()
PIPELINE_FOLDER = os.environ["PIPELINE_FOLDER"]
def get_timestamp():
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
def setup_logger(base_folder, timestamp):
log_dir = base_folder / Path("logs")
log_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
filename=log_dir / f"{timestamp}.log",
level=logging.INFO,
#format='%(asctime)s - %(message)s'
force=True
)
async def main():
parser = argparse.ArgumentParser(description="Text to 3D model pipeline")
parser.add_argument("--prompt", type=str, required=True, help="User text prompt")
args = parser.parse_args()
input_prompt = args.prompt
print(f"Input prompt: {input_prompt}")
refine_prompt = os.environ["REFINE_PROMPT"] == "1"
if refine_prompt:
image_generation_prompt = refine_text_prompt(input_prompt)
print(f"Refined prompt: {image_generation_prompt}")
else:
image_generation_prompt = input_prompt
pipeline_folder = Path(PIPELINE_FOLDER)
timestamp = get_timestamp()
setup_logger(pipeline_folder, timestamp)
time_checkpoint = time.time()
image_path = pipeline_folder / "images" / f"{timestamp}.jpg"
# TODO: use Invoke AI or Cloudflare, depending on env var
#text_to_image_cloudflare(image_generation_prompt, image_path)
await text_to_image_invoke_ai(image_generation_prompt, image_path)
image_generation_time = time.time() - time_checkpoint
time_checkpoint = time.time()
logging.info(f"Image generation time: {round(image_generation_time, 1)} s")
print(f"Generated image file: {image_path}")
model_path = pipeline_folder / "models" / timestamp
model_file = image_to_3d_api(image_path, model_path)
model_generation_time = time.time() - time_checkpoint
logging.info(f"Model generation time: {round(model_generation_time, 1)} s")
print(f"Generated 3D model file: {model_file}")
if __name__ == "__main__":
asyncio.run(main())