Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""add batch tracking to collection_jobs

Revision ID: 055
Revises: 054
Create Date: 2026-04-13

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "055"
down_revision = "054"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"collection_jobs",
sa.Column(
"total_batches",
sa.Integer(),
nullable=True,
comment="Total number of batches the documents are split into",
),
)
op.add_column(
"collection_jobs",
sa.Column(
"current_batch_number",
sa.Integer(),
nullable=True,
comment="Which batch is currently being processed (1-indexed)",
),
)
op.add_column(
"collection_jobs",
sa.Column(
"documents_uploaded",
sa.JSON(),
nullable=True,
comment="List of document IDs successfully uploaded so far",
),
)
op.add_column(
"document",
sa.Column(
"openai_file_id",
sa.String(),
nullable=True,
comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading",
),
)


def downgrade():
op.drop_column("collection_jobs", "total_batches")
op.drop_column("collection_jobs", "current_batch_number")
op.drop_column("collection_jobs", "documents_uploaded")
op.drop_column("document", "openai_file_id")
2 changes: 1 addition & 1 deletion backend/app/api/docs/documents/upload.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Upload a document to Kaapi.

- If only a file is provided, the document will be uploaded and stored, and its ID will be returned.
- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB.
- If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job.
- If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed.

Expand Down
212 changes: 78 additions & 134 deletions backend/app/celery/tasks/job_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import celery
from asgi_correlation_id import correlation_id
from celery import current_task
from opentelemetry import context as otel_context
from opentelemetry import trace
from opentelemetry.propagate import extract

from app.celery.celery_app import celery_app
from app.celery.utils import gevent_timeout
Expand All @@ -20,61 +17,18 @@ def _set_trace(trace_id: str) -> None:
logger.info(f"[_set_trace] Set correlation ID: {trace_id}")


def _extract_parent_context(task_instance) -> otel_context.Context:
"""Extract OTel parent context from Celery headers if available."""
headers = getattr(task_instance.request, "headers", None) or {}
carrier: dict[str, str] = {}

if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(value, str):
carrier[str(key)] = value

nested = headers.get("otel", {})
if isinstance(nested, dict):
for key, value in nested.items():
if isinstance(value, str):
carrier[str(key)] = value

return extract(carrier)


def _run_with_otel_parent(task_instance, fn):
"""Attach extracted parent context and execute function.

When Celery auto-instrumentation is active, there is already a current
`run/...` span. Re-attaching extracted parent context here would make
service spans become siblings of `run/...` instead of children.

We only attach extracted context as a fallback when no active span exists.
"""
current_ctx = trace.get_current_span().get_span_context()
if current_ctx and current_ctx.is_valid:
return fn()

parent_ctx = _extract_parent_context(task_instance)
token = otel_context.attach(parent_ctx)
try:
return fn()
finally:
otel_context.detach(token)


@celery_app.task(bind=True, queue="high_priority", priority=9)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job")
def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -84,15 +38,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg
from app.services.llm.jobs import execute_chain_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_chain_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_chain_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -102,15 +53,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs
from app.services.response.jobs import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -120,15 +68,12 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw
from app.services.doctransform.job import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -137,18 +82,32 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw
def run_create_collection_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
from app.services.collections.create_collection import execute_job
from app.services.collections.create_collection import execute_setup_job

_set_trace(trace_id)
return execute_setup_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job")
def run_collection_batch_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
from app.services.collections.create_collection import execute_batch_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -160,15 +119,12 @@ def run_delete_collection_job(
from app.services.collections.delete_collection import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -180,15 +136,12 @@ def run_stt_batch_submission(
from app.services.stt_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -200,15 +153,12 @@ def run_stt_metric_computation(
from app.services.stt_evaluations.metric_job import execute_metric_computation

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_metric_computation(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_metric_computation(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -220,15 +170,12 @@ def run_tts_batch_submission(
from app.services.tts_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -242,13 +189,10 @@ def run_tts_result_processing(
)

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_tts_result_processing(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_tts_result_processing(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)
Loading
Loading