From 1eabb0020a5975e6ca960203883d46d137edbe0c Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:06:48 +0530 Subject: [PATCH 1/9] default file size and addding documentation --- backend/app/api/docs/documents/upload.md | 1 + .../services/collections/create_collection.py | 7 +++--- backend/app/services/collections/helpers.py | 22 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index e667015f5..c4c06caa6 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,6 +1,7 @@ Upload a document to Kaapi. - If only a file is provided, the document will be uploaded and stored, and its ID will be returned. +- The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index eb37fd039..d12b7be3f 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -22,6 +22,7 @@ CreationRequest, ) from app.services.collections.helpers import ( + calculate_total_size_kb, extract_error_message, to_collection_public, ) @@ -156,6 +157,7 @@ def execute_job( result = None creation_request = None provider = None + storage = None try: creation_request = CreationRequest(**request) @@ -169,9 +171,10 @@ def execute_job( with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) + total_size_kb = calculate_total_size_kb(flat_docs, storage) total_size_mb = round(total_size_kb / 1024, 2) with Session(engine) as session: @@ -186,8 +189,6 @@ def execute_job( ), ) - storage = get_cloud_storage(session=session, project_id=project_id) - provider = get_llm_provider( session=session, provider=creation_request.provider, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6985ac78e..66f9dc1c0 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,6 +2,7 @@ import json import ast import re +from typing import TYPE_CHECKING from uuid import UUID from fastapi import HTTPException @@ -11,6 +12,9 @@ from app.api.deps import SessionDep from app.models import DocumentCollection, Collection, CollectionPublic, Document +if TYPE_CHECKING: + from app.core.cloud.storage import CloudStorage + logger = logging.getLogger(__name__) @@ -63,6 +67,22 @@ def extract_error_message(err: Exception) -> str: return message.strip()[:1000] +def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: + """ + Sum document sizes in KB. Uses the stored file_size_kb if available. + """ + total: float = 0 + for doc in documents: + if doc.file_size_kb is not None: + total += doc.file_size_kb + else: + logger.info( + f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" + ) + total += storage.get_file_size_kb(doc.object_store_url) + return total + + def batch_documents(documents: list[Document]) -> list[list[Document]]: """ Batch documents dynamically based on size and count limits. @@ -83,7 +103,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 0 + doc_size_kb = doc.file_size_kb or 15 * 1024 would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From 7f5d86f45af99e6b8133d9954158377c41106df1 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:09:09 +0530 Subject: [PATCH 2/9] default file size and addding documentation --- backend/app/api/docs/documents/upload.md | 3 +-- backend/app/services/collections/create_collection.py | 1 - backend/app/services/collections/helpers.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index c4c06caa6..438dc3e9b 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,7 +1,6 @@ Upload a document to Kaapi. -- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. -- The maximum file size allowed for upload is 25 MB. +- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index d12b7be3f..009d55fd1 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -152,7 +152,6 @@ def execute_job( """ start_time = time.time() - # Keeping the references for potential backout/cleanup on failure collection_job = None result = None creation_request = None diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 66f9dc1c0..1b0ae0ace 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -23,7 +23,6 @@ MAX_DOC_SIZE_MB = 25 # 25 MB maximum per document # Maximum batch size for uploading documents to vector store -# Derived from MAX_DOC_SIZE + buffer to ensure single docs always fit MAX_BATCH_SIZE_KB = (MAX_DOC_SIZE_MB + 5) * 1024 # 30 MB in KB (25 + 5 MB buffer) MAX_BATCH_COUNT = 200 # Maximum documents per batch From 8e3d29d2093f488d7c54a0ec2f1708c05d7023b7 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:30:57 +0530 Subject: [PATCH 3/9] coderabbit reviews --- .../services/collections/create_collection.py | 19 +++++++++++++++- backend/app/services/collections/helpers.py | 22 +------------------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 009d55fd1..887208e18 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -22,7 +22,6 @@ CreationRequest, ) from app.services.collections.helpers import ( - calculate_total_size_kb, extract_error_message, to_collection_public, ) @@ -136,6 +135,24 @@ def _mark_job_failed( return None +def calculate_total_size_kb( + documents: list[Document], storage: "CloudStorage" +) -> float: + """ + Sum document sizes in KB. Uses the stored file_size_kb if available. + """ + total: float = 0 + for doc in documents: + if doc.file_size_kb is not None: + total += doc.file_size_kb + else: + logger.info( + f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" + ) + total += storage.get_file_size_kb(doc.object_store_url) + return total + + def execute_job( request: dict, with_assistant: bool, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 1b0ae0ace..db972c92d 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,7 +2,6 @@ import json import ast import re -from typing import TYPE_CHECKING from uuid import UUID from fastapi import HTTPException @@ -12,9 +11,6 @@ from app.api.deps import SessionDep from app.models import DocumentCollection, Collection, CollectionPublic, Document -if TYPE_CHECKING: - from app.core.cloud.storage import CloudStorage - logger = logging.getLogger(__name__) @@ -66,22 +62,6 @@ def extract_error_message(err: Exception) -> str: return message.strip()[:1000] -def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: - """ - Sum document sizes in KB. Uses the stored file_size_kb if available. - """ - total: float = 0 - for doc in documents: - if doc.file_size_kb is not None: - total += doc.file_size_kb - else: - logger.info( - f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" - ) - total += storage.get_file_size_kb(doc.object_store_url) - return total - - def batch_documents(documents: list[Document]) -> list[list[Document]]: """ Batch documents dynamically based on size and count limits. @@ -102,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 15 * 1024 + doc_size_kb = doc.file_size_kb if doc.file_size_kb is not None else 15 * 1024 would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From bed1d1a4f13f2cd918e639c0dd288fe9a661dfdc Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:35:28 +0530 Subject: [PATCH 4/9] test cases failing --- backend/app/services/collections/create_collection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 887208e18..bd55c2871 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -14,6 +14,7 @@ CollectionJobCrud, ) from app.models import ( + Document, CollectionJobStatus, CollectionJob, Collection, @@ -21,6 +22,7 @@ CollectionJobPublic, CreationRequest, ) +from app.core.cloud.storage import CloudStorage from app.services.collections.helpers import ( extract_error_message, to_collection_public, @@ -135,9 +137,7 @@ def _mark_job_failed( return None -def calculate_total_size_kb( - documents: list[Document], storage: "CloudStorage" -) -> float: +def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: """ Sum document sizes in KB. Uses the stored file_size_kb if available. """ From d02bac8a57d037327e501eb3e74ecfaa46c70c34 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 17:01:00 +0530 Subject: [PATCH 5/9] changing the logic --- .../services/collections/create_collection.py | 39 +++++++++---------- backend/app/services/collections/helpers.py | 2 +- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index bd55c2871..14696b191 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -14,7 +14,6 @@ CollectionJobCrud, ) from app.models import ( - Document, CollectionJobStatus, CollectionJob, Collection, @@ -22,7 +21,6 @@ CollectionJobPublic, CreationRequest, ) -from app.core.cloud.storage import CloudStorage from app.services.collections.helpers import ( extract_error_message, to_collection_public, @@ -137,22 +135,6 @@ def _mark_job_failed( return None -def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: - """ - Sum document sizes in KB. Uses the stored file_size_kb if available. - """ - total: float = 0 - for doc in documents: - if doc.file_size_kb is not None: - total += doc.file_size_kb - else: - logger.info( - f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" - ) - total += storage.get_file_size_kb(doc.object_store_url) - return total - - def execute_job( request: dict, with_assistant: bool, @@ -190,10 +172,27 @@ def execute_job( storage = get_cloud_storage(session=session, project_id=project_id) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} - total_size_kb = calculate_total_size_kb(flat_docs, storage) - total_size_mb = round(total_size_kb / 1024, 2) + + backfill: list[tuple[UUID, float]] = [] + for doc in flat_docs: + if doc.file_size_kb is None: + size_kb = round(storage.get_file_size_kb(doc.object_store_url)) + doc.file_size_kb = size_kb + backfill.append((doc.id, size_kb)) + + total_size_kb = sum( + doc.file_size_kb for doc in flat_docs if doc.file_size_kb is not None + ) + total_size_mb = total_size_kb / 1024 with Session(engine) as session: + if backfill: + document_crud = DocumentCrud(session, project_id) + for doc_id, size_kb in backfill: + doc = document_crud.read_one(doc_id) + doc.file_size_kb = size_kb + document_crud.update(doc) + collection_job_crud = CollectionJobCrud(session, project_id) collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index db972c92d..3f0a0cefd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -82,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb if doc.file_size_kb is not None else 15 * 1024 + doc_size_kb = doc.file_size_kb would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From 8b7556c226449f0f9a52bc66906642bd55fa8068 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 20:37:30 +0530 Subject: [PATCH 6/9] fixing test cases --- backend/app/services/collections/create_collection.py | 2 +- backend/app/tests/services/collections/test_helpers.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 14696b191..25aba0919 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -200,7 +200,7 @@ def execute_job( CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, + total_size_mb=round(total_size_mb, 2), ), ) diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 7cddaf305..8b43946a1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,14 +122,12 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total -def test_batch_documents_with_none_file_size() -> None: - """Test that documents with None file_size are treated as 0 bytes.""" +def test_batch_documents_with_none_file_size_raises() -> None: + """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" docs = create_fake_documents(10, file_size_kb=None) - batches = helpers.batch_documents(docs) - # All files with None/0 size should fit in one batch (under both limits) - assert len(batches) == 1 - assert len(batches[0]) == 10 + with pytest.raises(TypeError): + helpers.batch_documents(docs) def test_batch_documents_empty_input() -> None: From baaeac2eec68e1c89fdf764dcb563bbfdf408c98 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 5 May 2026 09:26:17 +0530 Subject: [PATCH 7/9] adding alembic file --- .../055_add_columns_to_collections.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 backend/app/alembic/versions/055_add_columns_to_collections.py diff --git a/backend/app/alembic/versions/055_add_columns_to_collections.py b/backend/app/alembic/versions/055_add_columns_to_collections.py new file mode 100644 index 000000000..804e5cc7d --- /dev/null +++ b/backend/app/alembic/versions/055_add_columns_to_collections.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_job and provider_file_id to document + +Revision ID: 055 +Revises: 054 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") From 47525e49639a2080d96f6a8b9bd980a79c74ff1f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 5 May 2026 10:16:25 +0530 Subject: [PATCH 8/9] adding logic to the pr --- ..._add_batch_tracking_to_collections_jobs.py | 62 ++ backend/app/celery/utils.py | 18 + backend/app/crud/rag/open_ai.py | 53 ++ backend/app/models/collection_job.py | 27 +- backend/app/models/document.py | 5 + .../services/collections/create_collection.py | 560 +++++++++++------- .../services/collections/providers/base.py | 56 +- .../services/collections/providers/openai.py | 86 ++- 8 files changed, 603 insertions(+), 264 deletions(-) create mode 100644 backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py new file mode 100644 index 000000000..804e5cc7d --- /dev/null +++ b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_job and provider_file_id to document + +Revision ID: 055 +Revises: 054 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 5ebbf624a..475a4772a 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -85,6 +85,24 @@ def start_doctransform_job( return task_id +def start_create_collection_setup_job( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_create_collection_setup_job + + task_id = _enqueue_with_trace_context( + run_create_collection_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, + ) + logger.info( + f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + ) + return task_id + + def start_create_collection_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index cdae82440..be6970235 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -1,6 +1,7 @@ import json import logging import functools as ft +import time from io import BytesIO from typing import Iterable @@ -149,6 +150,58 @@ def update( yield from docs + def update_batch( + self, + vector_store_id: str, + docs: list[Document], + ) -> tuple[list[Document], list[Document]]: + """ + Attach a batch of documents to the vector store via a single upload_and_poll call. + + All docs must have provider_file_id set before calling this method. + Returns (succeeded, failed) — failed docs should be retried in the next batch. + """ + succeeded: list[Document] = [] + failed: list[Document] = [] + + if not docs: + return succeeded, failed + + try: + _t0 = time.monotonic() + batch = self.client.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, + files=[], + file_ids=[doc.openai_file_id for doc in docs], + ) + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch upload_and_poll duration | " + f"{{'vector_store_id': '{vector_store_id}', 'duration_s': {time.monotonic() - _t0:.3f}, " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" + ) + if batch.file_counts.failed == 0: + succeeded.extend(docs) + else: + # Can't identify which specific files failed — retry all of them + logger.warning( + f"[OpenAIVectorStoreCrud.update_batch] Batch had failures, marking all for retry | " + f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" + ) + failed.extend(docs) + except OpenAIError as err: + logger.error( + f"[OpenAIVectorStoreCrud.update_batch] Batch attach failed | " + f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) + failed.extend(docs) + + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', 'succeeded': {len(succeeded)}, 'failed': {len(failed)}}}" + ) + return succeeded, failed + def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: try: diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index 333ebfd14..6b628ad7e 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -77,7 +77,29 @@ class CollectionJob(SQLModel, table=True): documents: list[str] | None = Field( default=None, sa_column=Column( - JSON, nullable=True, comment="List of documents given to make collection" + JSON, nullable=True, comment="List of document IDs given to make collection" + ), + ) + total_batches: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Total number of batches the documents are split into" + }, + ) + current_batch_number: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Which batch is currently being processed (1-indexed)" + }, + ) + documents_uploaded: list[str] | None = Field( + default=None, + sa_column=Column( + JSON, + nullable=True, + comment="List of document IDs successfully uploaded so far", ), ) @@ -139,6 +161,9 @@ class CollectionJobUpdate(SQLModel): collection_id: UUID | None = None total_size_mb: float | None = None trace_id: str | None = None + total_batches: int | None = None + current_batch_number: int | None = None + documents_uploaded: list[str] | None = None ##Response models diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 12843e72a..5bbcddc77 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -46,6 +46,11 @@ class Document(DocumentBase, table=True): description="The size of the document in kilobytes", sa_column_kwargs={"comment": "Size of the document in kilobytes (KB)"}, ) + openai_file_id: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "File ID assigned by OpenAI (avoid re-uploading)"}, + ) # Foreign keys source_document_id: UUID | None = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index bc42aa0d0..884f8e3bd 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -2,13 +2,11 @@ import time from uuid import UUID, uuid4 -from opentelemetry import trace from sqlmodel import Session from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage from app.core.db import engine -from app.core.telemetry import log_context from app.crud import ( CollectionCrud, DocumentCrud, @@ -23,17 +21,19 @@ CollectionJobPublic, CreationRequest, ) +from app.crud.rag import OpenAIVectorStoreCrud from app.services.collections.helpers import ( + batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job -from app.utils import send_callback, get_webhook_secret, APIResponse +from gevent import Timeout +from app.celery.utils import start_create_collection_job, start_collection_batch_job +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) -tracer = trace.get_tracer(__name__) def start_job( @@ -44,49 +44,31 @@ def start_job( with_assistant: bool, organization_id: int, ) -> str: - with log_context( - tag="collection", - lifecycle="collection.create.start_job", - action="create", - collection_job_id=collection_job_id, - project_id=project_id, - organization_id=organization_id, - ): - trace_id = correlation_id.get() or "N/A" + trace_id = correlation_id.get() or "N/A" - job_crud = CollectionJobCrud(db, project_id) - collection_job = job_crud.update( - collection_job_id, CollectionJobUpdate(trace_id=trace_id) - ) + job_crud = CollectionJobCrud(db, project_id) + job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( - project_id=project_id, - job_id=str(collection_job_id), - trace_id=trace_id, - request=request.model_dump(mode="json"), - with_assistant=with_assistant, - organization_id=organization_id, - ) + task_id = start_create_collection_job( + project_id=project_id, + job_id=str(collection_job_id), + trace_id=trace_id, + request=request.model_dump(mode="json"), + with_assistant=with_assistant, + organization_id=organization_id, + ) - logger.info( - "[create_collection.start_job] Job scheduled to create collection | " - f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" - ) + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) - return collection_job_id + return collection_job_id def build_success_payload( collection_job: CollectionJob, collection: Collection ) -> dict: - """ - { - "success": true, - "data": { job fields + full collection }, - "error": null, - "metadata": null - } - """ collection_public = to_collection_public(collection) collection_dict = collection_public.model_dump(mode="json", exclude_none=True) @@ -100,15 +82,6 @@ def build_success_payload( def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: - """ - { - "success": false, - "data": { job fields, collection: null }, - "error": "something went wrong", - "metadata": null - } - """ - # ensure `collection` is explicitly null in the payload job_public = CollectionJobPublic.model_validate( collection_job, update={"collection": None}, @@ -142,11 +115,63 @@ def _mark_job_failed( ) return collection_job except Exception: - logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + logger.warning("[create_collection] Failed to mark job as FAILED") return None -def execute_job( +def _persist_succeeded_docs(succeeded: list, project_id: int) -> list[str]: + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + for doc in succeeded: + if doc.openai_file_id: + db_doc = document_crud.read_one(doc.id) + if db_doc.openai_file_id != doc.openai_file_id: + db_doc.openai_file_id = doc.openai_file_id + document_crud.update(db_doc) + return [str(doc.id) for doc in succeeded] + + +def _retry_failed_uploads( + vector_store_crud, + vector_store_id: str, + failed_docs: list, + project_id: int, + max_retries: int = 3, +) -> list[str]: + """ + Retry attaching docs that failed the initial batch upload_and_poll. + All docs must already have provider_file_id set. + Returns the list of successfully retried doc IDs. + Raises RuntimeError if any docs still fail after all retries. + """ + pending = failed_docs + all_succeeded_ids: list[str] = [] + + for attempt in range(1, max_retries + 1): + logger.warning( + "[_retry_failed_uploads] Retry attempt %d/%d: %d doc(s) | vector_store_id=%s", + attempt, + max_retries, + len(pending), + vector_store_id, + ) + succeeded, failed = vector_store_crud.update_batch(vector_store_id, pending) + + if succeeded: + all_succeeded_ids += _persist_succeeded_docs(succeeded, project_id) + + if not failed: + return all_succeeded_ids + + pending = failed + + ids = [str(d.id) for d in pending] + raise RuntimeError( + f"Failed to upload {len(pending)} document(s) after {max_retries} retries: {ids}" + ) + + +def execute_setup_job( request: dict, with_assistant: bool, project_id: int, @@ -156,206 +181,311 @@ def execute_job( task_instance, ) -> None: """ - Worker entrypoint scheduled by start_job. - Orchestrates: job state, provider init, collection creation, - optional assistant creation, collection persistence, linking, callbacks, and cleanup. + Phase 1: Fetch documents, create the vector store, split into batches, + update job state to PROCESSING, then queue the first batch task. """ - start_time = time.time() - collection_job = None - result = None creation_request = None - provider = None - storage = None - - with log_context( - tag="collection", - lifecycle="collection.create.execute_job", - action="create", - collection_job_id=job_id, - task_id=task_id, - project_id=project_id, - organization_id=organization_id, - ), tracer.start_as_current_span("collections.create.execute_job") as span: - span.set_attribute("collection.job_id", str(job_id)) - span.set_attribute("kaapi.project_id", project_id) - span.set_attribute("kaapi.organization_id", organization_id) - - try: - creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - - span.set_attribute("collection.provider", str(creation_request.provider)) - - job_uuid = UUID(job_id) - - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_each(creation_request.documents) - - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) - total_size_mb = round(total_size_kb / 1024, 2) - span.set_attribute("collection.documents.count", len(flat_docs)) - span.set_attribute("collection.documents.total_size_mb", total_size_mb) - - with Session(engine) as session: - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - task_id=task_id, - status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, - ), - ) - - storage = get_cloud_storage(session=session, project_id=project_id) - provider = get_llm_provider( - session=session, - provider=creation_request.provider, - project_id=project_id, - organization_id=organization_id, - ) + + try: + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) storage = get_cloud_storage(session=session, project_id=project_id) - file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + for doc in flat_docs: + session.expunge(doc) - backfill: list[tuple[UUID, float]] = [] - for doc in flat_docs: - if doc.file_size_kb is None: - size_kb = round(storage.get_file_size_kb(doc.object_store_url)) - doc.file_size_kb = size_kb - backfill.append((doc.id, size_kb)) + provider.upload_files(storage, flat_docs, project_id) - total_size_kb = sum( - doc.file_size_kb for doc in flat_docs if doc.file_size_kb is not None + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d, failed=%d, duration_s=%.2f", + job_id, + len(flat_docs), ) + + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) total_size_mb = total_size_kb / 1024 - with Session(engine) as session: - if backfill: - document_crud = DocumentCrud(session, project_id) - for doc_id, size_kb in backfill: - doc = document_crud.read_one(doc_id) - doc.file_size_kb = size_kb - document_crud.update(doc) + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] + with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( job_uuid, CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, - total_size_mb=round(total_size_mb, 2), + total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], ), ) + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) + + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" + ) + _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=timeout_err, + collection_job=collection_job, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) + + +def execute_batch_job( + request: dict, + with_assistant: bool, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + task_instance, + vector_store_id: str | None, + batch_number: int, + batch_doc_ids: list[str], + remaining_batches: list[list[str]], +) -> None: + """ + Phase 2: Upload one batch of documents to the vector store. + - Uploads the batch; any failures within the batch are retried inline by _upload_batch_with_retry + - Raises immediately if all retries for the batch are exhausted + - Checkpoints progress to the DB + - If more batches remain, queues the next batch task + - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL + """ + collection_job = None + creation_request = None + + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" + + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), + ) + + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + is_final = not remaining_batches + + with Session(engine) as session: provider = get_llm_provider( session=session, provider=creation_request.provider, project_id=project_id, organization_id=organization_id, - with tracer.start_as_current_span("collections.create.provider"): - result = provider.create( - collection_request=creation_request, - storage=storage, - documents=flat_docs, - ) - - llm_service_id = result.llm_service_id - llm_service_name = result.llm_service_name - - with Session(engine) as session: - collection_crud = CollectionCrud(session, project_id) - collection_id = uuid4() - - collection = Collection( - id=collection_id, - project_id=project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, - provider=creation_request.provider, - name=creation_request.name, - description=creation_request.description, - ) - collection_crud.create(collection) - collection = collection_crud.read_one(collection.id) - - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) - - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), - ) - - success_payload = build_success_payload(collection_job, collection) - - span.set_attribute("collection.id", str(collection_id)) - - elapsed = time.time() - start_time - logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Total Size: %s MB | Types: %s", - collection_id, - elapsed, - len(flat_docs), - collection_job.total_size_mb, - list(file_exts), ) - if creation_request.callback_url: - webhook_secret = get_webhook_secret(project_id, organization_id) - send_callback( - str(creation_request.callback_url), - success_payload, - webhook_secret=webhook_secret, - ) - - except Exception as err: - span.record_exception(err) - span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) - logger.error( - "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", - job_id, - str(err), - exc_info=True, + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] ) + for doc in batch_docs: + session.expunge(doc) + + collection_result = provider.create( + creation_request, + batch_docs, + vector_store_id=vector_store_id, + is_final=is_final, + ) + resolved_vector_store_id = ( + collection_result.llm_service_id + if not is_final + else vector_store_id or collection_result.llm_service_id + ) + + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [str(d) for d in all_doc_ids_this_batch] - if provider is not None and result is not None: - try: - provider.delete(result) - except Exception: - logger.warning( - "[create_collection.execute_job] Provider cleanup failed" - ) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) - collection_job = _mark_job_failed( + logger.info( + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", + batch_number, + len(all_doc_ids_this_batch), + job_id, + ) + + if remaining_batches: + start_collection_batch_job( project_id=project_id, job_id=job_id, - err=err, - collection_job=collection_job, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + logger.info( + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, + ) + return + + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) if all_uploaded_ids else [] ) + for doc in all_docs: + session.expunge(doc) - if creation_request and creation_request.callback_url and collection_job: - failure_payload = build_failure_payload(collection_job, str(err)) - webhook_secret = get_webhook_secret(project_id, organization_id) - send_callback( - str(creation_request.callback_url), - failure_payload, - webhook_secret=webhook_secret, - ) - raise + with Session(engine) as session: + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, + provider=creation_request.provider, + name=creation_request.name, + description=creation_request.description, + ) + collection_crud = CollectionCrud(session, project_id) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) + + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) + + success_payload = build_success_payload(collection_job, collection) + + logger.info( + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", + collection_id, + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), + ) + + if creation_request.callback_url: + send_callback(creation_request.callback_url, success_payload) + + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_batch_job] Task exceeded soft time limit of {err.seconds}s" + ) + _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=timeout_err, + collection_job=collection_job, + ) + raise + except BaseException as err: + logger.error( + "[create_collection.execute_batch_job] Batch %d failed | job_id=%s, error=%s", + batch_number, + job_id, + str(err), + exc_info=True, + ) + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 36283d1fa..6649a0725 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -19,48 +19,46 @@ class BaseProvider(ABC): """ def __init__(self, client: Any) -> None: - """Initialize provider with client. + self.client = client + + @abstractmethod + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + """Upload all documents to the provider's file storage and persist their file IDs. Args: - client: Provider-specific client instance + storage: Cloud storage instance to fetch raw file bytes from + docs: Documents to upload + project_id: Project ID used to persist the provider file IDs to the DB """ - self.client = client + raise NotImplementedError("Providers must implement upload_files method") @abstractmethod def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: list[Document], + docs: list[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """Create collection with documents and optionally an assistant. - - Args: - collection_request: Collection parameters (name, description, document list, etc.) - storage: Cloud storage instance for file access - documents: Pre-fetched list of Document objects to add to the collection - - Returns: - Collection object with llm_service_id and llm_service_name populated - """ - raise NotImplementedError("Providers must implement execute method") + """Upload docs batch to vector store (creating it if vector_store_id is None). + Creates assistant only when is_final=True and model/instructions are set. + Returns Collection with llm_service_id set to vector_store_id on intermediate batches, + or to assistant/vector_store id on the final batch.""" + raise NotImplementedError("Providers must implement create method") @abstractmethod def delete(self, collection: Collection) -> None: - """Delete remote resources associated with a collection. - - Called when a collection is being deleted and remote resources need to be cleaned up. - - Args: - llm_service_id: ID of the resource to delete - llm_service_name: Name of the service (determines resource type) - """ + """Delete remote resources associated with a collection.""" raise NotImplementedError("Providers must implement delete method") - def get_provider_name(self) -> str: - """Get the name of the provider. + def get_existing_file_id(self, _doc: Document) -> str | None: + """Return the already-uploaded file ID for this provider, or None to trigger upload.""" + return None - Returns: - Provider name (e.g., "openai", "bedrock", "pinecone") - """ + def get_provider_name(self) -> str: return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index f52e83394..3afaaba81 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,12 +1,16 @@ import logging +from io import BytesIO from typing import List from openai import OpenAI +from sqlmodel import Session from app.services.collections.providers import BaseProvider from app.core.cloud.storage import CloudStorage +from app.core.db import engine +from app.crud import DocumentCrud from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import get_service_name, batch_documents +from app.services.collections.helpers import get_service_name from app.models import CreationRequest, Collection, Document @@ -20,29 +24,72 @@ def __init__(self, client: OpenAI): super().__init__(client) self.client = client + def get_existing_file_id(self, doc: Document) -> str | None: + return doc.openai_file_id + + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + for doc in docs: + if self.get_existing_file_id(doc): + continue + try: + content = storage.get(doc.object_store_url) + if doc.file_size_kb is None: + doc.file_size_kb = round(len(content) / 1024, 2) + f_obj = BytesIO(content) + f_obj.name = doc.fname + uploaded = self.client.files.create(file=f_obj, purpose="assistants") + doc.openai_file_id = uploaded.id + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + db_doc = document_crud.read_one(doc.id) + db_doc.openai_file_id = uploaded.id + db_doc.file_size_kb = doc.file_size_kb + document_crud.update(db_doc) + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: List[Document], + docs: List[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """ - Create OpenAI vector store with documents and optionally an assistant. - docs_batches must be pre-fetched inside a DB session before this call. - """ try: - docs_batches = batch_documents(documents) vector_store_crud = OpenAIVectorStoreCrud(self.client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + if vector_store_id is None: + vector_store = vector_store_crud.create() + vector_store_id = vector_store.id + logger.info( + "[OpenAIProvider.create] Vector store created | vector_store_id=%s", + vector_store_id, + ) - logger.info( - "[OpenAIProvider.create] Vector store created | " - f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" - ) + if docs: + vector_store_crud.update_batch(vector_store_id, docs) + logger.info( + "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", + vector_store_id, + len(docs), + ) + + if not is_final: + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), + ) - # Check if we need to create an assistant (based on assistant options in request) with_assistant = ( collection_request.model is not None and collection_request.instructions is not None @@ -59,11 +106,12 @@ def create( k: v for k, v in assistant_options.items() if v is not None } - assistant = assistant_crud.create(vector_store.id, **filtered_options) + assistant = assistant_crud.create(vector_store_id, **filtered_options) logger.info( - "[OpenAIProvider.create] Assistant created | " - f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + "[OpenAIProvider.create] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store_id, ) return Collection( @@ -76,7 +124,7 @@ def create( ) return Collection( - llm_service_id=vector_store.id, + llm_service_id=vector_store_id, llm_service_name=get_service_name("openai"), ) From 2a2e2680ef3f4b7a58821f9f98d7ff119aedaaf7 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 6 May 2026 08:55:41 +0530 Subject: [PATCH 9/9] pushing few changes --- ..._add_batch_tracking_to_collections_jobs.py | 2 +- .../055_add_columns_to_collections.py | 62 ----- backend/app/celery/tasks/job_execution.py | 215 +++++++----------- backend/app/celery/utils.py | 175 +++++++------- 4 files changed, 158 insertions(+), 296 deletions(-) delete mode 100644 backend/app/alembic/versions/055_add_columns_to_collections.py diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py index 804e5cc7d..26fb1a8d3 100644 --- a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py +++ b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py @@ -1,4 +1,4 @@ -"""add batch tracking to collection_job and provider_file_id to document +"""add batch tracking to collection_jobs Revision ID: 055 Revises: 054 diff --git a/backend/app/alembic/versions/055_add_columns_to_collections.py b/backend/app/alembic/versions/055_add_columns_to_collections.py deleted file mode 100644 index 804e5cc7d..000000000 --- a/backend/app/alembic/versions/055_add_columns_to_collections.py +++ /dev/null @@ -1,62 +0,0 @@ -"""add batch tracking to collection_job and provider_file_id to document - -Revision ID: 055 -Revises: 054 -Create Date: 2026-04-13 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "055" -down_revision = "054" -branch_labels = None -depends_on = None - - -def upgrade(): - op.add_column( - "collection_jobs", - sa.Column( - "total_batches", - sa.Integer(), - nullable=True, - comment="Total number of batches the documents are split into", - ), - ) - op.add_column( - "collection_jobs", - sa.Column( - "current_batch_number", - sa.Integer(), - nullable=True, - comment="Which batch is currently being processed (1-indexed)", - ), - ) - op.add_column( - "collection_jobs", - sa.Column( - "documents_uploaded", - sa.JSON(), - nullable=True, - comment="List of document IDs successfully uploaded so far", - ), - ) - op.add_column( - "document", - sa.Column( - "openai_file_id", - sa.String(), - nullable=True, - comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", - ), - ) - - -def downgrade(): - op.drop_column("collection_jobs", "total_batches") - op.drop_column("collection_jobs", "current_batch_number") - op.drop_column("collection_jobs", "documents_uploaded") - op.drop_column("document", "openai_file_id") diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 8dd20091a..0156459aa 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -2,11 +2,10 @@ from asgi_correlation_id import correlation_id from celery import current_task -from opentelemetry import context as otel_context -from opentelemetry import trace -from opentelemetry.propagate import extract from app.celery.celery_app import celery_app +from app.celery.utils import gevent_timeout +from app.core.config import settings logger = logging.getLogger(__name__) @@ -16,60 +15,17 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") -def _extract_parent_context(task_instance) -> otel_context.Context: - """Extract OTel parent context from Celery headers if available.""" - headers = getattr(task_instance.request, "headers", None) or {} - carrier: dict[str, str] = {} - - if isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(value, str): - carrier[str(key)] = value - - nested = headers.get("otel", {}) - if isinstance(nested, dict): - for key, value in nested.items(): - if isinstance(value, str): - carrier[str(key)] = value - - return extract(carrier) - - -def _run_with_otel_parent(task_instance, fn): - """Attach extracted parent context and execute function. - - When Celery auto-instrumentation is active, there is already a current - `run/...` span. Re-attaching extracted parent context here would make - service spans become siblings of `run/...` instead of children. - - We only attach extracted context as a fallback when no active span exists. - """ - current_ctx = trace.get_current_span().get_span_context() - if current_ctx and current_ctx.is_valid: - return fn() - - parent_ctx = _extract_parent_context(task_instance) - token = otel_context.attach(parent_ctx) - try: - return fn() - finally: - otel_context.detach(token) - - @celery_app.task(bind=True, queue="high_priority", priority=9) def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -78,15 +34,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -95,15 +48,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -112,34 +62,46 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job") def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job + + _set_trace(trace_id) + return execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -150,15 +112,12 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -169,15 +128,12 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -188,15 +144,12 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -207,15 +160,12 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -228,13 +178,10 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 475a4772a..9ffe47113 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -3,34 +3,25 @@ Business logic modules can use these functions without knowing Celery internals. """ import logging +import functools from typing import Any, Dict from celery.result import AsyncResult -from opentelemetry.propagate import inject +from gevent import Timeout from app.celery.celery_app import celery_app logger = logging.getLogger(__name__) -def _enqueue_with_trace_context(task, **kwargs) -> str: - """Publish Celery task with explicit trace context headers.""" - otel_headers: dict[str, str] = {} - inject(otel_headers) - celery_headers = dict(otel_headers) - celery_headers["otel"] = otel_headers - async_result = task.apply_async(kwargs=kwargs, headers=celery_headers) - return async_result.id - - def start_llm_job(project_id: int, job_id: str, trace_id: str = "N/A", **kwargs) -> str: from app.celery.tasks.job_execution import run_llm_job - task_id = _enqueue_with_trace_context( - run_llm_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task = run_llm_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_llm_chain_job( @@ -38,17 +29,13 @@ def start_llm_chain_job( ) -> str: from app.celery.tasks.job_execution import run_llm_chain_job - task_id = _enqueue_with_trace_context( - run_llm_chain_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_llm_chain_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_llm_chain_job] Started job {job_id} with Celery task {task_id}" + f"[start_llm_chain_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_response_job( @@ -56,15 +43,11 @@ def start_response_job( ) -> str: from app.celery.tasks.job_execution import run_response_job - task_id = _enqueue_with_trace_context( - run_response_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_response_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_response_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_response_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_doctransform_job( @@ -72,53 +55,41 @@ def start_doctransform_job( ) -> str: from app.celery.tasks.job_execution import run_doctransform_job - task_id = _enqueue_with_trace_context( - run_doctransform_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_doctransform_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_doctransform_job] Started job {job_id} with Celery task {task_id}" + f"[start_doctransform_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id -def start_create_collection_setup_job( +def start_create_collection_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_setup_job + from app.celery.tasks.job_execution import run_create_collection_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_create_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_create_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id -def start_create_collection_job( +def start_collection_batch_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_job + from app.celery.tasks.job_execution import run_collection_batch_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_collection_batch_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_batch_job] Started batch job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_delete_collection_job( @@ -126,17 +97,13 @@ def start_delete_collection_job( ) -> str: from app.celery.tasks.job_execution import run_delete_collection_job - task_id = _enqueue_with_trace_context( - run_delete_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_delete_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_delete_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_delete_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_batch_submission( @@ -144,17 +111,13 @@ def start_stt_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_stt_batch_submission - task_id = _enqueue_with_trace_context( - run_stt_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_stt_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_metric_computation( @@ -162,17 +125,13 @@ def start_stt_metric_computation( ) -> str: from app.celery.tasks.job_execution import run_stt_metric_computation - task_id = _enqueue_with_trace_context( - run_stt_metric_computation, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_metric_computation.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_metric_computation] Started job {job_id} with Celery task {task_id}" + f"[start_stt_metric_computation] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_batch_submission( @@ -180,17 +139,13 @@ def start_tts_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_tts_batch_submission - task_id = _enqueue_with_trace_context( - run_tts_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_tts_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_result_processing( @@ -198,17 +153,13 @@ def start_tts_result_processing( ) -> str: from app.celery.tasks.job_execution import run_tts_result_processing - task_id = _enqueue_with_trace_context( - run_tts_result_processing, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_result_processing.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_result_processing] Started job {job_id} with Celery task {task_id}" + f"[start_tts_result_processing] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def get_task_status(task_id: str) -> Dict[str, Any]: @@ -229,3 +180,29 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool: except Exception as e: logger.error(f"[revoke_task] Failed to revoke task {task_id}: {e}") return False + + +def gevent_timeout(seconds, task_name=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + name = task_name or func.__name__ + timeout = Timeout(seconds) + timeout.start() + try: + return func(*args, **kwargs) + except Timeout: + logger.error( + f"[{name}] Timed out after {seconds}s — args={args}, kwargs={kwargs}" + ) + raise + # raise TimeoutError(f"[{name}] Task exceeded soft time limit of {seconds}s") + finally: + raise TimeoutError( + f"[{name}] Task exceeded soft time limit of {seconds}s" + ) + timeout.cancel() + + return wrapper + + return decorator