From 9828c0b88553b26e36219cc2097653ad3f3ac471 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Wed, 6 May 2026 20:58:36 +0530 Subject: [PATCH 1/3] feat: ad hoc API for b64 cleanup --- backend/app/api/routes/private.py | 96 +++++++++++++++++++++++++++++++ backend/app/core/storage_utils.py | 43 +++++++++++++- 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/backend/app/api/routes/private.py b/backend/app/api/routes/private.py index 6100829c2..14890d787 100644 --- a/backend/app/api/routes/private.py +++ b/backend/app/api/routes/private.py @@ -1,15 +1,24 @@ +import base64 +import logging from typing import Any from fastapi import APIRouter from pydantic import BaseModel +from sqlmodel import col, select from app.api.deps import SessionDep +from app.core.cloud.storage import get_cloud_storage from app.core.security import get_password_hash +from app.core.storage_utils import upload_audio_bytes_to_s3 +from app.core.util import now from app.models import ( + LlmCall, User, UserPublic, ) +logger = logging.getLogger(__name__) + router = APIRouter(tags=["private"], prefix="/private") @@ -20,6 +29,93 @@ class PrivateUserCreate(BaseModel): is_verified: bool = False +@router.post("/migrate/tts-base64-to-s3", include_in_schema=False) +def migrate_tts_base64_to_s3(session: SessionDep) -> dict: + """ + One-shot migration: find all llm_call rows with input_type=text / output_type=audio + whose content still holds raw base64, upload the audio to S3, and replace with a URI. + """ + processed = skipped = failed = 0 + errors: list[dict] = [] + + # Storage instances are cached per project_id to avoid redundant DB lookups. + storage_cache: dict[int, Any] = {} + + statement = ( + select(LlmCall) + .where( + LlmCall.input_type == "text", + LlmCall.output_type == "audio", + col(LlmCall.deleted_at).is_(None), + ) + .order_by(col(LlmCall.created_at).desc()) + .execution_options(yield_per=100) + ) + + for call in session.exec(statement): + content = call.content + if not content: + skipped += 1 + continue + + audio_content = content.get("content", {}) + if audio_content.get("format") != "base64": + skipped += 1 + continue + + b64_value = audio_content.get("value") + if not b64_value: + skipped += 1 + continue + + try: + if call.project_id not in storage_cache: + storage_cache[call.project_id] = get_cloud_storage( + session, call.project_id + ) + storage = storage_cache[call.project_id] + + audio_bytes = base64.b64decode(b64_value) + s3_url = upload_audio_bytes_to_s3( + storage, + audio_bytes, + call.id, + audio_content.get("mime_type"), + "llm/tts/audio", + ) + + if not s3_url: + raise RuntimeError("upload returned None") + + call.content = { + "type": "audio", + "content": { + "format": "uri", + "value": s3_url, + "mime_type": audio_content.get("mime_type"), + }, + } + call.updated_at = now() + session.add(call) + processed += 1 + + except Exception as e: + failed += 1 + errors.append({"call_id": str(call.id), "error": str(e)}) + logger.warning( + f"[migrate_tts_base64_to_s3] Failed | call_id={call.id}, error={e}" + ) + + session.commit() + + return { + "processed": processed, + "skipped": skipped, + "failed": failed, + "errors": errors[:50], + } + + @router.post("/users", response_model=UserPublic, include_in_schema=False) def create_user(user_in: PrivateUserCreate, session: SessionDep) -> Any: """ diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py index 155627c13..dcf4fe02d 100644 --- a/backend/app/core/storage_utils.py +++ b/backend/app/core/storage_utils.py @@ -11,12 +11,13 @@ from datetime import datetime from io import BytesIO from pathlib import Path +from typing import Literal from urllib.parse import unquote, urlparse +from uuid import UUID from starlette.datastructures import Headers, UploadFile from app.core.cloud.storage import CloudStorage, CloudStorageError -from typing import Literal logger = logging.getLogger(__name__) @@ -207,6 +208,46 @@ def load_json_from_object_store(storage: CloudStorage, url: str) -> list | dict return None +_MIME_TO_EXT: dict[str, str] = { + "audio/mpeg": "mp3", + "audio/mp3": "mp3", + "audio/ogg": "ogg", + "audio/wav": "wav", + "audio/wave": "wav", + "audio/x-wav": "wav", + "audio/webm": "webm", + "audio/mp4": "mp4", + "audio/aac": "aac", + "audio/flac": "flac", +} + + +def upload_audio_bytes_to_s3( + storage: CloudStorage, + audio_bytes: bytes, + call_id: UUID, + mime_type: str | None, + prefix: str, +) -> str | None: + """Upload decoded audio bytes to S3 and return the s3:// URI. + + Args: + storage: CloudStorage instance + audio_bytes: Raw audio bytes + call_id: LLM call UUID used as the filename stem + mime_type: MIME type of the audio (determines file extension) + prefix: S3 subdirectory, e.g. "llm/tts/audio" or "llm/stt/audio" + + Returns: + s3:// URI if successful, None on failure + """ + ext = _MIME_TO_EXT.get(mime_type or "", "wav") + filename = f"{call_id}.{ext}" + return upload_to_object_store( + storage, audio_bytes, filename, prefix, mime_type or "audio/wav" + ) + + def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str: """ Generate a filename with timestamp. From 8f6de7f82c7193b1e25de64146afc42af5cb64b7 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Wed, 13 May 2026 16:59:19 +0530 Subject: [PATCH 2/3] chore: better loggers and progress tracking for ad-hoc script --- backend/app/api/routes/private.py | 117 ++++++- backend/app/tests/api/routes/test_private.py | 346 ++++++++++++++++++- 2 files changed, 454 insertions(+), 9 deletions(-) diff --git a/backend/app/api/routes/private.py b/backend/app/api/routes/private.py index 14890d787..ac441bf94 100644 --- a/backend/app/api/routes/private.py +++ b/backend/app/api/routes/private.py @@ -1,10 +1,11 @@ import base64 import logging +import time from typing import Any from fastapi import APIRouter from pydantic import BaseModel -from sqlmodel import col, select +from sqlmodel import col, func, select from app.api.deps import SessionDep from app.core.cloud.storage import get_cloud_storage @@ -29,18 +30,42 @@ class PrivateUserCreate(BaseModel): is_verified: bool = False +MIGRATION_BATCH_SIZE = 50 +MIGRATION_LOG_INTERVAL = 100 + + @router.post("/migrate/tts-base64-to-s3", include_in_schema=False) def migrate_tts_base64_to_s3(session: SessionDep) -> dict: """ One-shot migration: find all llm_call rows with input_type=text / output_type=audio whose content still holds raw base64, upload the audio to S3, and replace with a URI. + + Commits in batches so that partial progress is preserved on failure. """ + fn = "migrate_tts_base64_to_s3" + start_time = time.monotonic() + processed = skipped = failed = 0 + committed = 0 + pending_in_batch = 0 errors: list[dict] = [] # Storage instances are cached per project_id to avoid redundant DB lookups. storage_cache: dict[int, Any] = {} + # --- count total candidates for progress logging --- + count_stmt = ( + select(func.count()) + .select_from(LlmCall) + .where( + LlmCall.input_type == "text", + LlmCall.output_type == "audio", + col(LlmCall.deleted_at).is_(None), + ) + ) + total_candidates = session.exec(count_stmt).one() + logger.info(f"[{fn}] Starting migration | total_candidates={total_candidates}") + statement = ( select(LlmCall) .where( @@ -52,7 +77,7 @@ def migrate_tts_base64_to_s3(session: SessionDep) -> dict: .execution_options(yield_per=100) ) - for call in session.exec(statement): + for idx, call in enumerate(session.exec(statement), start=1): content = call.content if not content: skipped += 1 @@ -76,16 +101,20 @@ def migrate_tts_base64_to_s3(session: SessionDep) -> dict: storage = storage_cache[call.project_id] audio_bytes = base64.b64decode(b64_value) + b64_size_kb = len(b64_value) / 1024 + audio_size_kb = len(audio_bytes) / 1024 + + prefix = f"orgs/{call.organization_id}/{call.project_id}/audio/tts" s3_url = upload_audio_bytes_to_s3( storage, audio_bytes, call.id, audio_content.get("mime_type"), - "llm/tts/audio", + prefix, ) if not s3_url: - raise RuntimeError("upload returned None") + raise RuntimeError("upload_audio_bytes_to_s3 returned None") call.content = { "type": "audio", @@ -98,22 +127,94 @@ def migrate_tts_base64_to_s3(session: SessionDep) -> dict: call.updated_at = now() session.add(call) processed += 1 + pending_in_batch += 1 + + logger.debug( + f"[{fn}] Uploaded | call_id={call.id}, " + f"project_id={call.project_id}, " + f"b64_kb={b64_size_kb:.1f}, audio_kb={audio_size_kb:.1f}, " + f"s3_url={s3_url}" + ) except Exception as e: failed += 1 - errors.append({"call_id": str(call.id), "error": str(e)}) + errors.append( + { + "call_id": str(call.id), + "project_id": str(call.project_id), + "error": str(e), + } + ) logger.warning( - f"[migrate_tts_base64_to_s3] Failed | call_id={call.id}, error={e}" + f"[{fn}] Row failed | call_id={call.id}, " + f"project_id={call.project_id}, error={e}", + exc_info=True, + ) + # Expunge the dirty object so the failed row doesn't poison the batch + session.expunge(call) + + # --- batch commit for partial progress --- + if pending_in_batch >= MIGRATION_BATCH_SIZE: + try: + session.commit() + committed += pending_in_batch + logger.info( + f"[{fn}] Batch committed | " + f"batch_size={pending_in_batch}, total_committed={committed}" + ) + except Exception as e: + logger.error( + f"[{fn}] Batch commit failed, rolling back | " + f"pending={pending_in_batch}, error={e}", + exc_info=True, + ) + session.rollback() + failed += pending_in_batch + processed -= pending_in_batch + pending_in_batch = 0 + + # --- periodic progress log --- + if idx % MIGRATION_LOG_INTERVAL == 0: + elapsed = time.monotonic() - start_time + logger.info( + f"[{fn}] Progress | " + f"scanned={idx}/{total_candidates}, " + f"processed={processed}, skipped={skipped}, failed={failed}, " + f"elapsed={elapsed:.1f}s" ) - session.commit() + # --- final batch --- + if pending_in_batch > 0: + try: + session.commit() + committed += pending_in_batch + logger.info( + f"[{fn}] Final batch committed | " + f"batch_size={pending_in_batch}, total_committed={committed}" + ) + except Exception as e: + logger.error( + f"[{fn}] Final batch commit failed, rolling back | " + f"pending={pending_in_batch}, error={e}", + exc_info=True, + ) + session.rollback() + failed += pending_in_batch + processed -= pending_in_batch - return { + elapsed = time.monotonic() - start_time + summary = { "processed": processed, + "committed": committed, "skipped": skipped, "failed": failed, + "total_candidates": total_candidates, + "elapsed_seconds": round(elapsed, 2), "errors": errors[:50], } + logger.info(f"[{fn}] Migration complete | {summary}") + + return summary @router.post("/users", response_model=UserPublic, include_in_schema=False) diff --git a/backend/app/tests/api/routes/test_private.py b/backend/app/tests/api/routes/test_private.py index 1b5a3794c..ef2caa493 100644 --- a/backend/app/tests/api/routes/test_private.py +++ b/backend/app/tests/api/routes/test_private.py @@ -1,8 +1,106 @@ +import base64 +from unittest.mock import MagicMock, patch +from uuid import uuid4 + from fastapi.testclient import TestClient from sqlmodel import Session, select from app.core.config import settings -from app.models import User +from app.crud import JobCrud +from app.crud.llm import create_llm_call, update_llm_call_response +from app.models import JobType, LlmCall, User +from app.models.llm.request import ( + ConfigBlob, + KaapiCompletionConfig, + LLMCallConfig, + QueryParams, +) +from app.models.llm import LLMCallRequest +from app.tests.utils.auth import get_user_test_auth_context + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +FAKE_B64 = base64.b64encode(b"\x00\x01\x02\x03audio-bytes").decode() + +TTS_CONFIG = ConfigBlob( + completion=KaapiCompletionConfig( + provider="openai", + params={"model": "gpt-4o-mini-tts", "temperature": 0.7}, + type="tts", + ) +) + + +def _make_tts_call( + db: Session, + *, + project_id: int, + organization_id: int, + content: dict | None = None, +) -> LlmCall: + """Create an LlmCall with input_type=text, output_type=audio.""" + job = JobCrud(db).create( + job_type=JobType.LLM_API, + trace_id=f"test-tts-{uuid4().hex[:8]}", + project_id=project_id, + ) + call = create_llm_call( + db, + request=LLMCallRequest( + query=QueryParams(input="Say hello"), + config=LLMCallConfig(blob=TTS_CONFIG), + ), + job_id=job.id, + project_id=project_id, + organization_id=organization_id, + resolved_config=TTS_CONFIG, + original_provider="openai", + ) + if content is not None: + update_llm_call_response( + db, + llm_call_id=call.id, + provider_response_id=f"resp_{uuid4().hex[:8]}", + content=content, + ) + db.refresh(call) + return call + + +def _base64_content(mime_type: str = "audio/mp3") -> dict: + return { + "type": "audio", + "content": { + "format": "base64", + "value": FAKE_B64, + "mime_type": mime_type, + }, + } + + +def _uri_content() -> dict: + """Content that has already been migrated.""" + return { + "type": "audio", + "content": { + "format": "uri", + "value": "s3://bucket/audio/existing.mp3", + "mime_type": "audio/mp3", + }, + } + + +MIGRATE_URL = f"{settings.API_V1_STR}/private/migrate/tts-base64-to-s3" +UPLOAD_PATH = "app.api.routes.private.upload_audio_bytes_to_s3" +STORAGE_PATH = "app.api.routes.private.get_cloud_storage" + + +# --------------------------------------------------------------------------- +# Existing user test +# --------------------------------------------------------------------------- def test_create_user(client: TestClient, db: Session) -> None: @@ -24,3 +122,249 @@ def test_create_user(client: TestClient, db: Session) -> None: assert user assert user.email == "pollo@listo.com" assert user.full_name == "Pollo Listo" + + +# --------------------------------------------------------------------------- +# Migration tests +# --------------------------------------------------------------------------- + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/orgs/1/1/audio/tts/migrated.mp3") +def test_migrate_processes_base64_rows( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """Rows with base64 content are uploaded and rewritten to URI format.""" + auth = get_user_test_auth_context(db) + call = _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_base64_content(), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["processed"] >= 1 + assert data["failed"] == 0 + + db.refresh(call) + assert call.content["content"]["format"] == "uri" + assert call.content["content"]["value"].startswith("s3://") + mock_upload.assert_called() + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_skips_already_migrated_rows( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """Rows already in URI format are skipped, not re-uploaded.""" + auth = get_user_test_auth_context(db) + _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_uri_content(), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["skipped"] >= 1 + # upload should not be called for already-migrated rows + mock_upload.assert_not_called() + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_skips_rows_with_no_content( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """Rows with NULL content are skipped.""" + auth = get_user_test_auth_context(db) + _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=None, + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["skipped"] >= 1 + mock_upload.assert_not_called() + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value=None) +def test_migrate_records_failure_when_upload_returns_none( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """When upload_audio_bytes_to_s3 returns None, the row is counted as failed.""" + auth = get_user_test_auth_context(db) + call = _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_base64_content(), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["failed"] >= 1 + assert any(e["call_id"] == str(call.id) for e in data["errors"]) + + # Original content should remain unchanged + db.refresh(call) + assert call.content["content"]["format"] == "base64" + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, side_effect=RuntimeError("S3 connection timeout")) +def test_migrate_records_failure_on_upload_exception( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """An exception during upload is caught, logged, and reported in errors.""" + auth = get_user_test_auth_context(db) + call = _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_base64_content(), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["failed"] >= 1 + error_entry = next(e for e in data["errors"] if e["call_id"] == str(call.id)) + assert "S3 connection timeout" in error_entry["error"] + + # Original content should remain unchanged + db.refresh(call) + assert call.content["content"]["format"] == "base64" + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_uses_correct_s3_prefix( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """The S3 prefix follows orgs/{org_id}/{project_id}/audio/tts.""" + auth = get_user_test_auth_context(db) + call = _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_base64_content(), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + # Verify the prefix passed to upload_audio_bytes_to_s3 + _, kwargs = mock_upload.call_args + # positional args: (storage, audio_bytes, call_id, mime_type, prefix) + args = mock_upload.call_args[0] + expected_prefix = f"orgs/{auth.organization_id}/{auth.project_id}/audio/tts" + assert args[4] == expected_prefix + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_preserves_mime_type( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """The migrated content retains the original mime_type.""" + auth = get_user_test_auth_context(db) + call = _make_tts_call( + db, + project_id=auth.project_id, + organization_id=auth.organization_id, + content=_base64_content(mime_type="audio/wav"), + ) + + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + db.refresh(call) + assert call.content["content"]["mime_type"] == "audio/wav" + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_returns_summary_fields( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """The response includes all expected summary fields.""" + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + for key in [ + "processed", + "committed", + "skipped", + "failed", + "total_candidates", + "elapsed_seconds", + "errors", + ]: + assert key in data, f"Missing key: {key}" + + assert isinstance(data["elapsed_seconds"], (int, float)) + assert isinstance(data["errors"], list) + assert data["total_candidates"] >= 0 + + +@patch(STORAGE_PATH, return_value=MagicMock()) +@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") +def test_migrate_no_candidates( + mock_upload: MagicMock, + mock_storage: MagicMock, + client: TestClient, + db: Session, +) -> None: + """When there are no matching rows, migration completes with all zeros.""" + # Don't create any TTS LlmCall rows — the endpoint should still succeed + r = client.post(MIGRATE_URL) + assert r.status_code == 200 + + data = r.json() + assert data["processed"] == 0 + assert data["failed"] == 0 + assert data["committed"] == 0 + mock_upload.assert_not_called() From c4e694f039b8ac0e68438df20044fe996bc420da Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Wed, 13 May 2026 17:51:30 +0530 Subject: [PATCH 3/3] chore: better test cases --- backend/app/tests/api/routes/test_private.py | 484 ++++++++----------- 1 file changed, 206 insertions(+), 278 deletions(-) diff --git a/backend/app/tests/api/routes/test_private.py b/backend/app/tests/api/routes/test_private.py index ef2caa493..abf0ca7bc 100644 --- a/backend/app/tests/api/routes/test_private.py +++ b/backend/app/tests/api/routes/test_private.py @@ -1,339 +1,246 @@ import base64 -from unittest.mock import MagicMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, call from uuid import uuid4 from fastapi.testclient import TestClient from sqlmodel import Session, select from app.core.config import settings -from app.crud import JobCrud -from app.crud.llm import create_llm_call, update_llm_call_response -from app.models import JobType, LlmCall, User -from app.models.llm.request import ( - ConfigBlob, - KaapiCompletionConfig, - LLMCallConfig, - QueryParams, -) -from app.models.llm import LLMCallRequest -from app.tests.utils.auth import get_user_test_auth_context +from app.models import User # --------------------------------------------------------------------------- -# Helpers +# Existing user test (unchanged) # --------------------------------------------------------------------------- -FAKE_B64 = base64.b64encode(b"\x00\x01\x02\x03audio-bytes").decode() -TTS_CONFIG = ConfigBlob( - completion=KaapiCompletionConfig( - provider="openai", - params={"model": "gpt-4o-mini-tts", "temperature": 0.7}, - type="tts", +def test_create_user(client: TestClient, db: Session) -> None: + r = client.post( + f"{settings.API_V1_STR}/private/users", + json={ + "email": "pollo@listo.com", + "password": "password123", + "full_name": "Pollo Listo", + }, ) -) + assert r.status_code == 200 + + data = r.json() + + user = db.exec(select(User).where(User.id == data["id"])).first() -def _make_tts_call( - db: Session, - *, - project_id: int, - organization_id: int, + assert user + assert user.email == "pollo@listo.com" + assert user.full_name == "Pollo Listo" + + +# --------------------------------------------------------------------------- +# Unit tests for migrate_tts_base64_to_s3 +# --------------------------------------------------------------------------- + +MODULE = "app.api.routes.private" +FAKE_AUDIO = b"\x00\x01\x02\x03audio-bytes" +FAKE_B64 = base64.b64encode(FAKE_AUDIO).decode() + + +def _fake_call( content: dict | None = None, -) -> LlmCall: - """Create an LlmCall with input_type=text, output_type=audio.""" - job = JobCrud(db).create( - job_type=JobType.LLM_API, - trace_id=f"test-tts-{uuid4().hex[:8]}", - project_id=project_id, - ) - call = create_llm_call( - db, - request=LLMCallRequest( - query=QueryParams(input="Say hello"), - config=LLMCallConfig(blob=TTS_CONFIG), - ), - job_id=job.id, + project_id: int = 1, + organization_id: int = 10, +) -> SimpleNamespace: + """Lightweight stand-in for an LlmCall row.""" + return SimpleNamespace( + id=uuid4(), project_id=project_id, organization_id=organization_id, - resolved_config=TTS_CONFIG, - original_provider="openai", + content=content, + updated_at=None, ) - if content is not None: - update_llm_call_response( - db, - llm_call_id=call.id, - provider_response_id=f"resp_{uuid4().hex[:8]}", - content=content, - ) - db.refresh(call) - return call - - -def _base64_content(mime_type: str = "audio/mp3") -> dict: + + +def _b64_content(mime_type: str = "audio/mp3") -> dict: return { "type": "audio", - "content": { - "format": "base64", - "value": FAKE_B64, - "mime_type": mime_type, - }, + "content": {"format": "base64", "value": FAKE_B64, "mime_type": mime_type}, } def _uri_content() -> dict: - """Content that has already been migrated.""" return { "type": "audio", "content": { "format": "uri", - "value": "s3://bucket/audio/existing.mp3", + "value": "s3://bucket/existing.mp3", "mime_type": "audio/mp3", }, } -MIGRATE_URL = f"{settings.API_V1_STR}/private/migrate/tts-base64-to-s3" -UPLOAD_PATH = "app.api.routes.private.upload_audio_bytes_to_s3" -STORAGE_PATH = "app.api.routes.private.get_cloud_storage" +def _mock_session(rows: list) -> MagicMock: + """Build a mock session whose .exec() returns count then rows.""" + session = MagicMock() + count_result = MagicMock() + count_result.one.return_value = len(rows) + # First exec call → count, second → row iterator + session.exec.side_effect = [count_result, iter(rows)] + return session -# --------------------------------------------------------------------------- -# Existing user test -# --------------------------------------------------------------------------- +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/migrated.mp3") +def test_processes_base64_row(mock_upload: MagicMock, mock_storage: MagicMock) -> None: + """A row with base64 content is uploaded and rewritten to URI format.""" + from app.api.routes.private import migrate_tts_base64_to_s3 + row = _fake_call(content=_b64_content()) + session = _mock_session([row]) -def test_create_user(client: TestClient, db: Session) -> None: - r = client.post( - f"{settings.API_V1_STR}/private/users", - json={ - "email": "pollo@listo.com", - "password": "password123", - "full_name": "Pollo Listo", - }, - ) + result = migrate_tts_base64_to_s3(session) - assert r.status_code == 200 + assert result["processed"] == 1 + assert result["failed"] == 0 + assert row.content["content"]["format"] == "uri" + assert row.content["content"]["value"] == "s3://bucket/migrated.mp3" + session.add.assert_called_once_with(row) + mock_upload.assert_called_once() - data = r.json() - user = db.exec(select(User).where(User.id == data["id"])).first() +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/migrated.mp3") +def test_skips_already_migrated_uri( + mock_upload: MagicMock, mock_storage: MagicMock +) -> None: + """Rows already in URI format are skipped.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - assert user - assert user.email == "pollo@listo.com" - assert user.full_name == "Pollo Listo" + row = _fake_call(content=_uri_content()) + session = _mock_session([row]) + result = migrate_tts_base64_to_s3(session) -# --------------------------------------------------------------------------- -# Migration tests -# --------------------------------------------------------------------------- - + assert result["skipped"] == 1 + assert result["processed"] == 0 + mock_upload.assert_not_called() -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/orgs/1/1/audio/tts/migrated.mp3") -def test_migrate_processes_base64_rows( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, -) -> None: - """Rows with base64 content are uploaded and rewritten to URI format.""" - auth = get_user_test_auth_context(db) - call = _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_base64_content(), - ) - r = client.post(MIGRATE_URL) - assert r.status_code == 200 +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/migrated.mp3") +def test_skips_null_content(mock_upload: MagicMock, mock_storage: MagicMock) -> None: + """Rows with None content are skipped.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - data = r.json() - assert data["processed"] >= 1 - assert data["failed"] == 0 - - db.refresh(call) - assert call.content["content"]["format"] == "uri" - assert call.content["content"]["value"].startswith("s3://") - mock_upload.assert_called() - - -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_skips_already_migrated_rows( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, -) -> None: - """Rows already in URI format are skipped, not re-uploaded.""" - auth = get_user_test_auth_context(db) - _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_uri_content(), - ) + row = _fake_call(content=None) + session = _mock_session([row]) - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + result = migrate_tts_base64_to_s3(session) - data = r.json() - assert data["skipped"] >= 1 - # upload should not be called for already-migrated rows + assert result["skipped"] == 1 + assert result["processed"] == 0 mock_upload.assert_not_called() -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_skips_rows_with_no_content( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value=None) +def test_fails_when_upload_returns_none( + mock_upload: MagicMock, mock_storage: MagicMock ) -> None: - """Rows with NULL content are skipped.""" - auth = get_user_test_auth_context(db) - _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=None, - ) + """upload returning None is recorded as a failure; original content is unchanged.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + original_content = _b64_content() + row = _fake_call(content=original_content) + session = _mock_session([row]) - data = r.json() - assert data["skipped"] >= 1 - mock_upload.assert_not_called() + result = migrate_tts_base64_to_s3(session) + + assert result["failed"] == 1 + assert result["processed"] == 0 + assert any(e["call_id"] == str(row.id) for e in result["errors"]) + session.expunge.assert_called_once_with(row) -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value=None) -def test_migrate_records_failure_when_upload_returns_none( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", side_effect=RuntimeError("S3 timeout")) +def test_fails_on_upload_exception( + mock_upload: MagicMock, mock_storage: MagicMock ) -> None: - """When upload_audio_bytes_to_s3 returns None, the row is counted as failed.""" - auth = get_user_test_auth_context(db) - call = _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_base64_content(), - ) + """An upload exception is caught and recorded in errors.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + row = _fake_call(content=_b64_content()) + session = _mock_session([row]) - data = r.json() - assert data["failed"] >= 1 - assert any(e["call_id"] == str(call.id) for e in data["errors"]) + result = migrate_tts_base64_to_s3(session) - # Original content should remain unchanged - db.refresh(call) - assert call.content["content"]["format"] == "base64" + assert result["failed"] == 1 + error = next(e for e in result["errors"] if e["call_id"] == str(row.id)) + assert "S3 timeout" in error["error"] + session.expunge.assert_called_once_with(row) -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, side_effect=RuntimeError("S3 connection timeout")) -def test_migrate_records_failure_on_upload_exception( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_uses_correct_s3_prefix( + mock_upload: MagicMock, mock_storage: MagicMock ) -> None: - """An exception during upload is caught, logged, and reported in errors.""" - auth = get_user_test_auth_context(db) - call = _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_base64_content(), - ) - - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + """The prefix follows orgs/{org_id}/{project_id}/audio/tts.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - data = r.json() - assert data["failed"] >= 1 - error_entry = next(e for e in data["errors"] if e["call_id"] == str(call.id)) - assert "S3 connection timeout" in error_entry["error"] - - # Original content should remain unchanged - db.refresh(call) - assert call.content["content"]["format"] == "base64" - - -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_uses_correct_s3_prefix( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, -) -> None: - """The S3 prefix follows orgs/{org_id}/{project_id}/audio/tts.""" - auth = get_user_test_auth_context(db) - call = _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_base64_content(), - ) + row = _fake_call(content=_b64_content(), project_id=42, organization_id=7) + session = _mock_session([row]) - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + migrate_tts_base64_to_s3(session) - # Verify the prefix passed to upload_audio_bytes_to_s3 - _, kwargs = mock_upload.call_args - # positional args: (storage, audio_bytes, call_id, mime_type, prefix) args = mock_upload.call_args[0] - expected_prefix = f"orgs/{auth.organization_id}/{auth.project_id}/audio/tts" - assert args[4] == expected_prefix + assert args[4] == "orgs/7/42/audio/tts" -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_preserves_mime_type( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, -) -> None: +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_preserves_mime_type(mock_upload: MagicMock, mock_storage: MagicMock) -> None: """The migrated content retains the original mime_type.""" - auth = get_user_test_auth_context(db) - call = _make_tts_call( - db, - project_id=auth.project_id, - organization_id=auth.organization_id, - content=_base64_content(mime_type="audio/wav"), - ) + from app.api.routes.private import migrate_tts_base64_to_s3 - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + row = _fake_call(content=_b64_content(mime_type="audio/wav")) + session = _mock_session([row]) + + migrate_tts_base64_to_s3(session) + + assert row.content["content"]["mime_type"] == "audio/wav" - db.refresh(call) - assert call.content["content"]["mime_type"] == "audio/wav" +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_no_candidates(mock_upload: MagicMock, mock_storage: MagicMock) -> None: + """Zero rows means all counters are zero and no uploads happen.""" + from app.api.routes.private import migrate_tts_base64_to_s3 -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_returns_summary_fields( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, + session = _mock_session([]) + + result = migrate_tts_base64_to_s3(session) + + assert result["processed"] == 0 + assert result["failed"] == 0 + assert result["committed"] == 0 + assert result["total_candidates"] == 0 + mock_upload.assert_not_called() + session.commit.assert_not_called() + + +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_returns_all_summary_fields( + mock_upload: MagicMock, mock_storage: MagicMock ) -> None: - """The response includes all expected summary fields.""" - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + """The response dict contains every expected key.""" + from app.api.routes.private import migrate_tts_base64_to_s3 + + session = _mock_session([]) + + result = migrate_tts_base64_to_s3(session) - data = r.json() for key in [ "processed", "committed", @@ -343,28 +250,49 @@ def test_migrate_returns_summary_fields( "elapsed_seconds", "errors", ]: - assert key in data, f"Missing key: {key}" + assert key in result, f"Missing key: {key}" + assert isinstance(result["elapsed_seconds"], (int, float)) + assert isinstance(result["errors"], list) + + +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_mixed_rows(mock_upload: MagicMock, mock_storage: MagicMock) -> None: + """A mix of base64, URI, and null-content rows are handled correctly.""" + from app.api.routes.private import migrate_tts_base64_to_s3 + + rows = [ + _fake_call(content=_b64_content()), + _fake_call(content=_uri_content()), + _fake_call(content=None), + _fake_call(content=_b64_content(mime_type="audio/wav")), + ] + session = _mock_session(rows) - assert isinstance(data["elapsed_seconds"], (int, float)) - assert isinstance(data["errors"], list) - assert data["total_candidates"] >= 0 + result = migrate_tts_base64_to_s3(session) + assert result["processed"] == 2 + assert result["skipped"] == 2 + assert result["failed"] == 0 + assert mock_upload.call_count == 2 -@patch(STORAGE_PATH, return_value=MagicMock()) -@patch(UPLOAD_PATH, return_value="s3://bucket/audio.mp3") -def test_migrate_no_candidates( - mock_upload: MagicMock, - mock_storage: MagicMock, - client: TestClient, - db: Session, + +@patch(f"{MODULE}.get_cloud_storage", return_value=MagicMock()) +@patch(f"{MODULE}.upload_audio_bytes_to_s3", return_value="s3://bucket/out.mp3") +def test_caches_storage_per_project( + mock_upload: MagicMock, mock_storage: MagicMock ) -> None: - """When there are no matching rows, migration completes with all zeros.""" - # Don't create any TTS LlmCall rows — the endpoint should still succeed - r = client.post(MIGRATE_URL) - assert r.status_code == 200 + """get_cloud_storage is called once per unique project_id, not per row.""" + from app.api.routes.private import migrate_tts_base64_to_s3 - data = r.json() - assert data["processed"] == 0 - assert data["failed"] == 0 - assert data["committed"] == 0 - mock_upload.assert_not_called() + rows = [ + _fake_call(content=_b64_content(), project_id=1), + _fake_call(content=_b64_content(), project_id=1), + _fake_call(content=_b64_content(), project_id=2), + ] + session = _mock_session(rows) + + migrate_tts_base64_to_s3(session) + + # Only 2 distinct project_ids → 2 calls + assert mock_storage.call_count == 2