diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py new file mode 100644 index 000000000..26fb1a8d3 --- /dev/null +++ b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_jobs + +Revision ID: 055 +Revises: 054 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index e667015f5..438dc3e9b 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,6 +1,6 @@ Upload a document to Kaapi. -- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. +- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..9a13fddcf 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -4,9 +4,6 @@ import celery from asgi_correlation_id import correlation_id from celery import current_task -from opentelemetry import context as otel_context -from opentelemetry import trace -from opentelemetry.propagate import extract from app.celery.celery_app import celery_app from app.celery.utils import gevent_timeout @@ -20,61 +17,18 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") -def _extract_parent_context(task_instance) -> otel_context.Context: - """Extract OTel parent context from Celery headers if available.""" - headers = getattr(task_instance.request, "headers", None) or {} - carrier: dict[str, str] = {} - - if isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(value, str): - carrier[str(key)] = value - - nested = headers.get("otel", {}) - if isinstance(nested, dict): - for key, value in nested.items(): - if isinstance(value, str): - carrier[str(key)] = value - - return extract(carrier) - - -def _run_with_otel_parent(task_instance, fn): - """Attach extracted parent context and execute function. - - When Celery auto-instrumentation is active, there is already a current - `run/...` span. Re-attaching extracted parent context here would make - service spans become siblings of `run/...` instead of children. - - We only attach extracted context as a fallback when no active span exists. - """ - current_ctx = trace.get_current_span().get_span_context() - if current_ctx and current_ctx.is_valid: - return fn() - - parent_ctx = _extract_parent_context(task_instance) - token = otel_context.attach(parent_ctx) - try: - return fn() - finally: - otel_context.detach(token) - - @celery_app.task(bind=True, queue="high_priority", priority=9) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job") def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -84,15 +38,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -102,15 +53,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -120,15 +68,12 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -137,18 +82,32 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job + + _set_trace(trace_id) + return execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -160,15 +119,12 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -180,15 +136,12 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -200,15 +153,12 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -220,15 +170,12 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -242,13 +189,10 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 288cba7c4..f39fb9d5d 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -18,24 +18,14 @@ F = TypeVar("F", bound=Callable[..., Any]) -def _enqueue_with_trace_context(task, **kwargs) -> str: - """Publish Celery task with explicit trace context headers.""" - otel_headers: dict[str, str] = {} - inject(otel_headers) - celery_headers = dict(otel_headers) - celery_headers["otel"] = otel_headers - async_result = task.apply_async(kwargs=kwargs, headers=celery_headers) - return async_result.id - - def start_llm_job(project_id: int, job_id: str, trace_id: str = "N/A", **kwargs) -> str: from app.celery.tasks.job_execution import run_llm_job - task_id = _enqueue_with_trace_context( - run_llm_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task = run_llm_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_llm_chain_job( @@ -43,17 +33,13 @@ def start_llm_chain_job( ) -> str: from app.celery.tasks.job_execution import run_llm_chain_job - task_id = _enqueue_with_trace_context( - run_llm_chain_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_llm_chain_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_llm_chain_job] Started job {job_id} with Celery task {task_id}" + f"[start_llm_chain_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_response_job( @@ -61,15 +47,11 @@ def start_response_job( ) -> str: from app.celery.tasks.job_execution import run_response_job - task_id = _enqueue_with_trace_context( - run_response_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_response_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_response_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_response_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_doctransform_job( @@ -77,17 +59,13 @@ def start_doctransform_job( ) -> str: from app.celery.tasks.job_execution import run_doctransform_job - task_id = _enqueue_with_trace_context( - run_doctransform_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_doctransform_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_doctransform_job] Started job {job_id} with Celery task {task_id}" + f"[start_doctransform_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_create_collection_job( @@ -95,17 +73,27 @@ def start_create_collection_job( ) -> str: from app.celery.tasks.job_execution import run_create_collection_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_create_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + ) + logger.info( + f"[start_create_collection_job] Started job {job_id} with Celery task {task.id}" + ) + return task.id + + +def start_collection_batch_job( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_collection_batch_job + + task = run_collection_batch_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_batch_job] Started batch job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_delete_collection_job( @@ -113,17 +101,13 @@ def start_delete_collection_job( ) -> str: from app.celery.tasks.job_execution import run_delete_collection_job - task_id = _enqueue_with_trace_context( - run_delete_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_delete_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_delete_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_delete_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_batch_submission( @@ -131,17 +115,13 @@ def start_stt_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_stt_batch_submission - task_id = _enqueue_with_trace_context( - run_stt_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_stt_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_metric_computation( @@ -149,17 +129,13 @@ def start_stt_metric_computation( ) -> str: from app.celery.tasks.job_execution import run_stt_metric_computation - task_id = _enqueue_with_trace_context( - run_stt_metric_computation, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_metric_computation.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_metric_computation] Started job {job_id} with Celery task {task_id}" + f"[start_stt_metric_computation] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_batch_submission( @@ -167,17 +143,13 @@ def start_tts_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_tts_batch_submission - task_id = _enqueue_with_trace_context( - run_tts_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_tts_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_result_processing( @@ -185,17 +157,13 @@ def start_tts_result_processing( ) -> str: from app.celery.tasks.job_execution import run_tts_result_processing - task_id = _enqueue_with_trace_context( - run_tts_result_processing, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_result_processing.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_result_processing] Started job {job_id} with Celery task {task_id}" + f"[start_tts_result_processing] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def get_task_status(task_id: str) -> Dict[str, Any]: diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index cdae82440..be6970235 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -1,6 +1,7 @@ import json import logging import functools as ft +import time from io import BytesIO from typing import Iterable @@ -149,6 +150,58 @@ def update( yield from docs + def update_batch( + self, + vector_store_id: str, + docs: list[Document], + ) -> tuple[list[Document], list[Document]]: + """ + Attach a batch of documents to the vector store via a single upload_and_poll call. + + All docs must have provider_file_id set before calling this method. + Returns (succeeded, failed) — failed docs should be retried in the next batch. + """ + succeeded: list[Document] = [] + failed: list[Document] = [] + + if not docs: + return succeeded, failed + + try: + _t0 = time.monotonic() + batch = self.client.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, + files=[], + file_ids=[doc.openai_file_id for doc in docs], + ) + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch upload_and_poll duration | " + f"{{'vector_store_id': '{vector_store_id}', 'duration_s': {time.monotonic() - _t0:.3f}, " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" + ) + if batch.file_counts.failed == 0: + succeeded.extend(docs) + else: + # Can't identify which specific files failed — retry all of them + logger.warning( + f"[OpenAIVectorStoreCrud.update_batch] Batch had failures, marking all for retry | " + f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" + ) + failed.extend(docs) + except OpenAIError as err: + logger.error( + f"[OpenAIVectorStoreCrud.update_batch] Batch attach failed | " + f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) + failed.extend(docs) + + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', 'succeeded': {len(succeeded)}, 'failed': {len(failed)}}}" + ) + return succeeded, failed + def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: try: diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index 333ebfd14..6b628ad7e 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -77,7 +77,29 @@ class CollectionJob(SQLModel, table=True): documents: list[str] | None = Field( default=None, sa_column=Column( - JSON, nullable=True, comment="List of documents given to make collection" + JSON, nullable=True, comment="List of document IDs given to make collection" + ), + ) + total_batches: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Total number of batches the documents are split into" + }, + ) + current_batch_number: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Which batch is currently being processed (1-indexed)" + }, + ) + documents_uploaded: list[str] | None = Field( + default=None, + sa_column=Column( + JSON, + nullable=True, + comment="List of document IDs successfully uploaded so far", ), ) @@ -139,6 +161,9 @@ class CollectionJobUpdate(SQLModel): collection_id: UUID | None = None total_size_mb: float | None = None trace_id: str | None = None + total_batches: int | None = None + current_batch_number: int | None = None + documents_uploaded: list[str] | None = None ##Response models diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 12843e72a..5bbcddc77 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -46,6 +46,11 @@ class Document(DocumentBase, table=True): description="The size of the document in kilobytes", sa_column_kwargs={"comment": "Size of the document in kilobytes (KB)"}, ) + openai_file_id: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "File ID assigned by OpenAI (avoid re-uploading)"}, + ) # Foreign keys source_document_id: UUID | None = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index a9b787f6b..22f5e5602 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -2,7 +2,6 @@ import time from uuid import UUID, uuid4 -from opentelemetry import trace from sqlmodel import Session from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout @@ -10,7 +9,6 @@ from app.core.cloud import get_cloud_storage from app.core.db import engine -from app.core.telemetry import log_context from app.crud import ( CollectionCrud, DocumentCrud, @@ -25,17 +23,19 @@ CollectionJobPublic, CreationRequest, ) +from app.crud.rag import OpenAIVectorStoreCrud from app.services.collections.helpers import ( + batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job -from app.utils import send_callback, get_webhook_secret, APIResponse +from gevent import Timeout +from app.celery.utils import start_create_collection_job, start_collection_batch_job +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) -tracer = trace.get_tracer(__name__) def start_job( @@ -46,49 +46,31 @@ def start_job( with_assistant: bool, organization_id: int, ) -> str: - with log_context( - tag="collection", - lifecycle="collection.create.start_job", - action="create", - collection_job_id=collection_job_id, - project_id=project_id, - organization_id=organization_id, - ): - trace_id = correlation_id.get() or "N/A" + trace_id = correlation_id.get() or "N/A" - job_crud = CollectionJobCrud(db, project_id) - collection_job = job_crud.update( - collection_job_id, CollectionJobUpdate(trace_id=trace_id) - ) + job_crud = CollectionJobCrud(db, project_id) + job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( - project_id=project_id, - job_id=str(collection_job_id), - trace_id=trace_id, - request=request.model_dump(mode="json"), - with_assistant=with_assistant, - organization_id=organization_id, - ) + task_id = start_create_collection_job( + project_id=project_id, + job_id=str(collection_job_id), + trace_id=trace_id, + request=request.model_dump(mode="json"), + with_assistant=with_assistant, + organization_id=organization_id, + ) - logger.info( - "[create_collection.start_job] Job scheduled to create collection | " - f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" - ) + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) - return collection_job_id + return collection_job_id def build_success_payload( collection_job: CollectionJob, collection: Collection ) -> dict: - """ - { - "success": true, - "data": { job fields + full collection }, - "error": null, - "metadata": null - } - """ collection_public = to_collection_public(collection) collection_dict = collection_public.model_dump(mode="json", exclude_none=True) @@ -102,15 +84,6 @@ def build_success_payload( def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: - """ - { - "success": false, - "data": { job fields, collection: null }, - "error": "something went wrong", - "metadata": null - } - """ - # ensure `collection` is explicitly null in the payload job_public = CollectionJobPublic.model_validate( collection_job, update={"collection": None}, @@ -144,10 +117,63 @@ def _mark_job_failed( ) return collection_job except Exception: - logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + logger.warning("[create_collection] Failed to mark job as FAILED") return None +def _persist_succeeded_docs(succeeded: list, project_id: int) -> list[str]: + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + for doc in succeeded: + if doc.openai_file_id: + db_doc = document_crud.read_one(doc.id) + if db_doc.openai_file_id != doc.openai_file_id: + db_doc.openai_file_id = doc.openai_file_id + document_crud.update(db_doc) + return [str(doc.id) for doc in succeeded] + + +def _retry_failed_uploads( + vector_store_crud, + vector_store_id: str, + failed_docs: list, + project_id: int, + max_retries: int = 3, +) -> list[str]: + """ + Retry attaching docs that failed the initial batch upload_and_poll. + All docs must already have provider_file_id set. + Returns the list of successfully retried doc IDs. + Raises RuntimeError if any docs still fail after all retries. + """ + pending = failed_docs + all_succeeded_ids: list[str] = [] + + for attempt in range(1, max_retries + 1): + logger.warning( + "[_retry_failed_uploads] Retry attempt %d/%d: %d doc(s) | vector_store_id=%s", + attempt, + max_retries, + len(pending), + vector_store_id, + ) + succeeded, failed = vector_store_crud.update_batch(vector_store_id, pending) + + if succeeded: + all_succeeded_ids += _persist_succeeded_docs(succeeded, project_id) + + if not failed: + return all_succeeded_ids + + pending = failed + + ids = [str(d.id) for d in pending] + raise RuntimeError( + f"Failed to upload {len(pending)} document(s) after {max_retries} retries: {ids}" + ) + + +def execute_setup_job( def _handle_job_failure( span, project_id: int, @@ -196,122 +222,282 @@ def execute_job( task_instance, ) -> None: """ - Worker entrypoint scheduled by start_job. - Orchestrates: job state, provider init, collection creation, - optional assistant creation, collection persistence, linking, callbacks, and cleanup. + Phase 1: Fetch documents, create the vector store, split into batches, + update job state to PROCESSING, then queue the first batch task. """ - start_time = time.time() + collection_job = None + creation_request = None + + try: + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" + + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + for doc in flat_docs: + session.expunge(doc) + provider.upload_files(storage, flat_docs, project_id) + + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d, failed=%d, duration_s=%.2f", + job_id, + len(flat_docs), + ) + + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) + total_size_mb = total_size_kb / 1024 + + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] + + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + task_id=task_id, + status=CollectionJobStatus.PROCESSING, + total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], + ), + ) + + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) + + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" + ) + _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=timeout_err, + collection_job=collection_job, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) + + +def execute_batch_job( + request: dict, + with_assistant: bool, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + task_instance, + vector_store_id: str | None, + batch_number: int, + batch_doc_ids: list[str], + remaining_batches: list[list[str]], +) -> None: + """ + Phase 2: Upload one batch of documents to the vector store. + - Uploads the batch; any failures within the batch are retried inline by _upload_batch_with_retry + - Raises immediately if all retries for the batch are exhausted + - Checkpoints progress to the DB + - If more batches remain, queues the next batch task + - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL + """ collection_job = None - result = None creation_request = None - provider = None - - with log_context( - tag="collection", - lifecycle="collection.create.execute_job", - action="create", - collection_job_id=job_id, - task_id=task_id, - project_id=project_id, - organization_id=organization_id, - ), tracer.start_as_current_span("collections.create.execute_job") as span: - span.set_attribute("collection.job_id", str(job_id)) - span.set_attribute("kaapi.project_id", project_id) - span.set_attribute("kaapi.organization_id", organization_id) - try: - creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - - span.set_attribute("collection.provider", str(creation_request.provider)) - - job_uuid = UUID(job_id) - - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_each(creation_request.documents) - - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) - total_size_mb = round(total_size_kb / 1024, 2) - span.set_attribute("collection.documents.count", len(flat_docs)) - span.set_attribute("collection.documents.total_size_mb", total_size_mb) - - with Session(engine) as session: - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - task_id=task_id, - status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, - ), - ) + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" - storage = get_cloud_storage(session=session, project_id=project_id) - provider = get_llm_provider( - session=session, - provider=creation_request.provider, - project_id=project_id, - organization_id=organization_id, - ) + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" - with tracer.start_as_current_span("collections.create.provider"): - result = provider.create( - collection_request=creation_request, - storage=storage, - documents=flat_docs, - ) + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), + ) - llm_service_id = result.llm_service_id - llm_service_name = result.llm_service_name - - with Session(engine) as session: - collection_crud = CollectionCrud(session, project_id) - collection_id = uuid4() - - collection = Collection( - id=collection_id, - project_id=project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, - provider=creation_request.provider, - name=creation_request.name, - description=creation_request.description, - ) - collection_crud.create(collection) - collection = collection_crud.read_one(collection.id) - - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) - - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), - ) + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + is_final = not remaining_batches + + with Session(engine) as session: + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] + ) + for doc in batch_docs: + session.expunge(doc) + + collection_result = provider.create( + creation_request, + batch_docs, + vector_store_id=vector_store_id, + is_final=is_final, + ) + resolved_vector_store_id = ( + collection_result.llm_service_id + if not is_final + else vector_store_id or collection_result.llm_service_id + ) - success_payload = build_success_payload(collection_job, collection) + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [str(d) for d in all_doc_ids_this_batch] - span.set_attribute("collection.id", str(collection_id)) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) + + logger.info( + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", + batch_number, + len(all_doc_ids_this_batch), + job_id, + ) - elapsed = time.time() - start_time + if remaining_batches: + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Total Size: %s MB | Types: %s", - collection_id, - elapsed, - len(flat_docs), - collection_job.total_size_mb, - list(file_exts), + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, ) + return + + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) if all_uploaded_ids else [] + ) + for doc in all_docs: + session.expunge(doc) + + with Session(engine) as session: + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, + provider=creation_request.provider, + name=creation_request.name, + description=creation_request.description, + ) + collection_crud = CollectionCrud(session, project_id) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) + + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) + + success_payload = build_success_payload(collection_job, collection) + + logger.info( + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", + collection_id, + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), + ) if creation_request.callback_url: webhook_secret = get_webhook_secret(project_id, organization_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6985ac78e..3f0a0cefd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -19,7 +19,6 @@ MAX_DOC_SIZE_MB = 25 # 25 MB maximum per document # Maximum batch size for uploading documents to vector store -# Derived from MAX_DOC_SIZE + buffer to ensure single docs always fit MAX_BATCH_SIZE_KB = (MAX_DOC_SIZE_MB + 5) * 1024 # 30 MB in KB (25 + 5 MB buffer) MAX_BATCH_COUNT = 200 # Maximum documents per batch @@ -83,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 0 + doc_size_kb = doc.file_size_kb would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 36283d1fa..6649a0725 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -19,48 +19,46 @@ class BaseProvider(ABC): """ def __init__(self, client: Any) -> None: - """Initialize provider with client. + self.client = client + + @abstractmethod + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + """Upload all documents to the provider's file storage and persist their file IDs. Args: - client: Provider-specific client instance + storage: Cloud storage instance to fetch raw file bytes from + docs: Documents to upload + project_id: Project ID used to persist the provider file IDs to the DB """ - self.client = client + raise NotImplementedError("Providers must implement upload_files method") @abstractmethod def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: list[Document], + docs: list[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """Create collection with documents and optionally an assistant. - - Args: - collection_request: Collection parameters (name, description, document list, etc.) - storage: Cloud storage instance for file access - documents: Pre-fetched list of Document objects to add to the collection - - Returns: - Collection object with llm_service_id and llm_service_name populated - """ - raise NotImplementedError("Providers must implement execute method") + """Upload docs batch to vector store (creating it if vector_store_id is None). + Creates assistant only when is_final=True and model/instructions are set. + Returns Collection with llm_service_id set to vector_store_id on intermediate batches, + or to assistant/vector_store id on the final batch.""" + raise NotImplementedError("Providers must implement create method") @abstractmethod def delete(self, collection: Collection) -> None: - """Delete remote resources associated with a collection. - - Called when a collection is being deleted and remote resources need to be cleaned up. - - Args: - llm_service_id: ID of the resource to delete - llm_service_name: Name of the service (determines resource type) - """ + """Delete remote resources associated with a collection.""" raise NotImplementedError("Providers must implement delete method") - def get_provider_name(self) -> str: - """Get the name of the provider. + def get_existing_file_id(self, _doc: Document) -> str | None: + """Return the already-uploaded file ID for this provider, or None to trigger upload.""" + return None - Returns: - Provider name (e.g., "openai", "bedrock", "pinecone") - """ + def get_provider_name(self) -> str: return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index f52e83394..3afaaba81 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,12 +1,16 @@ import logging +from io import BytesIO from typing import List from openai import OpenAI +from sqlmodel import Session from app.services.collections.providers import BaseProvider from app.core.cloud.storage import CloudStorage +from app.core.db import engine +from app.crud import DocumentCrud from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import get_service_name, batch_documents +from app.services.collections.helpers import get_service_name from app.models import CreationRequest, Collection, Document @@ -20,29 +24,72 @@ def __init__(self, client: OpenAI): super().__init__(client) self.client = client + def get_existing_file_id(self, doc: Document) -> str | None: + return doc.openai_file_id + + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + for doc in docs: + if self.get_existing_file_id(doc): + continue + try: + content = storage.get(doc.object_store_url) + if doc.file_size_kb is None: + doc.file_size_kb = round(len(content) / 1024, 2) + f_obj = BytesIO(content) + f_obj.name = doc.fname + uploaded = self.client.files.create(file=f_obj, purpose="assistants") + doc.openai_file_id = uploaded.id + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + db_doc = document_crud.read_one(doc.id) + db_doc.openai_file_id = uploaded.id + db_doc.file_size_kb = doc.file_size_kb + document_crud.update(db_doc) + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: List[Document], + docs: List[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """ - Create OpenAI vector store with documents and optionally an assistant. - docs_batches must be pre-fetched inside a DB session before this call. - """ try: - docs_batches = batch_documents(documents) vector_store_crud = OpenAIVectorStoreCrud(self.client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + if vector_store_id is None: + vector_store = vector_store_crud.create() + vector_store_id = vector_store.id + logger.info( + "[OpenAIProvider.create] Vector store created | vector_store_id=%s", + vector_store_id, + ) - logger.info( - "[OpenAIProvider.create] Vector store created | " - f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" - ) + if docs: + vector_store_crud.update_batch(vector_store_id, docs) + logger.info( + "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", + vector_store_id, + len(docs), + ) + + if not is_final: + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), + ) - # Check if we need to create an assistant (based on assistant options in request) with_assistant = ( collection_request.model is not None and collection_request.instructions is not None @@ -59,11 +106,12 @@ def create( k: v for k, v in assistant_options.items() if v is not None } - assistant = assistant_crud.create(vector_store.id, **filtered_options) + assistant = assistant_crud.create(vector_store_id, **filtered_options) logger.info( - "[OpenAIProvider.create] Assistant created | " - f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + "[OpenAIProvider.create] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store_id, ) return Collection( @@ -76,7 +124,7 @@ def create( ) return Collection( - llm_service_id=vector_store.id, + llm_service_id=vector_store_id, llm_service_name=get_service_name("openai"), ) diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 7cddaf305..8b43946a1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,14 +122,12 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total -def test_batch_documents_with_none_file_size() -> None: - """Test that documents with None file_size are treated as 0 bytes.""" +def test_batch_documents_with_none_file_size_raises() -> None: + """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" docs = create_fake_documents(10, file_size_kb=None) - batches = helpers.batch_documents(docs) - # All files with None/0 size should fit in one batch (under both limits) - assert len(batches) == 1 - assert len(batches[0]) == 10 + with pytest.raises(TypeError): + helpers.batch_documents(docs) def test_batch_documents_empty_input() -> None: