diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a2167eb..c3bef3a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -42,8 +42,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6366d78..b89ef77 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,8 +64,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.github/workflows/verify_release.yml b/.github/workflows/verify_release.yml index 81d7c65..7c8ceb6 100644 --- a/.github/workflows/verify_release.yml +++ b/.github/workflows/verify_release.yml @@ -57,8 +57,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.gitignore b/.gitignore index 2511da8..164c3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -176,5 +176,8 @@ gocaves* .pytest_cache/ test_scripts/ -# rff +# ruff .ruff_cache/ + +# other +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6edfbdc..b9feb9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,8 @@ repos: - pytest~=8.3.5 - httpx~=0.28.1 - aiohttp~=3.11.10 + - sniffio~=1.3.1 + - anyio~=4.9.0 types: - python require_serial: true diff --git a/acouchbase_analytics/cluster.py b/acouchbase_analytics/cluster.py index 2cc0caf..bf07129 100644 --- a/acouchbase_analytics/cluster.py +++ b/acouchbase_analytics/cluster.py @@ -25,7 +25,8 @@ from typing import TypeAlias from acouchbase_analytics.database import AsyncDatabase -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult if TYPE_CHECKING: from couchbase_analytics.credential import Credential @@ -92,9 +93,7 @@ def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awai **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~couchbase_analytics.options.QueryOptions` Returns: - Future[:class:`~couchbase_analytics.result.AsyncQueryResult`]: A :class:`~asyncio.Future` is returned. - Once the :class:`~asyncio.Future` completes, an instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult` - is available to provide access to iterate over the query results and access metadata and metrics about the query. + :class:`~couchbase_analytics.result.AsyncQueryResult`: An instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult`. Examples: Simple query:: @@ -143,6 +142,22 @@ def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awai """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + """Executes a query against an Analytics cluster in async mode. + + .. seealso:: + :meth:`acouchbase_analytics.Scope.start_query`: For how to execute scope-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~acouchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~acouchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`: An instance of a :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) + async def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/acouchbase_analytics/cluster.pyi b/acouchbase_analytics/cluster.pyi index bea6643..1745c5f 100644 --- a/acouchbase_analytics/cluster.pyi +++ b/acouchbase_analytics/cluster.pyi @@ -21,10 +21,19 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.credential import Credential from acouchbase_analytics.database import AsyncDatabase -from couchbase_analytics.credential import Credential -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncCluster: @overload @@ -54,14 +63,34 @@ class AsyncCluster: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ... def shutdown(self) -> Awaitable[None]: ... @overload @classmethod diff --git a/acouchbase_analytics/errors.py b/acouchbase_analytics/errors.py index 03a6439..5d40e7a 100644 --- a/acouchbase_analytics/errors.py +++ b/acouchbase_analytics/errors.py @@ -18,4 +18,5 @@ from couchbase_analytics.common.errors import InternalSDKError as InternalSDKError # noqa: F401 from couchbase_analytics.common.errors import InvalidCredentialError as InvalidCredentialError # noqa: F401 from couchbase_analytics.common.errors import QueryError as QueryError # noqa: F401 +from couchbase_analytics.common.errors import QueryNotFoundError as QueryNotFoundError # noqa: F401 from couchbase_analytics.common.errors import TimeoutError as TimeoutError # noqa: F401 diff --git a/acouchbase_analytics/options.py b/acouchbase_analytics/options.py index bc8f846..47c432b 100644 --- a/acouchbase_analytics/options.py +++ b/acouchbase_analytics/options.py @@ -16,9 +16,13 @@ from couchbase_analytics.common.options import ClusterOptions as ClusterOptions # noqa: F401 from couchbase_analytics.common.options import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptions as FetchResultsOptions # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import QueryOptions as QueryOptions # noqa: F401 from couchbase_analytics.common.options import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import SecurityOptions as SecurityOptions # noqa: F401 from couchbase_analytics.common.options import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptions as StartQueryOptions # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import TimeoutOptions as TimeoutOptions # noqa: F401 from couchbase_analytics.common.options import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 diff --git a/acouchbase_analytics/protocol/_core/anyio_utils.py b/acouchbase_analytics/protocol/_core/anyio_utils.py index ce7a751..5a2e211 100644 --- a/acouchbase_analytics/protocol/_core/anyio_utils.py +++ b/acouchbase_analytics/protocol/_core/anyio_utils.py @@ -66,7 +66,7 @@ def current_async_library() -> Optional[AsyncBackend]: try: import sniffio except ImportError: - async_lib = 'asyncio' + return AsyncBackend('asyncio') try: async_lib = sniffio.current_async_library() diff --git a/acouchbase_analytics/protocol/_core/client_adapter.py b/acouchbase_analytics/protocol/_core/client_adapter.py index 6a6ac60..c6d00c5 100644 --- a/acouchbase_analytics/protocol/_core/client_adapter.py +++ b/acouchbase_analytics/protocol/_core/client_adapter.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import Optional, cast from uuid import uuid4 from httpx import URL, AsyncClient, BasicAuth, Response @@ -25,12 +25,10 @@ from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder -if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request import QueryRequest - class _AsyncClientAdapter: """ @@ -164,7 +162,7 @@ async def create_client(self) -> None: def log_message(self, message: str, log_level: LogLevel) -> None: log_message(logger, f'{self.log_prefix} {message}', log_level) - async def send_request(self, request: QueryRequest) -> Response: + async def send_request(self, request: HttpRequest, stream: Optional[bool] = True) -> Response: """ **INTERNAL** """ @@ -177,8 +175,18 @@ async def send_request(self, request: QueryRequest) -> Response: port=request.url.port, path=request.url.path, ) - req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) - return await self._client.send(req, stream=True) + if isinstance(request, (QueryRequest, StartQueryRequest)): + req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) + else: + data = request.data if isinstance(request, CancelRequest) else None + req = self._client.build_request( + request.method, url, data=data, headers=request.headers, extensions=request.extensions + ) + + if stream is None: + stream = True + + return await self._client.send(req, stream=stream) def reset_client(self) -> None: """ diff --git a/acouchbase_analytics/protocol/_core/request_context.py b/acouchbase_analytics/protocol/_core/request_context.py index f01eec1..a395251 100644 --- a/acouchbase_analytics/protocol/_core/request_context.py +++ b/acouchbase_analytics/protocol/_core/request_context.py @@ -36,20 +36,20 @@ from couchbase_analytics.common.errors import AnalyticsError from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.request import RequestState +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import DEFAULT_TIMEOUTS from couchbase_analytics.protocol.errors import ErrorMapper if TYPE_CHECKING: from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter - from couchbase_analytics.protocol._core.request import QueryRequest class AsyncRequestContext: def __init__( self, client_adapter: _AsyncClientAdapter, - request: QueryRequest, - stream_config: Optional[JsonStreamConfig] = None, + request: HttpRequest, + supports_cancellation: Optional[bool] = None, backend: Optional[AsyncBackend] = None, ) -> None: self._id = str(uuid4()) @@ -57,11 +57,11 @@ def __init__( self._request = request self._backend = backend or current_async_library() self._backoff_calc = DefaultBackoffCalculator() - self._error_ctx = ErrorContext(num_attempts=0, method=request.method, statement=request.get_request_statement()) + self._error_context = ErrorContext(num_attempts=0, method=request.method) + if isinstance(request, (QueryRequest, StartQueryRequest)): + self._error_context.set_statement(request.get_request_statement()) self._request_state = RequestState.NotStarted - self._stream_config = stream_config or JsonStreamConfig() - self._json_stream: AsyncJsonStream - self._stage_completed: Optional[anyio.Event] = None + self._supports_cancellation = False if supports_cancellation is None else supports_cancellation self._request_error: Optional[Union[BaseException, Exception]] = None connect_timeout = self._client_adapter.connection_details.get_connect_timeout() self._connect_deadline = get_time() + connect_timeout @@ -71,31 +71,19 @@ def __init__( @property def cancelled(self) -> bool: + if not self._supports_cancellation: + return False self._check_timed_out() return self._request_state in [RequestState.Cancelled, RequestState.AsyncCancelledPriorToTimeout] @property def error_context(self) -> ErrorContext: - return self._error_ctx - - @property - def has_stage_completed(self) -> bool: - return self._stage_completed is not None and self._stage_completed.is_set() + return self._error_context @property def is_shutdown(self) -> bool: return self._shutdown - @property - def okay_to_iterate(self) -> bool: - self._check_timed_out() - return RequestState.okay_to_iterate(self._request_state) - - @property - def okay_to_stream(self) -> bool: - self._check_timed_out() - return RequestState.okay_to_stream(self._request_state) - @property def request_error(self) -> Optional[Union[BaseException, Exception]]: return self._request_error @@ -106,139 +94,15 @@ def request_state(self) -> RequestState: @property def retry_limit_exceeded(self) -> bool: - return self.error_context.num_attempts > self._request.max_retries - - @property - def results_or_errors_type(self) -> ParsedResultType: - return self._json_stream.results_or_errors_type + return self._error_context.num_attempts > self._request.max_retries @property def timed_out(self) -> bool: self._check_timed_out() return self._request_state == RequestState.Timeout - def _check_timed_out(self) -> None: - if self._request_state in [RequestState.Timeout, RequestState.Cancelled, RequestState.Error]: - return - - if hasattr(self, '_request_deadline') is False: - return - - current_time = get_time() - timed_out = current_time >= self._request_deadline - if timed_out: - message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} - self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) - if self._request_state == RequestState.Cancelled: - self._request_state = RequestState.AsyncCancelledPriorToTimeout - else: - self._request_state = RequestState.Timeout - - async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None: - await fn(*args) - if self._stage_completed is not None: - self._stage_completed.set() - - def _maybe_set_request_error( - self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None - ) -> None: - self._check_timed_out() - if exc_val is None: - return - if not RequestState.is_timeout_or_cancelled(self._request_state): - # This handles httpx timeouts - if exc_type is not None and issubclass(exc_type, TimeoutException): - self._request_state = RequestState.Timeout - elif issubclass(type(exc_val), TimeoutException): - self._request_state = RequestState.Timeout - elif isinstance(exc_val, CancelledError): - self._request_state = RequestState.Cancelled - else: - self._request_state = RequestState.Error - self._request_error = exc_val - - async def _process_error( - self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False - ) -> None: - self._request_state = RequestState.Error - if isinstance(json_data, str): - self._request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_ctx) - elif not isinstance(json_data, list): - self._request_error = AnalyticsError( - 'Cannot parse error response; expected JSON array', context=str(self._error_ctx) - ) - else: - self._request_error = ErrorMapper.build_error_from_json(json_data, self._error_ctx) - if handle_context_shutdown is True: - await self.reraise_after_shutdown(self._request_error) - - raise self._request_error - - def _reset_stream(self) -> None: - if hasattr(self, '_json_stream'): - del self._json_stream - self._request_state = RequestState.ResetAndNotStarted - self._stage_completed = None - self._cancel_scope_deadline_updated = False - - def _start_next_stage( - self, fn: Callable[..., Awaitable[Any]], *args: object, reset_previous_stage: Optional[bool] = False - ) -> None: - if self._stage_completed is not None: - if reset_previous_stage is True: - self._stage_completed = None - else: - raise RuntimeError('Task already running in this context.') - - self._stage_completed = anyio.Event() - self._taskgroup.start_soon(self._execute, fn, *args) - - async def _trace_handler(self, event_name: str, _: str) -> None: - if event_name == 'connection.connect_tcp.complete': - # after connection is established, we need to update the cancel_scope deadline to match the query_timeout - self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) - self._cancel_scope_deadline_updated = True - elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'): - # if the socket is reused, we won't get the connect_tcp.complete event, - # so the deadline at the next closest event - self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) - self._cancel_scope_deadline_updated = True - - def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool] = False) -> None: - new_deadline = deadline if is_absolute else get_time() + deadline - current_time = get_time() - if current_time >= new_deadline: - self.log_message( - 'Deadline already exceeded, cancelling request', - LogLevel.DEBUG, - message_data={ - 'current_time': f'{current_time}', - 'new_deadline': f'{new_deadline}', - }, - ) - self._taskgroup.cancel_scope.cancel() - else: - self.log_message( - f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}', - LogLevel.DEBUG, - ) - self._taskgroup.cancel_scope.deadline = new_deadline - - async def _wait_for_stage_to_complete(self) -> None: - if self._stage_completed is None: - return - await self._stage_completed.wait() - def calculate_backoff(self) -> float: - return self._backoff_calc.calculate_backoff(self._error_ctx.num_attempts) / 1000 - - def cancel_request(self, fn: Optional[Callable[..., Awaitable[Any]]] = None, *args: object) -> None: - if fn is not None: - self._taskgroup.start_soon(fn, *args) - if self._request_state == RequestState.Timeout: - return - self._taskgroup.cancel_scope.cancel() - self._request_state = RequestState.Cancelled + return self._backoff_calc.calculate_backoff(self._error_context.num_attempts) / 1000 def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object) -> Task[Any]: if self._backend is None or self._backend.backend_lib != 'asyncio': @@ -250,20 +114,6 @@ def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *arg self._response_task = task return task - def deserialize_result(self, result: bytes) -> Any: - return self._request.deserializer.deserialize(result) - - async def finish_processing_stream(self) -> None: - if not self.has_stage_completed: - await self._wait_for_stage_to_complete() - - while not self._json_stream.token_stream_exhausted: - self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - await self._wait_for_stage_to_complete() - - async def get_result_from_stream(self) -> ParsedResult: - return await self._json_stream.get_result() - async def initialize(self) -> None: if self._request_state == RequestState.ResetAndNotStarted: current_time = get_time() @@ -297,18 +147,11 @@ def log_message( message = f'{message}, {message_data_str}' self._client_adapter.log_message(message, log_level) - def maybe_continue_to_process_stream(self) -> None: - if not self.has_stage_completed: - return - - if self._json_stream.token_stream_exhausted: - return - - self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - def okay_to_delay_and_retry(self, delay: float) -> bool: - self._check_timed_out() - if self._request_state in [RequestState.Timeout, RequestState.Cancelled]: + # calling self.timed_out will call _check_timed_out, so we don't need to call it again + if self.timed_out: + return False + elif self._supports_cancellation and self._request_state == RequestState.Cancelled: return False current_time = get_time() @@ -331,37 +174,28 @@ def okay_to_delay_and_retry(self, delay: float) -> bool: } self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) return False - else: - self._reset_stream() - return True + elif self._supports_cancellation: + # _reset_stream() _should_ exist, but surround w/ try/except just in case + try: + self._reset_stream() # type: ignore[attr-defined] + except AttributeError: + pass # nosec + + return True async def process_response( self, - close_handler: Callable[[], Coroutine[Any, Any, None]], - raw_response: Optional[ParsedResult] = None, + core_response: HttpCoreResponse, + close_handler: Callable[[], Awaitable[None]], handle_context_shutdown: Optional[bool] = False, ) -> Any: - if raw_response is None: - raw_response = await self._json_stream.get_result() - if raw_response is None: - await close_handler() - raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) - ) - - if raw_response.value is None: - await close_handler() - raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) - ) - # we have all the data, close the core response/stream await close_handler() try: - json_response = json.loads(raw_response.value) + json_response = core_response.json() except json.JSONDecodeError: - await self._process_error(str(raw_response.value), handle_context_shutdown=handle_context_shutdown) + await self._process_error(str(core_response.text), handle_context_shutdown=handle_context_shutdown) else: if 'errors' in json_response: await self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) @@ -374,33 +208,42 @@ async def reraise_after_shutdown(self, err: Exception) -> None: await self.shutdown(type(ex), ex, ex.__traceback__) raise ex from None - async def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse: - self._error_ctx.update_num_attempts() + async def send_request( + self, enable_trace_handling: Optional[bool] = False, ignore_not_found_status: Optional[bool] = False + ) -> HttpCoreResponse: + self._error_context.update_num_attempts() ip = await get_request_ip_async(self._request.url.host, self._request.url.port, self.log_message) - if enable_trace_handling is True: - ( - self._request.update_url(ip, self._client_adapter.analytics_path).add_trace_to_extensions( - self._trace_handler - ) - ) + if self._request.path and not self._request.path.isspace(): + req_path = f'{self._request.path}' else: - self._request.update_url(ip, self._client_adapter.analytics_path) - self._error_ctx.update_request_context(self._request) + req_path = self._client_adapter.analytics_path + if enable_trace_handling is True and hasattr(self, '_trace_handler'): + self._request.update_url(ip, req_path).add_trace_to_extensions(self._trace_handler) + else: + self._request.update_url(ip, req_path) + self._error_context.update_request_context(self._request, path=req_path) message_data = { 'url': f'{self._request.url.get_formatted_url()}', - 'body': f'{self._request.body}', 'request_deadline': f'{self._request_deadline}', } + + if isinstance(self._request, (QueryRequest, StartQueryRequest)): + message_data['body'] = f'{self._request.body}' + + stream = hasattr(self._request, 'should_stream') and self._request.should_stream is True + message_data['streaming'] = str(stream) + self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) - response = await self._client_adapter.send_request(self._request) - self._error_ctx.update_response_context(response) + response = await self._client_adapter.send_request(self._request, stream=stream) + self._error_context.update_response_context(response) message_data = { 'status_code': f'{response.status_code}', - 'last_dispatched_to': f'{self._error_ctx.last_dispatched_to}', - 'last_dispatched_from': f'{self._error_ctx.last_dispatched_from}', + 'last_dispatched_to': f'{self._error_context.last_dispatched_to}', + 'last_dispatched_from': f'{self._error_context.last_dispatched_from}', 'request_deadline': f'{self._request_deadline}', } self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) + self._check_for_http_status_error(response.status_code, ignore_not_found_status=ignore_not_found_status) return response async def shutdown( @@ -422,21 +265,66 @@ async def shutdown( self._shutdown = True self.log_message('Request context shutdown complete', LogLevel.INFO) - def start_stream(self, core_response: HttpCoreResponse) -> None: - if hasattr(self, '_json_stream'): - self.log_message('JSON stream already exists', LogLevel.WARNING) + def _check_for_http_status_error(self, status_code: int, ignore_not_found_status: Optional[bool] = False) -> None: + ctx = str(self._error_context) + ErrorMapper.maybe_raise_error_from_status_code( + status_code, ctx, ignore_not_found_status=ignore_not_found_status + ) + + def _check_timed_out(self) -> None: + if self._request_state in (RequestState.Timeout, RequestState.Error): return - self._json_stream = AsyncJsonStream( - core_response.aiter_bytes(), stream_config=self._stream_config, logger_handler=self.log_message - ) - self._start_next_stage(self._json_stream.start_parsing) + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + return - async def wait_for_results_or_errors(self) -> None: - await self._json_stream.has_results_or_errors.wait() - if self._json_stream.results_or_errors_type == ParsedResultType.ROW: - # we move to iterating rows - self._request_state = RequestState.StreamingResults + if hasattr(self, '_request_deadline') is False: + return + + current_time = get_time() + timed_out = current_time >= self._request_deadline + if timed_out: + message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} + self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + self._request_state = RequestState.AsyncCancelledPriorToTimeout + else: + self._request_state = RequestState.Timeout + + def _maybe_set_request_error( + self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None + ) -> None: + self._check_timed_out() + if exc_val is None: + return + if not RequestState.is_timeout_or_cancelled(self._request_state): + # This handles httpx timeouts + if exc_type is not None and issubclass(exc_type, TimeoutException): + self._request_state = RequestState.Timeout + elif issubclass(type(exc_val), TimeoutException): + self._request_state = RequestState.Timeout + elif isinstance(exc_val, CancelledError): + self._request_state = RequestState.Cancelled + else: + self._request_state = RequestState.Error + self._request_error = exc_val + + async def _process_error( + self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False + ) -> None: + self._request_state = RequestState.Error + if isinstance(json_data, str): + self._request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_context) + elif not isinstance(json_data, list): + self._request_error = AnalyticsError( + 'Cannot parse error response; expected JSON array', context=str(self._error_context) + ) + else: + self._request_error = ErrorMapper.build_error_from_json(json_data, self._error_context) + if handle_context_shutdown is True: + await self.reraise_after_shutdown(self._request_error) + + raise self._request_error async def __aenter__(self) -> AsyncRequestContext: self._taskgroup = anyio.create_task_group() @@ -457,3 +345,177 @@ async def __aexit__( self._maybe_set_request_error(exc_type, exc_val) del self._taskgroup return None # noqa: B012 + + +class AsyncStreamingRequestContext(AsyncRequestContext): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request: Union[FetchResultsRequest, QueryRequest], + stream_config: Optional[JsonStreamConfig] = None, + backend: Optional[AsyncBackend] = None, + ) -> None: + super().__init__(client_adapter, request, supports_cancellation=True, backend=backend) + self._stream_config = stream_config or JsonStreamConfig() + self._json_stream: AsyncJsonStream + self._stage_completed: Optional[anyio.Event] = None + self._deserializer = request.deserializer + + @property + def has_stage_completed(self) -> bool: + return self._stage_completed is not None and self._stage_completed.is_set() + + @property + def okay_to_iterate(self) -> bool: + self._check_timed_out() + return RequestState.okay_to_iterate(self._request_state) + + @property + def okay_to_stream(self) -> bool: + self._check_timed_out() + return RequestState.okay_to_stream(self._request_state) + + @property + def results_or_errors_type(self) -> ParsedResultType: + return self._json_stream.results_or_errors_type + + def cancel_request(self, fn: Optional[Callable[..., Awaitable[Any]]] = None, *args: object) -> None: + if fn is not None: + self._taskgroup.start_soon(fn, *args) + if self._request_state == RequestState.Timeout: + return + self._taskgroup.cancel_scope.cancel() + self._request_state = RequestState.Cancelled + + def deserialize_result(self, result: bytes) -> Any: + if not self._deserializer: + raise RuntimeError('No deserializer found for this request context.') + return self._deserializer.deserialize(result) + + async def finish_processing_stream(self) -> None: + if not self.has_stage_completed: + await self._wait_for_stage_to_complete() + + while not self._json_stream.token_stream_exhausted: + self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) + await self._wait_for_stage_to_complete() + + async def get_result_from_stream(self) -> ParsedResult: + return await self._json_stream.get_result() + + def maybe_continue_to_process_stream(self) -> None: + if not self.has_stage_completed: + return + + if self._json_stream.token_stream_exhausted: + return + + self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) + + async def process_streaming_response( + self, + close_handler: Callable[[], Awaitable[None]], + raw_response: Optional[ParsedResult] = None, + handle_context_shutdown: Optional[bool] = False, + ) -> Any: + if raw_response is None: + raw_response = await self._json_stream.get_result() + if raw_response is None: + await close_handler() + raise AnalyticsError( + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) + ) + + if raw_response.value is None: + await close_handler() + raise AnalyticsError( + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) + ) + + # we have all the data, close the core response/stream + await close_handler() + + try: + json_response = json.loads(raw_response.value) + except json.JSONDecodeError: + await self._process_error(str(raw_response.value), handle_context_shutdown=handle_context_shutdown) + else: + if 'errors' in json_response: + await self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) + return json_response + + def start_stream(self, core_response: HttpCoreResponse) -> None: + if hasattr(self, '_json_stream'): + self.log_message('JSON stream already exists', LogLevel.WARNING) + return + + self._json_stream = AsyncJsonStream( + core_response.aiter_bytes(), stream_config=self._stream_config, logger_handler=self.log_message + ) + self._start_next_stage(self._json_stream.start_parsing) + + async def wait_for_results_or_errors(self) -> None: + await self._json_stream.has_results_or_errors.wait() + if self._json_stream.results_or_errors_type == ParsedResultType.ROW: + # we move to iterating rows + self._request_state = RequestState.StreamingResults + + async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None: + await fn(*args) + if self._stage_completed is not None: + self._stage_completed.set() + + def _reset_stream(self) -> None: + if hasattr(self, '_json_stream'): + del self._json_stream + self._request_state = RequestState.ResetAndNotStarted + self._stage_completed = None + self._cancel_scope_deadline_updated = False + + def _start_next_stage( + self, fn: Callable[..., Awaitable[Any]], *args: object, reset_previous_stage: Optional[bool] = False + ) -> None: + if self._stage_completed is not None: + if reset_previous_stage is True: + self._stage_completed = None + else: + raise RuntimeError('Task already running in this context.') + + self._stage_completed = anyio.Event() + self._taskgroup.start_soon(self._execute, fn, *args) + + async def _trace_handler(self, event_name: str, _: str) -> None: + if event_name == 'connection.connect_tcp.complete': + # after connection is established, we need to update the cancel_scope deadline to match the query_timeout + self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) + self._cancel_scope_deadline_updated = True + elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'): + # if the socket is reused, we won't get the connect_tcp.complete event, + # so the deadline at the next closest event + self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) + self._cancel_scope_deadline_updated = True + + def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool] = False) -> None: + new_deadline = deadline if is_absolute else get_time() + deadline + current_time = get_time() + if current_time >= new_deadline: + self.log_message( + 'Deadline already exceeded, cancelling request', + LogLevel.DEBUG, + message_data={ + 'current_time': f'{current_time}', + 'new_deadline': f'{new_deadline}', + }, + ) + self._taskgroup.cancel_scope.cancel() + else: + self.log_message( + f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}', + LogLevel.DEBUG, + ) + self._taskgroup.cancel_scope.deadline = new_deadline + + async def _wait_for_stage_to_complete(self) -> None: + if self._stage_completed is None: + return + await self._stage_completed.wait() diff --git a/acouchbase_analytics/protocol/_core/response.py b/acouchbase_analytics/protocol/_core/response.py new file mode 100644 index 0000000..6d30731 --- /dev/null +++ b/acouchbase_analytics/protocol/_core/response.py @@ -0,0 +1,136 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any, Optional + +from httpx import Response as HttpCoreResponse + +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.retries import AsyncRetryHandler +from couchbase_analytics.common._core.query import build_query_metadata +from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError +from couchbase_analytics.common.logging import LogLevel +from couchbase_analytics.common.query import QueryMetadata +from couchbase_analytics.protocol.errors import WrappedError + + +class AsyncHttpResponse: + def __init__( + self, + request_context: AsyncRequestContext, + has_no_body_response: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: + # Goal is to treat the AsyncHttpStreamingResponse as a "task group" + self._request_context = request_context + self._metadata: Optional[QueryMetadata] = None + self._core_response: HttpCoreResponse + self._json_response: Optional[Any] = None + self._has_no_body_response = has_no_body_response + self._request_id = request_id + + @property + def json_response(self) -> Optional[Any]: + """ + **INTERNAL** + """ + return self._json_response + + async def close(self) -> None: + """ + **INTERNAL** + """ + if hasattr(self, '_core_response'): + await self._core_response.aclose() + self._request_context.log_message('HTTP core response closed', LogLevel.INFO) + del self._core_response + + def get_metadata(self) -> QueryMetadata: + """ + **INTERNAL** + """ + if self._metadata is None: + raise RuntimeError('Query metadata is only available after all rows have been iterated.') + return self._metadata + + async def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: + """ + **INTERNAL** + """ + try: + self._metadata = QueryMetadata( + build_query_metadata( + json_data=json_data, + raw_metadata=raw_metadata, + request_id=self._request_id, + log_fn=self._request_context.log_message, + ) + ) + await self._request_context.shutdown() + except (AnalyticsError, ValueError) as err: + await self._request_context.reraise_after_shutdown(err) + except Exception as ex: + internal_err = InternalSDKError(cause=ex, message=str(ex), context=str(self._request_context.error_context)) + await self._request_context.reraise_after_shutdown(internal_err) + finally: + await self.close() + + @AsyncRetryHandler.with_retries + async def send_request(self) -> None: + """ + **INTERNAL** + """ + await self._request_context.initialize() + self._core_response = await self._request_context.send_request( + ignore_not_found_status=self._has_no_body_response + ) + if self._has_no_body_response is True: + await self._process_no_body_response() + return + await self._process_response() + + async def shutdown(self) -> None: + """ + **INTERNAL** + """ + await self.close() + await self._request_context.shutdown() + + async def _close_in_background(self) -> None: + """ + **INTERNAL** + """ + await self.close() + + async def _process_no_body_response(self) -> None: + status_code = self._core_response.status_code + await self.close() + if 200 <= status_code < 300 or status_code == 404: + await self._request_context.shutdown() + return + ctx = str(self._request_context.error_context) + raise WrappedError(AnalyticsError(context=ctx, message=f'Request failed with status {status_code}.')) + + async def _process_response(self) -> None: + """ + **INTERNAL** + """ + self._json_response = await self._request_context.process_response( + self._core_response, self.close, handle_context_shutdown=True + ) + await self.set_metadata(json_data=self._json_response) diff --git a/acouchbase_analytics/protocol/_core/retries.py b/acouchbase_analytics/protocol/_core/retries.py index 83423d1..4570225 100644 --- a/acouchbase_analytics/protocol/_core/retries.py +++ b/acouchbase_analytics/protocol/_core/retries.py @@ -18,7 +18,7 @@ from asyncio import CancelledError from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union from httpx import ConnectError, ConnectTimeout, CookieConflict, HTTPError, InvalidURL, ReadTimeout, StreamError @@ -29,19 +29,22 @@ from couchbase_analytics.protocol.errors import WrappedError if TYPE_CHECKING: - from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext + from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext + from acouchbase_analytics.protocol._core.response import AsyncHttpResponse from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +AsyncReqContext = Union['AsyncRequestContext', 'AsyncStreamingRequestContext'] +T = TypeVar('T', bound=Union['AsyncHttpResponse', 'AsyncHttpStreamingResponse']) + + class AsyncRetryHandler: """ **INTERNAL** """ @staticmethod - async def handle_httpx_retry( - ex: Union[ConnectError, ConnectTimeout], ctx: AsyncRequestContext - ) -> Optional[Exception]: + async def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: AsyncReqContext) -> Optional[Exception]: err_str = str(ex) if 'SSL:' in err_str: message = 'TLS connection error occurred.' @@ -64,7 +67,7 @@ async def handle_httpx_retry( return None @staticmethod - async def handle_retry(ex: WrappedError, ctx: AsyncRequestContext) -> Optional[Union[BaseException, Exception]]: + async def handle_retry(ex: WrappedError, ctx: AsyncReqContext) -> Optional[Union[BaseException, Exception]]: if ex.retriable is True: delay = ctx.calculate_backoff() err: Optional[Union[BaseException, Exception]] = None @@ -94,10 +97,10 @@ async def handle_retry(ex: WrappedError, ctx: AsyncRequestContext) -> Optional[U @staticmethod def with_retries( # noqa: C901 - fn: Callable[[AsyncHttpStreamingResponse], Coroutine[Any, Any, None]], - ) -> Callable[[AsyncHttpStreamingResponse], Coroutine[Any, Any, None]]: + fn: Callable[[T], Coroutine[Any, Any, None]], + ) -> Callable[[T], Coroutine[Any, Any, None]]: @wraps(fn) - async def wrapped_fn(self: AsyncHttpStreamingResponse) -> None: # noqa: C901 + async def wrapped_fn(self: T) -> None: # noqa: C901 while True: try: await fn(self) diff --git a/acouchbase_analytics/protocol/cluster.py b/acouchbase_analytics/protocol/cluster.py index f08519f..c28661b 100644 --- a/acouchbase_analytics/protocol/cluster.py +++ b/acouchbase_analytics/protocol/cluster.py @@ -27,15 +27,18 @@ from acouchbase_analytics.protocol._core.anyio_utils import current_async_library from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.result import AsyncQueryResult from couchbase_analytics.protocol._core.request import _RequestBuilder if TYPE_CHECKING: + from acouchbase_analytics.options import ClusterOptions + from couchbase_analytics.common._core import JsonStreamConfig from couchbase_analytics.common.credential import Credential - from couchbase_analytics.options import ClusterOptions class AsyncCluster: @@ -106,16 +109,36 @@ async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQu return AsyncQueryResult(http_resp) def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryResult]: - base_req = self._request_builder.build_base_query_request(statement, *args, is_async=True, **kwargs) - stream_config = base_req.options.pop('stream_config', None) - request_context = AsyncRequestContext( - client_adapter=self.client_adapter, request=base_req, stream_config=stream_config, backend=self._backend + req = self._request_builder.build_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncStreamingRequestContext( + self.client_adapter, req, stream_config=stream_config, backend=self._backend ) resp = AsyncHttpStreamingResponse(request_context) if self._backend.backend_lib == 'asyncio': return request_context.create_response_task(self._execute_query, resp) return self._execute_query(resp) + async def _start_query( + self, http_resp: AsyncHttpResponse, stream_config: Optional[JsonStreamConfig] + ) -> AsyncQueryHandle: + if not self.has_client: + self.client_adapter.log_message( + 'Cluster does not have a connection. Creating the client.', LogLevel.WARNING + ) + await self._create_client() + await http_resp.send_request() + return AsyncQueryHandle(self._client_adapter, self._request_builder, http_resp, stream_config=stream_config) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncRequestContext(self.client_adapter, req, backend=self._backend) + resp = AsyncHttpResponse(request_context) + if self._backend.backend_lib == 'asyncio': + return request_context.create_response_task(self._start_query, resp, stream_config) + return self._start_query(resp, stream_config) + @classmethod def create_instance( cls, endpoint: str, credential: Credential, options: Optional[ClusterOptions] = None, **kwargs: object diff --git a/acouchbase_analytics/protocol/cluster.pyi b/acouchbase_analytics/protocol/cluster.pyi index 87cb2a1..3fd9f3a 100644 --- a/acouchbase_analytics/protocol/cluster.pyi +++ b/acouchbase_analytics/protocol/cluster.pyi @@ -21,11 +21,20 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.credential import Credential +from acouchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter from acouchbase_analytics.protocol.database import AsyncDatabase -from couchbase_analytics.common.credential import Credential -from couchbase_analytics.common.result import AsyncQueryResult -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncCluster: @overload @@ -56,14 +65,34 @@ class AsyncCluster: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ... @overload @classmethod def create_instance(cls, endpoint: str, credential: Credential) -> AsyncCluster: ... diff --git a/acouchbase_analytics/protocol/query_handle.py b/acouchbase_analytics/protocol/query_handle.py new file mode 100644 index 0000000..9a31803 --- /dev/null +++ b/acouchbase_analytics/protocol/query_handle.py @@ -0,0 +1,163 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +from couchbase_analytics.common._core.query_handle import QueryHandleStatusResponse +from couchbase_analytics.common.errors import AnalyticsError, QueryNotFoundError +from couchbase_analytics.common.query_handle import AsyncQueryHandle as _CoreAsyncQueryHandle +from couchbase_analytics.common.query_handle import AsyncQueryResultHandle as _CoreAsyncQueryResultHandle +from couchbase_analytics.common.query_handle import AsyncQueryStatus as _CoreAsyncQueryStatus +from couchbase_analytics.common.result import AsyncQueryResult +from couchbase_analytics.protocol._core.request import _RequestBuilder + +if TYPE_CHECKING: + from acouchbase_analytics.options import FetchResultsOptions + from couchbase_analytics.common._core import JsonStreamConfig + + +class AsyncQueryHandle(_CoreAsyncQueryHandle): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request_builder: _RequestBuilder, + http_response: AsyncHttpResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._http_response = http_response + self._stream_config = stream_config + self._request_id: str = '' + self._handle: str = '' + self._get_status_handle() + + async def fetch_status(self, options: Optional[Any] = None, **kwargs: Any) -> AsyncQueryStatus: + server_req = self._request_builder.build_request_from_handle(self._handle) + request_context = AsyncRequestContext(self._client_adapter, server_req) + resp = AsyncHttpResponse(request_context) + await resp.send_request() + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + status_response = self._get_handle_status_response(resp) + return AsyncQueryStatus( + self._client_adapter, + self._request_builder, + status_response, + stream_config=self._stream_config, + ) + + async def cancel(self, options: Optional[Any] = None, **kwargs: Any) -> None: + cancel_req = self._request_builder.build_cancel_request(self._request_id) + request_context = AsyncRequestContext(self._client_adapter, cancel_req) + resp = AsyncHttpResponse(request_context, has_no_body_response=True, request_id=self._request_id) + await resp.send_request() + + def _get_status_handle(self) -> None: + if self._http_response.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + request_id = self._http_response.json_response.get('requestID', None) + if request_id is None: + raise QueryNotFoundError(message='Server response is missing "requestID" field.') + handle = self._http_response.json_response.get('handle', None) + if handle is None: + raise QueryNotFoundError(message='Server response is missing "handle" field.') + + self._request_id = request_id + self._handle = handle + + def _get_handle_status_response(self, resp: AsyncHttpResponse) -> QueryHandleStatusResponse: + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + return QueryHandleStatusResponse.from_server(self._request_id, resp.json_response) + + +class AsyncQueryResultHandle(_CoreAsyncQueryResultHandle): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request_builder: _RequestBuilder, + request_id: str, + handle: str, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._request_id = request_id + self._handle = handle + self._stream_config = stream_config + + async def fetch_results(self, options: Optional[FetchResultsOptions] = None, **kwargs: Any) -> AsyncQueryResult: + server_req = self._request_builder.build_fetch_results_request(self._handle, options, **kwargs) + request_context = AsyncStreamingRequestContext( + self._client_adapter, server_req, stream_config=self._stream_config + ) + resp = AsyncHttpStreamingResponse(request_context, request_id=self._request_id) + await resp.send_request() + return AsyncQueryResult(resp) + + async def discard_results(self, options: Optional[Any] = None, **kwargs: Any) -> None: + req = self._request_builder.build_discard_results_request(self._handle) + request_context = AsyncRequestContext(self._client_adapter, req) + resp = AsyncHttpResponse(request_context, has_no_body_response=True, request_id=self._request_id) + await resp.send_request() + + +class AsyncQueryStatus(_CoreAsyncQueryStatus): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request_builder: _RequestBuilder, + status_resp: QueryHandleStatusResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._status_resp = status_resp + self._stream_config = stream_config + + def results_ready(self) -> bool: + return self._status_resp.handle is not None + + def result_handle(self) -> AsyncQueryResultHandle: + if self._status_resp.handle is None: + raise AnalyticsError(message='Query is not ready. Handle is not available.') + + return AsyncQueryResultHandle( + self._client_adapter, + self._request_builder, + self._status_resp.request_id, + self._status_resp.handle, + stream_config=self._stream_config, + ) + + def __repr__(self) -> str: + return f'AsyncQueryStatus({self._status_resp.get_details()})' + + def __str__(self) -> str: + return self.__repr__() diff --git a/acouchbase_analytics/protocol/scope.py b/acouchbase_analytics/protocol/scope.py index cd97f5f..c41bc4a 100644 --- a/acouchbase_analytics/protocol/scope.py +++ b/acouchbase_analytics/protocol/scope.py @@ -17,7 +17,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Awaitable +from typing import TYPE_CHECKING, Awaitable, Optional if sys.version_info < (3, 10): from typing_extensions import TypeAlias @@ -26,14 +26,17 @@ from acouchbase_analytics.protocol._core.anyio_utils import current_async_library from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +from acouchbase_analytics.result import AsyncQueryResult from couchbase_analytics.common.logging import LogLevel -from couchbase_analytics.common.result import AsyncQueryResult from couchbase_analytics.protocol._core.request import _RequestBuilder if TYPE_CHECKING: from acouchbase_analytics.protocol.database import AsyncDatabase + from couchbase_analytics.common._core import JsonStreamConfig class AsyncScope: @@ -73,15 +76,35 @@ async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQu return AsyncQueryResult(http_resp) def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryResult]: - base_req = self._request_builder.build_base_query_request(statement, *args, is_async=True, **kwargs) - stream_config = base_req.options.pop('stream_config', None) - request_context = AsyncRequestContext( - client_adapter=self.client_adapter, request=base_req, stream_config=stream_config, backend=self._backend + req = self._request_builder.build_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncStreamingRequestContext( + client_adapter=self.client_adapter, request=req, stream_config=stream_config, backend=self._backend ) resp = AsyncHttpStreamingResponse(request_context) if self._backend.backend_lib == 'asyncio': return request_context.create_response_task(self._execute_query, resp) return self._execute_query(resp) + async def _start_query( + self, http_resp: AsyncHttpResponse, stream_config: Optional[JsonStreamConfig] + ) -> AsyncQueryHandle: + if not self.client_adapter.has_client: + self.client_adapter.log_message( + 'Cluster does not have a connection. Creating the client.', LogLevel.WARNING + ) + await self._create_client() + await http_resp.send_request() + return AsyncQueryHandle(self.client_adapter, self._request_builder, http_resp, stream_config=stream_config) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncRequestContext(self.client_adapter, req, backend=self._backend) + resp = AsyncHttpResponse(request_context) + if self._backend.backend_lib == 'asyncio': + return request_context.create_response_task(self._start_query, resp, stream_config) + return self._start_query(resp, stream_config) + Scope: TypeAlias = AsyncScope diff --git a/acouchbase_analytics/protocol/scope.pyi b/acouchbase_analytics/protocol/scope.pyi index 87b1a52..90bbadd 100644 --- a/acouchbase_analytics/protocol/scope.pyi +++ b/acouchbase_analytics/protocol/scope.pyi @@ -21,10 +21,12 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter from acouchbase_analytics.protocol.database import AsyncDatabase as AsyncDatabase -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncScope: def __init__(self, database: AsyncDatabase, scope_name: str) -> None: ... @@ -52,3 +54,23 @@ class AsyncScope: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ... diff --git a/acouchbase_analytics/protocol/streaming.py b/acouchbase_analytics/protocol/streaming.py index 48d7e58..5705aa4 100644 --- a/acouchbase_analytics/protocol/streaming.py +++ b/acouchbase_analytics/protocol/streaming.py @@ -20,7 +20,7 @@ from httpx import Response as HttpCoreResponse -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncStreamingRequestContext from acouchbase_analytics.protocol._core.retries import AsyncRetryHandler from couchbase_analytics.common._core import ParsedResult, ParsedResultType from couchbase_analytics.common._core.query import build_query_metadata @@ -30,11 +30,12 @@ class AsyncHttpStreamingResponse: - def __init__(self, request_context: AsyncRequestContext) -> None: + def __init__(self, request_context: AsyncStreamingRequestContext, request_id: Optional[str] = None) -> None: self._metadata: Optional[QueryMetadata] = None self._core_response: HttpCoreResponse # Goal is to treat the AsyncHttpStreamingResponse as a "task group" self._request_context = request_context + self._request_id = request_id async def _close_in_background(self) -> None: """ @@ -68,7 +69,7 @@ async def _process_response( """ **INTERNAL** """ - json_response = await self._request_context.process_response( + json_response = await self._request_context.process_streaming_response( self.close, raw_response=raw_response, handle_context_shutdown=handle_context_shutdown ) await self.set_metadata(json_data=json_response) @@ -111,7 +112,14 @@ async def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Opti **INTERNAL** """ try: - self._metadata = QueryMetadata(build_query_metadata(json_data=json_data, raw_metadata=raw_metadata)) + self._metadata = QueryMetadata( + build_query_metadata( + json_data=json_data, + raw_metadata=raw_metadata, + request_id=self._request_id, + log_fn=self._request_context.log_message, + ) + ) await self._request_context.shutdown() except (AnalyticsError, ValueError) as err: await self._request_context.reraise_after_shutdown(err) diff --git a/couchbase_analytics/protocol/result.py b/acouchbase_analytics/query_handle.py similarity index 65% rename from couchbase_analytics/protocol/result.py rename to acouchbase_analytics/query_handle.py index 7165b68..3ca1fc3 100644 --- a/couchbase_analytics/protocol/result.py +++ b/acouchbase_analytics/query_handle.py @@ -13,5 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations +from couchbase_analytics.common.query_handle import AsyncQueryHandle as AsyncQueryHandle # noqa: F401 +from couchbase_analytics.common.query_handle import AsyncQueryResultHandle as AsyncQueryResultHandle # noqa: F401 +from couchbase_analytics.common.query_handle import AsyncQueryStatus as AsyncQueryStatus # noqa: F401 diff --git a/acouchbase_analytics/scope.py b/acouchbase_analytics/scope.py index 6d12ac3..6560187 100644 --- a/acouchbase_analytics/scope.py +++ b/acouchbase_analytics/scope.py @@ -17,15 +17,15 @@ from __future__ import annotations import sys -from asyncio import Future -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable if sys.version_info < (3, 10): from typing_extensions import TypeAlias else: from typing import TypeAlias -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult if TYPE_CHECKING: from acouchbase_analytics.protocol.database import AsyncDatabase @@ -44,7 +44,7 @@ def name(self) -> str: """ return self._impl.name - def execute_query(self, statement: str, *args: object, **kwargs: object) -> Future[AsyncQueryResult]: + def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryResult]: """Executes a query against an Analytics scope. .. note:: @@ -59,9 +59,7 @@ def execute_query(self, statement: str, *args: object, **kwargs: object) -> Futu **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~acouchbase_analytics.options.QueryOptions` Returns: - Future[:class:`~couchbase_analytics.result.AsyncQueryResult`]: A :class:`~asyncio.Future` is returned. - Once the :class:`~asyncio.Future` completes, an instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult` - is available to provide access to iterate over the query results and access metadata and metrics about the query. + :class:`~couchbase_analytics.result.AsyncQueryResult`: An instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult`. Examples: Simple query:: @@ -110,5 +108,21 @@ def execute_query(self, statement: str, *args: object, **kwargs: object) -> Futu """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + """Executes a query against an Analytics scope in async mode. + + .. seealso:: + :meth:`acouchbase_analytics.AsyncCluster.start_query`: For how to execute cluster-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~acouchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~acouchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`: An instance of a :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) + Scope: TypeAlias = AsyncScope diff --git a/acouchbase_analytics/scope.pyi b/acouchbase_analytics/scope.pyi index b02fa4d..6f02cbe 100644 --- a/acouchbase_analytics/scope.pyi +++ b/acouchbase_analytics/scope.pyi @@ -21,9 +21,11 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from acouchbase_analytics.protocol.database import AsyncDatabase as AsyncDatabase -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncScope: def __init__(self, database: AsyncDatabase, scope_name: str) -> None: ... @@ -41,11 +43,31 @@ class AsyncScope: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> Awaitable[AsyncQueryHandle]: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ... diff --git a/acouchbase_analytics/tests/connection_t.py b/acouchbase_analytics/tests/connection_t.py index 1567cfb..bf62ee1 100644 --- a/acouchbase_analytics/tests/connection_t.py +++ b/acouchbase_analytics/tests/connection_t.py @@ -67,7 +67,7 @@ def test_connstr_options_max_retries(self) -> None: connstr = f'https://localhost?max_retries={max_retries}' client = _AsyncClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') assert req.max_retries == max_retries @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_connstr_options_timeout(self, duration: str, expected_seconds: str) -> connstr = f'https://localhost?{to_query_str(opts)}' client = _AsyncClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') expected = float(expected_seconds) returned_timeout_opts = req.get_request_timeouts() assert isinstance(returned_timeout_opts, dict) diff --git a/acouchbase_analytics/tests/options_t.py b/acouchbase_analytics/tests/options_t.py index e26401f..6b3c844 100644 --- a/acouchbase_analytics/tests/options_t.py +++ b/acouchbase_analytics/tests/options_t.py @@ -21,16 +21,16 @@ import pytest -from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from couchbase_analytics.credential import Credential -from couchbase_analytics.deserializer import DefaultJsonDeserializer, Deserializer, PassthroughDeserializer -from couchbase_analytics.options import ( +from acouchbase_analytics.credential import Credential +from acouchbase_analytics.deserializer import DefaultJsonDeserializer, Deserializer, PassthroughDeserializer +from acouchbase_analytics.options import ( ClusterOptions, SecurityOptions, SecurityOptionsKwargs, TimeoutOptions, TimeoutOptionsKwargs, ) +from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter from tests.utils import get_test_cert_list, get_test_cert_path, get_test_cert_str TEST_CERT_PATH = get_test_cert_path() @@ -186,12 +186,16 @@ def test_security_options_invalid_kwargs(self, opts: Dict[str, object]) -> None: @pytest.mark.parametrize( 'opts, expected_opts', [ - ({}, None), ({'connect_timeout': timedelta(seconds=30)}, {'connect_timeout': 30}), + ({'handle_request_timeout': timedelta(seconds=20)}, {'handle_request_timeout': 20}), ({'query_timeout': timedelta(seconds=30)}, {'query_timeout': 30}), ( - {'connect_timeout': timedelta(seconds=60), 'query_timeout': timedelta(seconds=30)}, - {'connect_timeout': 60, 'query_timeout': 30}, + { + 'connect_timeout': timedelta(seconds=60), + 'handle_request_timeout': timedelta(seconds=20), + 'query_timeout': timedelta(seconds=30), + }, + {'connect_timeout': 60, 'handle_request_timeout': 20, 'query_timeout': 30}, ), ], ) @@ -204,10 +208,15 @@ def test_timeout_options(self, opts: TimeoutOptionsKwargs, expected_opts: Timeou 'opts, expected_opts', [ ({'connect_timeout': timedelta(seconds=30)}, {'connect_timeout': 30}), + ({'handle_request_timeout': timedelta(seconds=20)}, {'handle_request_timeout': 20}), ({'query_timeout': timedelta(seconds=30)}, {'query_timeout': 30}), ( - {'connect_timeout': timedelta(seconds=60), 'query_timeout': timedelta(seconds=30)}, - {'connect_timeout': 60, 'query_timeout': 30}, + { + 'connect_timeout': timedelta(seconds=60), + 'handle_request_timeout': timedelta(seconds=20), + 'query_timeout': timedelta(seconds=30), + }, + {'connect_timeout': 60, 'handle_request_timeout': 20, 'query_timeout': 30}, ), ], ) @@ -217,7 +226,12 @@ def test_timeout_options_kwargs(self, opts: Dict[str, object], expected_opts: Di assert expected_opts == client.connection_details.cluster_options.get('timeout_options') @pytest.mark.parametrize( - 'opts', [{'connect_timeout': timedelta(seconds=-1)}, {'query_timeout': timedelta(seconds=-1)}] + 'opts', + [ + {'connect_timeout': timedelta(seconds=-1)}, + {'handle_request_timeout': timedelta(seconds=-1)}, + {'query_timeout': timedelta(seconds=-1)}, + ], ) def test_timeout_options_must_be_positive(self, opts: TimeoutOptionsKwargs) -> None: cred = Credential.from_username_and_password('Administrator', 'password') @@ -225,7 +239,12 @@ def test_timeout_options_must_be_positive(self, opts: TimeoutOptionsKwargs) -> N _AsyncClientAdapter('https://localhost', cred, ClusterOptions(timeout_options=TimeoutOptions(**opts))) @pytest.mark.parametrize( - 'opts', [{'connect_timeout': timedelta(seconds=-1)}, {'query_timeout': timedelta(seconds=-1)}] + 'opts', + [ + {'connect_timeout': timedelta(seconds=-1)}, + {'handle_request_timeout': timedelta(seconds=-1)}, + {'query_timeout': timedelta(seconds=-1)}, + ], ) def test_timeout_options_must_be_positive_kwargs(self, opts: Dict[str, object]) -> None: cred = Credential.from_username_and_password('Administrator', 'password') diff --git a/acouchbase_analytics/tests/query_options_t.py b/acouchbase_analytics/tests/query_options_t.py index dea8e14..ebc0525 100644 --- a/acouchbase_analytics/tests/query_options_t.py +++ b/acouchbase_analytics/tests/query_options_t.py @@ -76,7 +76,7 @@ def test_options_deserializer( deserializer = DefaultJsonDeserializer() q_opts = QueryOptions(deserializer=deserializer) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -89,35 +89,35 @@ def test_options_deserializer_kwargs( deserializer = DefaultJsonDeserializer() kwargs: QueryOptionsKwargs = {'deserializer': deserializer} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer query_ctx.validate_query_context(req.body) - @pytest.mark.parametrize('max_retries', [5, 10, None]) + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) def test_options_max_retries( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] ) -> None: if max_retries is not None: q_opts = QueryOptions(max_retries=max_retries) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) query_ctx.validate_query_context(req.body) - @pytest.mark.parametrize('max_retries', [5, 10, None]) + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) def test_options_max_retries_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] ) -> None: if max_retries is not None: kwargs: QueryOptionsKwargs = {'max_retries': max_retries} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -128,7 +128,7 @@ def test_options_named_parameters( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} q_opts = QueryOptions(named_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -138,7 +138,7 @@ def test_options_named_parameters_kwargs( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} kwargs: QueryOptionsKwargs = {'named_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -148,7 +148,7 @@ def test_options_positional_parameters( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] q_opts = QueryOptions(positional_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -158,7 +158,7 @@ def test_options_positional_parameters_kwargs( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] kwargs: QueryOptionsKwargs = {'positional_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -167,7 +167,7 @@ def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} q_opts = QueryOptions(raw=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -178,7 +178,7 @@ def test_options_raw_kwargs( pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} kwargs: QueryOptionsKwargs = {'raw': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -187,7 +187,7 @@ def test_options_readonly( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(readonly=True) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -196,7 +196,7 @@ def test_options_readonly_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'readonly': True} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -207,7 +207,7 @@ def test_options_scan_consistency( from couchbase_analytics.query import QueryScanConsistency q_opts = QueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -218,7 +218,7 @@ def test_options_scan_consistency_kwargs( from couchbase_analytics.query import QueryScanConsistency kwargs: QueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -227,7 +227,7 @@ def test_options_timeout( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=20)) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -238,7 +238,7 @@ def test_options_timeout_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=20)} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -248,14 +248,14 @@ def test_options_timeout_kwargs( def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=-1)) with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, q_opts) + request_builder.build_query_request(query_statment, q_opts) def test_options_timeout_must_be_positive_kwargs( self, query_statment: str, request_builder: _RequestBuilder ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, **kwargs) + request_builder.build_query_request(query_statment, **kwargs) class ClusterQueryOptionsTests(QueryOptionsTestSuite): diff --git a/acouchbase_analytics/tests/start_query_integration_t.py b/acouchbase_analytics/tests/start_query_integration_t.py new file mode 100644 index 0000000..e388834 --- /dev/null +++ b/acouchbase_analytics/tests/start_query_integration_t.py @@ -0,0 +1,403 @@ +# Copyright 2016-2026. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import json +from datetime import timedelta +from typing import Any, Dict + +import pytest + +from acouchbase_analytics.deserializer import PassthroughDeserializer +from acouchbase_analytics.errors import AnalyticsError, QueryError, QueryNotFoundError, TimeoutError +from acouchbase_analytics.options import FetchResultsOptions, StartQueryOptions +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle, AsyncQueryStatus +from couchbase_analytics.common.request import RequestState +from tests import AsyncYieldFixture +from tests.environments.base_environment import AsyncTestEnvironment + + +class QueryTestSuite: + TEST_MANIFEST = [ + 'test_cancel_prior_iterating', + 'test_cancel_while_iterating', + 'test_query_metadata', + 'test_query_metadata_not_available', + 'test_query_named_parameters', + 'test_query_named_parameters_no_options', + 'test_query_named_parameters_override', + 'test_query_passthrough_deserializer', + 'test_query_positional_params', + 'test_query_positional_params_no_option', + 'test_query_positional_params_override', + 'test_query_raises_exception_prior_to_iterating', + 'test_query_raw_options', + 'test_query_results', + 'test_query_status_not_found', + 'test_query_status_prior_to_results', + 'test_query_timeout', + ] + + @pytest.fixture(scope='class') + def query_statement_limit2(self, test_env: AsyncTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_pos_params_limit2(self, test_env: AsyncTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} WHERE country = $1 LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} WHERE country = $1 LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_named_params_limit2(self, test_env: AsyncTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} WHERE country = $country LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} WHERE country = $country LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_limit5(self, test_env: AsyncTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} LIMIT 5;' + else: + return f'SELECT * FROM {test_env.fqdn} LIMIT 5;' + + async def test_cancel_prior_iterating(self, test_env: AsyncTestEnvironment) -> None: + statement = 'FROM range(0, 100000) AS r SELECT *' + q_handle = await test_env.cluster_or_scope.start_query(statement) + assert isinstance(q_handle, AsyncQueryHandle) + await q_handle.cancel() + + # it takes a moment for the cancellation to propagate, so we'll retry fetching + # status a few times until we get an exception + await AsyncTestEnvironment.try_n_times_till_exception(10, 2, q_handle.fetch_status) + + with pytest.raises(QueryError): + await q_handle.fetch_status() + + await q_handle.cancel() # should be idempotent and not raise + + async def test_cancel_while_iterating(self, test_env: AsyncTestEnvironment, query_statement_limit5: str) -> None: + q_handle = await test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + rows = [] + count = 0 + async for row in result.rows(): + if count == 2: + result.cancel() + assert row is not None + rows.append(row) + count += 1 + + assert len(rows) == count + expected_state = RequestState.Cancelled + assert result._http_response._request_context.request_state == expected_state + with pytest.raises(RuntimeError): + result.metadata() + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_metadata(self, test_env: AsyncTestEnvironment, query_statement_limit5: str) -> None: + q_handle = await test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + + expected_count = 5 + await test_env.assert_rows(result, expected_count) + + metadata = result.metadata() + + assert len(metadata.warnings()) == 0 + assert len(metadata.request_id()) > 0 + + metrics = metadata.metrics() + + assert metrics.result_size() > 0 + assert metrics.result_count() == expected_count + assert metrics.processed_objects() > 0 + # sometimes we have a negative elapsed time which we set to 0 + assert metrics.elapsed_time() >= timedelta(0) + assert metrics.execution_time() > timedelta(0) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_metadata_not_available( + self, test_env: AsyncTestEnvironment, query_statement_limit5: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + + with pytest.raises(RuntimeError): + result.metadata() + + # Read one row -- NOTE: anext()/aiter() added in Python 3.10 + aiter = result.rows() + row = await aiter.__anext__() + assert row is not None + assert isinstance(row, dict) + + with pytest.raises(RuntimeError): + result.metadata() + + # Iterate the rest of the rows + rows = [r async for r in result.rows()] + assert len(rows) == 4 + + metadata = result.metadata() + assert len(metadata.warnings()) == 0 + assert len(metadata.request_id()) > 0 + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_named_parameters( + self, + test_env: AsyncTestEnvironment, + query_statement_named_params_limit2: str, + ) -> None: + named_parameters: Dict[str, Any] = {'country': 'United States'} + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_named_params_limit2, StartQueryOptions(named_parameters=named_parameters) + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_named_parameters_no_options( + self, test_env: AsyncTestEnvironment, query_statement_named_params_limit2: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_named_params_limit2, country='United States' + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_named_parameters_override( + self, test_env: AsyncTestEnvironment, query_statement_named_params_limit2: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_named_params_limit2, + StartQueryOptions(named_parameters={'country': 'abcdefg'}), + country='United States', + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_passthrough_deserializer(self, test_env: AsyncTestEnvironment) -> None: + statement = 'FROM range(0, 10) AS num SELECT *' + q_handle = await test_env.cluster_or_scope.start_query(statement) + result_handle, _ = await test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + result = await result_handle.fetch_results(FetchResultsOptions(deserializer=PassthroughDeserializer())) + idx = 0 + async for row in result.rows(): + assert isinstance(row, bytes) + assert json.loads(row) == {'num': idx} + idx += 1 + test_env.assert_streaming_response_state(result) + await result_handle.discard_results() + + async def test_query_positional_params( + self, test_env: AsyncTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, StartQueryOptions(positional_parameters=['United States']) + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_positional_params_no_option( + self, test_env: AsyncTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query(query_statement_pos_params_limit2, 'United States') + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_positional_params_override( + self, test_env: AsyncTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, + StartQueryOptions(positional_parameters=['abcdefg']), + 'United States', + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_raises_exception_prior_to_iterating(self, test_env: AsyncTestEnvironment) -> None: + statement = "I'm not N1QL!" + with pytest.raises(QueryError): + await test_env.cluster_or_scope.start_query(statement) + + async def test_query_raw_options( + self, test_env: AsyncTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + if test_env.use_scope: + statement = f'SELECT * FROM {test_env.collection_name} WHERE country = $country LIMIT $1;' + else: + statement = f'SELECT * FROM {test_env.fqdn} WHERE country = $country LIMIT $1;' + + q_handle = await test_env.cluster_or_scope.start_query( + statement, StartQueryOptions(raw={'$country': 'United States', 'args': [2]}) + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + finally: + await result_handle.discard_results() + + q_handle = await test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, StartQueryOptions(raw={'args': ['United States']}) + ) + result_handle, result = await test_env.wait_for_query_results(q_handle) + try: + assert result is not None + await test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + finally: + await result_handle.discard_results() + + async def test_query_results(self, test_env: AsyncTestEnvironment, query_statement_limit5: str) -> None: + q_handle = await test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, _ = await test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + result = await result_handle.fetch_results() + await test_env.assert_rows(result, 5) + # fetch results again + result = await result_handle.fetch_results() + await test_env.assert_rows(result, 5) + # now discard results + await result_handle.discard_results() + # fetching results after discarding should raise + with pytest.raises(QueryNotFoundError): + await result_handle.fetch_results() + + async def test_query_status_not_found(self, test_env: AsyncTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 1000) AS some_field;' + q_handle = await test_env.cluster_or_scope.start_query(statement) + + result_handle, _ = await test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + await result_handle.discard_results() + + with pytest.raises(QueryNotFoundError): + await q_handle.fetch_status() + + async def test_query_status_prior_to_results(self, test_env: AsyncTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 1000) AS some_field;' + q_handle = await test_env.cluster_or_scope.start_query(statement) + assert isinstance(q_handle, AsyncQueryHandle) + q_status = await q_handle.fetch_status() + assert isinstance(q_status, AsyncQueryStatus) + assert q_status.results_ready() is False + with pytest.raises(AnalyticsError): + q_status.result_handle() + + # clean up + result_handle, _ = await test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + await result_handle.discard_results() + + async def test_query_timeout(self, test_env: AsyncTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 10000) AS some_field;' + q_handle = await test_env.cluster_or_scope.start_query( + statement, StartQueryOptions(timeout=timedelta(seconds=2)) + ) + await AsyncTestEnvironment.try_n_times_till_exception(10, 2, q_handle.fetch_status) + with pytest.raises(TimeoutError): + await q_handle.fetch_status() + + +class ClusterStartQueryTests(QueryTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ClusterStartQueryTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ClusterStartQueryTests) if valid_test_method(meth)] + test_list = set(QueryTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='test_env') + async def couchbase_test_environment( + self, async_test_env: AsyncTestEnvironment + ) -> AsyncYieldFixture[AsyncTestEnvironment]: + await async_test_env.setup() + yield async_test_env + await async_test_env.teardown() + + +class ScopeStartQueryTests(QueryTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ScopeStartQueryTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ScopeStartQueryTests) if valid_test_method(meth)] + test_list = set(QueryTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='test_env') + async def couchbase_test_environment( + self, async_test_env: AsyncTestEnvironment + ) -> AsyncYieldFixture[AsyncTestEnvironment]: + await async_test_env.setup() + test_env = async_test_env.enable_scope() + yield test_env + test_env.disable_scope() + await test_env.teardown() diff --git a/acouchbase_analytics/tests/start_query_options_t.py b/acouchbase_analytics/tests/start_query_options_t.py new file mode 100644 index 0000000..d255b89 --- /dev/null +++ b/acouchbase_analytics/tests/start_query_options_t.py @@ -0,0 +1,274 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List, Optional, Union + +import pytest + +from acouchbase_analytics import JSONType +from acouchbase_analytics.credential import Credential +from acouchbase_analytics.options import StartQueryOptions, StartQueryOptionsKwargs +from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter +from couchbase_analytics.protocol._core.request import _RequestBuilder +from couchbase_analytics.protocol.options import StartQueryOptionsTransformedKwargs + + +@dataclass +class QueryContext: + database_name: Optional[str] = None + scope_name: Optional[str] = None + + def validate_query_context(self, body: Dict[str, Union[str, object]]) -> None: + if self.database_name is None or self.scope_name is None: + with pytest.raises(KeyError): + body['query_context'] + else: + assert body['query_context'] == f'default:`{self.database_name}`.`{self.scope_name}`' + + +class StartQueryOptionsTestSuite: + TEST_MANIFEST = [ + 'test_options_max_retries', + 'test_options_max_retries_kwargs', + 'test_options_named_parameters', + 'test_options_named_parameters_kwargs', + 'test_options_positional_parameters', + 'test_options_positional_parameters_kwargs', + 'test_options_raw', + 'test_options_raw_kwargs', + 'test_options_readonly', + 'test_options_readonly_kwargs', + 'test_options_scan_consistency', + 'test_options_scan_consistency_kwargs', + 'test_options_timeout', + 'test_options_timeout_kwargs', + 'test_options_timeout_must_be_positive', + 'test_options_timeout_must_be_positive_kwargs', + ] + + @pytest.fixture(scope='class') + def query_statment(self) -> str: + return 'SELECT * FROM default' + + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) + def test_options_max_retries( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] + ) -> None: + if max_retries is not None: + q_opts = StartQueryOptions(max_retries=max_retries) + req = request_builder.build_start_query_request(query_statment, q_opts) + else: + req = request_builder.build_start_query_request(query_statment) + exp_opts: StartQueryOptionsTransformedKwargs = {} + assert req.options == exp_opts + assert req.max_retries == (max_retries if max_retries is not None else 7) + query_ctx.validate_query_context(req.body) + + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) + def test_options_max_retries_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] + ) -> None: + if max_retries is not None: + kwargs: StartQueryOptionsKwargs = {'max_retries': max_retries} + req = request_builder.build_start_query_request(query_statment, **kwargs) + else: + req = request_builder.build_start_query_request(query_statment) + exp_opts: StartQueryOptionsTransformedKwargs = {} + assert req.options == exp_opts + assert req.max_retries == (max_retries if max_retries is not None else 7) + query_ctx.validate_query_context(req.body) + + def test_options_named_parameters( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} + q_opts = StartQueryOptions(named_parameters=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'named_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_named_parameters_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} + kwargs: StartQueryOptionsKwargs = {'named_parameters': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'named_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_positional_parameters( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: List[JSONType] = ['foo', 'bar', 1, False] + q_opts = StartQueryOptions(positional_parameters=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'positional_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_positional_parameters_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: List[JSONType] = ['foo', 'bar', 1, False] + kwargs: StartQueryOptionsKwargs = {'positional_parameters': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'positional_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext) -> None: + pos_params: List[JSONType] = ['foo', 'bar', 1, False] + params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} + q_opts = StartQueryOptions(raw=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'raw': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_raw_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + pos_params: List[JSONType] = ['foo', 'bar', 1, False] + params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} + kwargs: StartQueryOptionsKwargs = {'raw': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'raw': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_readonly( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + q_opts = StartQueryOptions(readonly=True) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'readonly': True} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_readonly_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + kwargs: StartQueryOptionsKwargs = {'readonly': True} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'readonly': True} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_scan_consistency( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + from couchbase_analytics.query import QueryScanConsistency + + q_opts = StartQueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_scan_consistency_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + from couchbase_analytics.query import QueryScanConsistency + + kwargs: StartQueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_timeout( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + q_opts = StartQueryOptions(timeout=timedelta(seconds=20)) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'timeout': 20.0} + assert req.options == exp_opts + # NOTE: we add time to the server timeout to ensure a client side timeout + assert req.body['timeout'] == '25000.0ms' + query_ctx.validate_query_context(req.body) + + def test_options_timeout_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + kwargs: StartQueryOptionsKwargs = {'timeout': timedelta(seconds=20)} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'timeout': 20.0} + assert req.options == exp_opts + # NOTE: we add time to the server timeout to ensure a client side timeout + assert req.body['timeout'] == '25000.0ms' + query_ctx.validate_query_context(req.body) + + def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: + q_opts = StartQueryOptions(timeout=timedelta(seconds=-1)) + with pytest.raises(ValueError): + request_builder.build_start_query_request(query_statment, q_opts) + + def test_options_timeout_must_be_positive_kwargs( + self, query_statment: str, request_builder: _RequestBuilder + ) -> None: + kwargs: StartQueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} + with pytest.raises(ValueError): + request_builder.build_start_query_request(query_statment, **kwargs) + + +class ClusterStartQueryOptionsTests(StartQueryOptionsTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ClusterStartQueryOptionsTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ClusterStartQueryOptionsTests) if valid_test_method(meth)] + test_list = set(StartQueryOptionsTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='query_ctx') + def query_context(self) -> QueryContext: + return QueryContext() + + @pytest.fixture(scope='class') + def request_builder(self) -> _RequestBuilder: + cred = Credential.from_username_and_password('Administrator', 'password') + return _RequestBuilder(_AsyncClientAdapter('https://localhost', cred)) + + +class ScopeStartQueryOptionsTests(StartQueryOptionsTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ScopeStartQueryOptionsTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ScopeStartQueryOptionsTests) if valid_test_method(meth)] + test_list = set(StartQueryOptionsTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='query_ctx') + def query_context(self) -> QueryContext: + return QueryContext('test-database', 'test-scope') + + @pytest.fixture(scope='class') + def request_builder(self) -> _RequestBuilder: + cred = Credential.from_username_and_password('Administrator', 'password') + return _RequestBuilder(_AsyncClientAdapter('https://localhost', cred), 'test-database', 'test-scope') diff --git a/acouchbase_analytics/tests/test_server_t.py b/acouchbase_analytics/tests/test_server_t.py index 912f85f..66080d1 100644 --- a/acouchbase_analytics/tests/test_server_t.py +++ b/acouchbase_analytics/tests/test_server_t.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import timedelta -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import pytest @@ -132,13 +132,8 @@ async def test_error_retriable_http503(self, test_env: AsyncTestEnvironment, ana statement = 'SELECT "Hello, data!" AS greeting' allowed_retries = 5 q_opts = QueryOptions(max_retries=allowed_retries, timeout=timedelta(seconds=10)) - ex: Union[pytest.ExceptionInfo[AnalyticsError], pytest.ExceptionInfo[QueryError]] - if analytics_error: - with pytest.raises(QueryError) as ex: - await test_env.cluster_or_scope.execute_query(statement, q_opts) - else: - with pytest.raises(AnalyticsError) as ex: - await test_env.cluster_or_scope.execute_query(statement, q_opts) + with pytest.raises(AnalyticsError) as ex: + await test_env.cluster_or_scope.execute_query(statement, q_opts) test_env.assert_error_context_num_attempts(allowed_retries + 1, ex.value._context) test_env.assert_error_context_contains_last_dispatch(ex.value._context) diff --git a/conftest.py b/conftest.py index 8d72d1b..42a235f 100644 --- a/conftest.py +++ b/conftest.py @@ -38,6 +38,8 @@ 'couchbase_analytics/tests/options_t.py::ClusterOptionsTests', 'couchbase_analytics/tests/query_options_t.py::ClusterQueryOptionsTests', 'couchbase_analytics/tests/query_options_t.py::ScopeQueryOptionsTests', + 'couchbase_analytics/tests/start_query_options_t.py::ClusterStartQueryOptionsTests', + 'couchbase_analytics/tests/start_query_options_t.py::ScopeStartQueryOptionsTests', 'couchbase_analytics/tests/test_server_t.py::ClusterTestServerTests', 'couchbase_analytics/tests/test_server_t.py::ScopeTestServerTests', ] @@ -46,9 +48,13 @@ 'acouchbase_analytics/tests/connect_integration_t.py::ConnectTests', 'acouchbase_analytics/tests/query_integration_t.py::ClusterQueryTests', 'acouchbase_analytics/tests/query_integration_t.py::ScopeQueryTests', + 'acouchbase_analytics/tests/start_query_integration_t.py::ClusterStartQueryTests', + 'acouchbase_analytics/tests/start_query_integration_t.py::ScopeStartQueryTests', 'couchbase_analytics/tests/connect_integration_t.py::ConnectTests', 'couchbase_analytics/tests/query_integration_t.py::ClusterQueryTests', 'couchbase_analytics/tests/query_integration_t.py::ScopeQueryTests', + 'couchbase_analytics/tests/start_query_integration_t.py::ClusterStartQueryTests', + 'couchbase_analytics/tests/start_query_integration_t.py::ScopeStartQueryTests', ] diff --git a/couchbase_analytics/cluster.py b/couchbase_analytics/cluster.py index 8ca0784..9c133de 100644 --- a/couchbase_analytics/cluster.py +++ b/couchbase_analytics/cluster.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Optional, Union from couchbase_analytics.database import Database +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult if TYPE_CHECKING: @@ -139,6 +140,22 @@ def execute_query( """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + """Executes a query against an Analytics cluster in async mode. + + .. seealso:: + :meth:`couchbase_analytics.Scope.start_query`: For how to execute scope-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~couchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~couchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`: An instance of a :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) + def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/couchbase_analytics/cluster.pyi b/couchbase_analytics/cluster.pyi index 38f44ff..0b32759 100644 --- a/couchbase_analytics/cluster.pyi +++ b/couchbase_analytics/cluster.pyi @@ -25,7 +25,15 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.credential import Credential from couchbase_analytics.database import Database -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult class Cluster: @@ -114,6 +122,26 @@ class Cluster: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/common/_core/error_context.py b/couchbase_analytics/common/_core/error_context.py index 1356bc0..b43c1c6 100644 --- a/couchbase_analytics/common/_core/error_context.py +++ b/couchbase_analytics/common/_core/error_context.py @@ -17,11 +17,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from httpx import Response as HttpCoreResponse -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest @dataclass @@ -42,6 +42,9 @@ def set_errors(self, errors: List[Dict[str, Any]]) -> None: def set_first_error(self, error: Dict[str, Any]) -> None: self.first_error = error + def set_statement(self, statement: Optional[str]) -> None: + self.statement = statement + def maybe_update_errors(self) -> None: if self.errors is not None and len(self.errors) > 0: return @@ -51,8 +54,10 @@ def maybe_update_errors(self) -> None: def update_num_attempts(self) -> None: self.num_attempts += 1 - def update_request_context(self, request: QueryRequest) -> None: - self.path = request.url.path + def update_request_context( + self, request: Union[HttpRequest, FetchResultsRequest, QueryRequest], path: Optional[str] = None + ) -> None: + self.path = path or request.url.path def update_response_context(self, response: HttpCoreResponse) -> None: network_stream = response.extensions.get('network_stream', None) diff --git a/couchbase_analytics/common/_core/query.py b/couchbase_analytics/common/_core/query.py index 93c18d7..617a7fa 100644 --- a/couchbase_analytics/common/_core/query.py +++ b/couchbase_analytics/common/_core/query.py @@ -17,9 +17,10 @@ from __future__ import annotations import json -from typing import Any, List, Optional, TypedDict +from typing import Any, Callable, List, Optional, TypedDict from couchbase_analytics.common._core.duration_str_utils import parse_duration_str +from couchbase_analytics.common.logging import LogLevel class QueryMetricsCore(TypedDict, total=False): @@ -59,7 +60,25 @@ class QueryMetadataCore(TypedDict, total=False): status: Optional[str] -def build_query_metadata(json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> QueryMetadataCore: +def _parse_duration_metric(metrics: Any, field: str, log_fn: Optional[Callable[[str, LogLevel], None]] = None) -> float: + raw = metrics.get(field, '0') + try: + return parse_duration_str(raw, in_millis=True) + except ValueError: + if log_fn is not None: + log_fn( + f'Could not parse metrics field "{field}"; received value="{raw}". Defaulting to 0.', + LogLevel.WARNING, + ) + return 0.0 + + +def build_query_metadata( + json_data: Optional[Any] = None, + raw_metadata: Optional[bytes] = None, + request_id: Optional[str] = None, + log_fn: Optional[Callable[[str, LogLevel], None]] = None, +) -> QueryMetadataCore: """ Builds the query metadata from the raw bytes. @@ -83,7 +102,7 @@ def build_query_metadata(json_data: Optional[Any] = None, raw_metadata: Optional warnings.append({'code': warning.get('code', 0), 'message': warning.get('msg', '')}) metadata: QueryMetadataCore = { - 'request_id': json_data.get('requestID', ''), + 'request_id': json_data.get('requestID', request_id or ''), 'client_context_id': json_data.get('clientContextID', ''), 'warnings': warnings, } @@ -96,10 +115,10 @@ def build_query_metadata(json_data: Optional[Any] = None, raw_metadata: Optional return metadata metrics: QueryMetricsCore = { - 'elapsed_time': parse_duration_str(json_data['metrics'].get('elapsedTime', '0'), in_millis=True), - 'execution_time': parse_duration_str(json_data['metrics'].get('executionTime', '0'), in_millis=True), - 'compile_time': parse_duration_str(json_data['metrics'].get('compileTime', '0'), in_millis=True), - 'queue_wait_time': parse_duration_str(json_data['metrics'].get('queueWaitTime', '0'), in_millis=True), + 'elapsed_time': _parse_duration_metric(json_data['metrics'], 'elapsedTime', log_fn), + 'execution_time': _parse_duration_metric(json_data['metrics'], 'executionTime', log_fn), + 'compile_time': _parse_duration_metric(json_data['metrics'], 'compileTime', log_fn), + 'queue_wait_time': _parse_duration_metric(json_data['metrics'], 'queueWaitTime', log_fn), 'result_count': json_data['metrics'].get('resultCount', 0), 'result_size': json_data['metrics'].get('resultSize', 0), 'processed_objects': json_data['metrics'].get('processedObjects', 0), diff --git a/couchbase_analytics/common/_core/query_handle.py b/couchbase_analytics/common/_core/query_handle.py new file mode 100644 index 0000000..09b8e2c --- /dev/null +++ b/couchbase_analytics/common/_core/query_handle.py @@ -0,0 +1,111 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Awaitable, List, Mapping, Optional, TypedDict, Union + +from couchbase_analytics.common._core.result import QueryResult + + +class QueryHandle(ABC): + @abstractmethod + def cancel(self) -> Union[Awaitable[None], None]: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_status(self) -> Union[Awaitable[QueryStatus], QueryStatus]: + raise NotImplementedError + + +class QueryResultHandle(ABC): + """Abstract base class for query result handle.""" + + @abstractmethod + def fetch_results(self) -> Union[Awaitable[QueryResult], QueryResult]: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self) -> Union[Awaitable[None], None]: + """ + Discard the query results associated with the QueryResultHandle. + """ + raise NotImplementedError + + +class QueryStatus(ABC): + @abstractmethod + def results_ready(self) -> bool: + raise NotImplementedError + + @abstractmethod + def result_handle(self) -> QueryResultHandle: + raise NotImplementedError + + +class ResultPartition(TypedDict): + handle: str + result_count: Optional[int] + + +@dataclass +class QueryHandleStatusResponse: + """**INTERNAL**""" + + request_id: str + status: str + handle: Optional[str] = None + result_count: Optional[int] = None + partitions: Optional[List[ResultPartition]] = None + result_set_ordered: Optional[bool] = None + metrics: Optional[Mapping[str, Union[str, int]]] = None + created_at: Optional[str] = None + + def get_details(self) -> Mapping[str, Any]: + """**INTERNAL**""" + return { + 'request_id': self.request_id, + 'status': self.status, + 'handle': self.handle, + 'metrics': self.metrics, + } + + @classmethod + def from_server(cls, request_id: str, raw_json: Any) -> QueryHandleStatusResponse: + raw_partitions = raw_json.get('partitions', []) + partitions: list[ResultPartition] = [] + for partition in raw_partitions: + partitions.append( + {'handle': partition.get('handle', None), 'result_count': partition.get('resultCount', None)} + ) + return cls( + request_id, + raw_json.get('status', None), + raw_json.get('handle', None), + result_count=raw_json.get('resultCount', None), + partitions=partitions, + result_set_ordered=raw_json.get('resultSetOrdered', None), + metrics=raw_json.get('metrics', None), + created_at=raw_json.get('createdAt', None), + ) diff --git a/couchbase_analytics/common/_core/result.py b/couchbase_analytics/common/_core/result.py index 0146122..0dbfef5 100644 --- a/couchbase_analytics/common/_core/result.py +++ b/couchbase_analytics/common/_core/result.py @@ -18,7 +18,7 @@ import sys from abc import ABC, abstractmethod -from typing import Any, Coroutine, List, Optional, Union +from typing import Any, Awaitable, List, Optional, Union if sys.version_info < (3, 9): from typing import AsyncIterator as PyAsyncIterator @@ -34,7 +34,7 @@ class QueryResult(ABC): """Abstract base class for query results.""" @abstractmethod - def cancel(self) -> Union[Coroutine[Any, Any, None], None]: + def cancel(self) -> Union[Awaitable[None], None]: """ Cancel streaming the query results. @@ -43,7 +43,7 @@ def cancel(self) -> Union[Coroutine[Any, Any, None], None]: raise NotImplementedError @abstractmethod - def get_all_rows(self) -> Union[Coroutine[Any, Any, List[Any]], List[Any]]: + def get_all_rows(self) -> Union[Awaitable[List[Any]], List[Any]]: """Convenience method to load all query results into memory.""" raise NotImplementedError diff --git a/couchbase_analytics/common/errors.py b/couchbase_analytics/common/errors.py index 24a697d..96e3278 100644 --- a/couchbase_analytics/common/errors.py +++ b/couchbase_analytics/common/errors.py @@ -122,6 +122,29 @@ def __str__(self) -> str: return self.__repr__() +class QueryNotFoundError(AnalyticsError): + """ + Indicates that a request returned a not found status code. + """ + + def __init__( + self, + cause: Optional[Union[BaseException, Exception]] = None, + context: Optional[str] = None, + message: Optional[str] = None, + ) -> None: + super().__init__(cause=cause, context=context, message=message) + + def __repr__(self) -> str: + details = self._err_details() + if details: + return f'{type(self).__name__}({details})' + return f'{type(self).__name__}()' + + def __str__(self) -> str: + return self.__repr__() + + class TimeoutError(AnalyticsError): """ Indicates that a request was unable to complete prior to reaching the deadline specified for the reqest. diff --git a/couchbase_analytics/common/logging.py b/couchbase_analytics/common/logging.py index d599174..a2c0918 100644 --- a/couchbase_analytics/common/logging.py +++ b/couchbase_analytics/common/logging.py @@ -16,6 +16,7 @@ import logging from enum import Enum +from typing import Optional LOG_FORMAT_ARR = [ '[%(asctime)s.%(msecs)03d]', @@ -36,10 +37,27 @@ class LogLevel(Enum): CRITICAL = logging.CRITICAL +def _has_open_handlers(logger: logging.Logger) -> bool: + current: Optional[logging.Logger] = logger + while current is not None: + for handler in current.handlers: + if isinstance(handler, logging.StreamHandler): + if hasattr(handler.stream, 'closed') and handler.stream.closed: + continue + return True + if not current.propagate: + break + current = current.parent + return False + + def log_message(logger: logging.Logger, message: str, log_level: LogLevel) -> None: if not logger or not logger.hasHandlers(): return + if not _has_open_handlers(logger): + return + if log_level == LogLevel.DEBUG: logger.debug(message) elif log_level == LogLevel.INFO: diff --git a/couchbase_analytics/common/options.py b/couchbase_analytics/common/options.py index 5a680cc..5e7c897 100644 --- a/couchbase_analytics/common/options.py +++ b/couchbase_analytics/common/options.py @@ -26,13 +26,17 @@ from couchbase_analytics.common.options_base import ( ClusterOptionsBase, + FetchResultsOptionsBase, QueryOptionsBase, SecurityOptionsBase, + StartQueryOptionsBase, TimeoutOptionsBase, ) from couchbase_analytics.common.options_base import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options_base import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options_base import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 """ @@ -57,6 +61,14 @@ class ClusterOptions(ClusterOptionsBase): """ # noqa: E501 +class FetchResultsOptions(FetchResultsOptionsBase): + """Available options for Analytics asynchronous server query fetch results operation. + + Args: + deserializer (Optional[Deserializer]): Specifies a :class:`~couchbase_analytics.deserializer.Deserializer` to apply to results. Defaults to `None` (:class:`~couchbase_analytics.deserializer.DefaultJsonDeserializer`). + """ # noqa: E501 + + class SecurityOptions(SecurityOptionsBase): """Available security options to set when creating a cluster. @@ -134,6 +146,7 @@ class TimeoutOptions(TimeoutOptionsBase): Args: connect_timeout (Optional[timedelta]): Set to configure the period of time allowed to make a connection. Defaults to `None` (10s). + handle_request_timeout (Optional[timedelta]): Set to configure the period of time allowed for HTTP requests when using the server asynchronous requests API. Defaults to `None` (10s). query_timeout (Optional[timedelta]): Set to configure the period of time allowed for query operations. Defaults to `None` (10m). """ # noqa: E501 @@ -149,10 +162,32 @@ class QueryOptions(QueryOptionsBase): Args: client_context_id (Optional[str]): Set to configure a unique identifier for this query request. Defaults to `None` (autogenerated by client). deserializer (Optional[Deserializer]): Specifies a :class:`~couchbase_analytics.deserializer.Deserializer` to apply to results. Defaults to `None` (:class:`~couchbase_analytics.deserializer.DefaultJsonDeserializer`). - lazy_execute (Optional[bool]): **VOLATILE** If enabled, the query will not execute until the application begins to iterate over results. Defaulst to `None` (disabled). + lazy_execute (Optional[bool]): **VOLATILE** If enabled, the query will not execute until the application begins to iterate over results. Defaults to `None` (disabled). + max_retries (Optional[int]): **VOLATILE** Set to configure the maximum number of retries for a request. + named_parameters (Optional[Dict[str, :py:type:`~couchbase_analytics.JSONType`]]): Values to use for named placeholders in query. + positional_parameters (Optional[List[:py:type:`~couchbase_analytics.JSONType`]]): Values to use for positional placeholders in query. + query_context (Optional[str]): Specifies the context within which this query should be executed. + raw (Optional[Dict[str, Any]]): Specifies any additional parameters which should be passed to the Analytics engine when executing the query. + readonly (Optional[bool]): Specifies that this query should be executed in read-only mode, disabling the ability for the query to make any changes to the data. + scan_consistency (Optional[QueryScanConsistency]): Specifies the consistency requirements when executing the query. + timeout (Optional[timedelta]): Set to configure allowed time for operation to complete. Defaults to `None` (75s). + stream_config (Optional[JsonStreamConfig]): **VOLATILE** Configuration for JSON stream processing. Defaults to `None` (default configuration). See :class:`~couchbase_analytics.common.json_parsing.JsonStreamConfig` for details. + """ # noqa: E501 + + +class StartQueryOptions(StartQueryOptionsBase): + """Available options for Analytics asynchronous server query operation. + + Timeout will default to cluster setting if not set for the operation. + + .. note:: + Options marked **VOLATILE** are subject to change at any time. + + Args: + client_context_id (Optional[str]): Set to configure a unique identifier for this query request. Defaults to `None` (autogenerated by client). max_retries (Optional[int]): **VOLATILE** Set to configure the maximum number of retries for a request. - named_parameters (Optional[Dict[str, :py:type:`~couchbase_analytics.JSONType`]]): Values to use for positional placeholders in query. - positional_parameters (Optional[List[:py:type:`~couchbase_analytics.JSONType`]]):, optional): Values to use for named placeholders in query. + named_parameters (Optional[Dict[str, :py:type:`~couchbase_analytics.JSONType`]]): Values to use for named placeholders in query. + positional_parameters (Optional[List[:py:type:`~couchbase_analytics.JSONType`]]): Values to use for positional placeholders in query. query_context (Optional[str]): Specifies the context within which this query should be executed. raw (Optional[Dict[str, Any]]): Specifies any additional parameters which should be passed to the Analytics engine when executing the query. readonly (Optional[bool]): Specifies that this query should be executed in read-only mode, disabling the ability for the query to make any changes to the data. @@ -164,7 +199,9 @@ class QueryOptions(QueryOptionsBase): OptionsClass: TypeAlias = Union[ ClusterOptions, + FetchResultsOptions, SecurityOptions, TimeoutOptions, QueryOptions, + StartQueryOptions, ] diff --git a/couchbase_analytics/common/options_base.py b/couchbase_analytics/common/options_base.py index 1fd4811..5d73041 100644 --- a/couchbase_analytics/common/options_base.py +++ b/couchbase_analytics/common/options_base.py @@ -109,11 +109,13 @@ def __init__(self, **kwargs: Unpack[SecurityOptionsKwargs]) -> None: class TimeoutOptionsKwargs(TypedDict, total=False): connect_timeout: Optional[timedelta] + handle_request_timeout: Optional[timedelta] query_timeout: Optional[timedelta] TimeoutOptionsValidKeys: TypeAlias = Literal[ 'connect_timeout', + 'handle_request_timeout', 'query_timeout', ] @@ -125,6 +127,7 @@ class TimeoutOptionsBase(Dict[str, object]): VALID_OPTION_KEYS: List[TimeoutOptionsValidKeys] = [ 'connect_timeout', + 'handle_request_timeout', 'query_timeout', ] @@ -183,3 +186,66 @@ class QueryOptionsBase(Dict[str, object]): def __init__(self, **kwargs: Unpack[QueryOptionsKwargs]) -> None: filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} super().__init__(**filtered_kwargs) + + +class StartQueryOptionsKwargs(TypedDict, total=False): + client_context_id: Optional[str] + max_retries: Optional[int] + named_parameters: Optional[Dict[str, JSONType]] + positional_parameters: Optional[Iterable[JSONType]] + query_context: Optional[str] + raw: Optional[Dict[str, Any]] + readonly: Optional[bool] + scan_consistency: Optional[Union[QueryScanConsistency, str]] + stream_config: Optional[JsonStreamConfig] + timeout: Optional[timedelta] + + +StartQueryOptionsValidKeys: TypeAlias = Literal[ + 'client_context_id', + 'max_retries', + 'named_parameters', + 'positional_parameters', + 'query_context', + 'raw', + 'readonly', + 'scan_consistency', + 'stream_config', + 'timeout', +] + + +class StartQueryOptionsBase(Dict[str, object]): + VALID_OPTION_KEYS: List[StartQueryOptionsValidKeys] = [ + 'client_context_id', + 'max_retries', + 'named_parameters', + 'positional_parameters', + 'query_context', + 'raw', + 'readonly', + 'scan_consistency', + 'stream_config', + 'timeout', + ] + + def __init__(self, **kwargs: Unpack[StartQueryOptionsKwargs]) -> None: + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + super().__init__(**filtered_kwargs) + + +class FetchResultsOptionsKwargs(TypedDict, total=False): + deserializer: Optional[Deserializer] + + +FetchResultsOptionsValidKeys: TypeAlias = Literal['deserializer',] + + +class FetchResultsOptionsBase(Dict[str, object]): + VALID_OPTION_KEYS: List[FetchResultsOptionsValidKeys] = [ + 'deserializer', + ] + + def __init__(self, **kwargs: Unpack[FetchResultsOptionsKwargs]) -> None: + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + super().__init__(**filtered_kwargs) diff --git a/couchbase_analytics/common/query_handle.py b/couchbase_analytics/common/query_handle.py new file mode 100644 index 0000000..8c8bed7 --- /dev/null +++ b/couchbase_analytics/common/query_handle.py @@ -0,0 +1,108 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Awaitable, Optional + +from couchbase_analytics.common._core.query_handle import QueryHandle, QueryResultHandle, QueryStatus +from couchbase_analytics.common.options import FetchResultsOptions +from couchbase_analytics.common.result import AsyncQueryResult, BlockingQueryResult + + +class AsyncQueryHandle(QueryHandle): + @abstractmethod + def cancel(self, options: Optional[Any] = None, **kwargs: Any) -> Awaitable[None]: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_status(self, options: Optional[Any] = None, **kwargs: Any) -> Awaitable[AsyncQueryStatus]: + raise NotImplementedError + + +class BlockingQueryHandle(QueryHandle): + @abstractmethod + def cancel(self, options: Optional[Any] = None, **kwargs: Any) -> None: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_status(self, options: Optional[Any] = None, **kwargs: Any) -> BlockingQueryStatus: + raise NotImplementedError + + +class AsyncQueryResultHandle(QueryResultHandle): + """Abstract base class for async query result handle.""" + + @abstractmethod + def fetch_results( + self, options: Optional[FetchResultsOptions] = None, **kwargs: Any + ) -> Awaitable[AsyncQueryResult]: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self, options: Optional[Any] = None, **kwargs: Any) -> Awaitable[None]: + """ + Discard the query results associated with the AsyncQueryResultHandle. + """ + raise NotImplementedError + + +class BlockingQueryResultHandle(QueryResultHandle): + """Abstract base class for query result handle.""" + + @abstractmethod + def fetch_results(self, options: Optional[FetchResultsOptions] = None, **kwargs: Any) -> BlockingQueryResult: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self, options: Optional[Any] = None, **kwargs: Any) -> None: + """ + Discard the query results associated with the BlockingQueryResultHandle. + """ + raise NotImplementedError + + +class AsyncQueryStatus(QueryStatus): + @abstractmethod + def results_ready(self) -> bool: + raise NotImplementedError + + @abstractmethod + def result_handle(self) -> AsyncQueryResultHandle: + raise NotImplementedError + + +class BlockingQueryStatus(QueryStatus): + @abstractmethod + def results_ready(self) -> bool: + raise NotImplementedError + + @abstractmethod + def result_handle(self) -> BlockingQueryResultHandle: + raise NotImplementedError diff --git a/couchbase_analytics/errors.py b/couchbase_analytics/errors.py index 03a6439..5d40e7a 100644 --- a/couchbase_analytics/errors.py +++ b/couchbase_analytics/errors.py @@ -18,4 +18,5 @@ from couchbase_analytics.common.errors import InternalSDKError as InternalSDKError # noqa: F401 from couchbase_analytics.common.errors import InvalidCredentialError as InvalidCredentialError # noqa: F401 from couchbase_analytics.common.errors import QueryError as QueryError # noqa: F401 +from couchbase_analytics.common.errors import QueryNotFoundError as QueryNotFoundError # noqa: F401 from couchbase_analytics.common.errors import TimeoutError as TimeoutError # noqa: F401 diff --git a/couchbase_analytics/options.py b/couchbase_analytics/options.py index bc8f846..47c432b 100644 --- a/couchbase_analytics/options.py +++ b/couchbase_analytics/options.py @@ -16,9 +16,13 @@ from couchbase_analytics.common.options import ClusterOptions as ClusterOptions # noqa: F401 from couchbase_analytics.common.options import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptions as FetchResultsOptions # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import QueryOptions as QueryOptions # noqa: F401 from couchbase_analytics.common.options import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import SecurityOptions as SecurityOptions # noqa: F401 from couchbase_analytics.common.options import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptions as StartQueryOptions # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import TimeoutOptions as TimeoutOptions # noqa: F401 from couchbase_analytics.common.options import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 diff --git a/couchbase_analytics/protocol/_core/client_adapter.py b/couchbase_analytics/protocol/_core/client_adapter.py index e620ac8..04fc078 100644 --- a/couchbase_analytics/protocol/_core/client_adapter.py +++ b/couchbase_analytics/protocol/_core/client_adapter.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import Optional, cast from uuid import uuid4 from httpx import URL, BasicAuth, Client, Response @@ -25,12 +25,10 @@ from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder -if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request import QueryRequest - class _ClientAdapter: """ @@ -162,7 +160,7 @@ def create_client(self) -> None: def log_message(self, message: str, log_level: LogLevel) -> None: log_message(logger, f'{self.log_prefix} {message}', log_level) - def send_request(self, request: QueryRequest) -> Response: + def send_request(self, request: HttpRequest, stream: Optional[bool] = True) -> Response: """ **INTERNAL** """ @@ -170,8 +168,17 @@ def send_request(self, request: QueryRequest) -> Response: raise RuntimeError('Client not created yet') url = URL(scheme=request.url.scheme, host=request.url.ip, port=request.url.port, path=request.url.path) - req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) - return self._client.send(req, stream=True) + if isinstance(request, (QueryRequest, StartQueryRequest)): + req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) + else: + data = request.data if isinstance(request, CancelRequest) else None + req = self._client.build_request( + request.method, url, data=data, headers=request.headers, extensions=request.extensions + ) + + if stream is None: + stream = True + return self._client.send(req, stream=stream) def reset_client(self) -> None: """ diff --git a/couchbase_analytics/protocol/_core/json_stream.py b/couchbase_analytics/protocol/_core/json_stream.py index f1bc6d5..0fc975a 100644 --- a/couchbase_analytics/protocol/_core/json_stream.py +++ b/couchbase_analytics/protocol/_core/json_stream.py @@ -30,7 +30,7 @@ from couchbase_analytics.protocol._core.json_token_parser import JsonTokenParser if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request_context import RequestContext + from couchbase_analytics.protocol._core.request_context import StreamingRequestContext class JsonStream: @@ -80,7 +80,7 @@ def token_stream_exhausted(self) -> bool: """ return self._token_stream_exhausted - def _continue_processing(self, request_context: Optional[RequestContext] = None) -> bool: + def _continue_processing(self, request_context: Optional[StreamingRequestContext] = None) -> bool: """ **INTERNAL** """ @@ -125,7 +125,7 @@ def _log_message(self, message: str, level: LogLevel) -> None: if self._log_handler is not None: self._log_handler(message, level) - def _process_token_stream(self, request_context: Optional[RequestContext] = None) -> None: + def _process_token_stream(self, request_context: Optional[StreamingRequestContext] = None) -> None: """ **INTERNAL** """ @@ -207,7 +207,7 @@ def get_result(self, timeout: float) -> Optional[ParsedResult]: def start_parsing( self, - request_context: Optional[RequestContext] = None, + request_context: Optional[StreamingRequestContext] = None, notify_on_results_or_error: Optional[Future[ParsedResultType]] = None, ) -> None: if self._json_stream_parser is not None: @@ -218,6 +218,6 @@ def start_parsing( def continue_parsing( self, - request_context: Optional[RequestContext] = None, + request_context: Optional[StreamingRequestContext] = None, ) -> None: self._process_token_stream(request_context=request_context) diff --git a/couchbase_analytics/protocol/_core/request.py b/couchbase_analytics/protocol/_core/request.py index c2cc5f1..9b590a8 100644 --- a/couchbase_analytics/protocol/_core/request.py +++ b/couchbase_analytics/protocol/_core/request.py @@ -18,13 +18,27 @@ from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional, TypedDict, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + List, + Mapping, + Optional, + TypedDict, + Union, + cast, + overload, +) +from urllib.parse import urlparse from uuid import uuid4 from couchbase_analytics.common.deserializer import Deserializer -from couchbase_analytics.common.options import QueryOptions +from couchbase_analytics.common.options import FetchResultsOptions, QueryOptions, StartQueryOptions from couchbase_analytics.common.request import RequestURL -from couchbase_analytics.protocol.options import QueryOptionsTransformedKwargs +from couchbase_analytics.protocol.options import QueryOptionsTransformedKwargs, StartQueryOptionsTransformedKwargs from couchbase_analytics.query import QueryScanConsistency if TYPE_CHECKING: @@ -46,20 +60,17 @@ class RequestExtensions(TypedDict, total=False): @dataclass -class QueryRequest: +class HttpRequest: url: RequestURL - deserializer: Deserializer - body: Dict[str, Union[str, object]] extensions: RequestExtensions + path: str + method: str + headers: Mapping[str, str] max_retries: int - method: str = 'POST' - - options: Optional[QueryOptionsTransformedKwargs] = None - enable_cancel: Optional[bool] = None def add_trace_to_extensions( self, handler: Callable[[str, str], Union[None, Coroutine[Any, Any, None]]] - ) -> QueryRequest: + ) -> HttpRequest: """ **INTERNAL** """ @@ -68,14 +79,6 @@ def add_trace_to_extensions( self.extensions['trace'] = handler return self - def get_request_statement(self) -> Optional[str]: - """ - **INTERNAL** - """ - if 'statement' in self.body: - return cast(str, self.body['statement']) - return None - def get_request_timeouts(self) -> Optional[RequestTimeoutExtensions]: """ **INTERNAL** @@ -84,7 +87,7 @@ def get_request_timeouts(self) -> Optional[RequestTimeoutExtensions]: return {} return self.extensions['timeout'] - def update_url(self, ip: str, path: str) -> QueryRequest: + def update_url(self, ip: str, path: str) -> HttpRequest: """ **INTERNAL** """ @@ -93,6 +96,53 @@ def update_url(self, ip: str, path: str) -> QueryRequest: return self +class CancelRequestData(TypedDict): + request_id: str + + +@dataclass +class CancelRequest(HttpRequest): + data: CancelRequestData + + +@dataclass +class FetchResultsRequest(HttpRequest): + deserializer: Deserializer + should_stream: bool = True + + +@dataclass +class QueryRequest(HttpRequest): + deserializer: Deserializer + body: Dict[str, Union[str, object]] + options: Optional[QueryOptionsTransformedKwargs] = None + enable_cancel: Optional[bool] = None + should_stream: bool = True + + def get_request_statement(self) -> Optional[str]: + """ + **INTERNAL** + """ + if 'statement' in self.body: + return cast(str, self.body['statement']) + return None + + +@dataclass +class StartQueryRequest(HttpRequest): + body: Dict[str, Union[str, object]] + options: Optional[StartQueryOptionsTransformedKwargs] = None + should_stream: bool = False + + def get_request_statement(self) -> Optional[str]: + """ + **INTERNAL** + """ + if 'statement' in self.body: + return cast(str, self.body['statement']) + return None + + class _RequestBuilder: def __init__( self, @@ -106,6 +156,7 @@ def __init__( self._scope_name = scope_name connect_timeout = self._conn_details.get_connect_timeout() + self._handle_request_timeout = self._conn_details.get_handle_request_timeout() self._default_query_timeout = self._conn_details.get_query_timeout() self._extensions: RequestExtensions = { 'timeout': {'pool': connect_timeout, 'connect': connect_timeout, 'read': self._default_query_timeout} @@ -113,13 +164,55 @@ def __init__( if self._conn_details.is_secure() and self._conn_details.sni_hostname is not None: self._extensions['sni_hostname'] = self._conn_details.sni_hostname - def build_base_query_request( # noqa: C901 + def build_request_from_handle(self, handle: str, method: Optional[str] = None) -> HttpRequest: + method = method or 'GET' + extensions = deepcopy(self._extensions) + extensions['timeout']['read'] = self._handle_request_timeout + max_retries = self._conn_details.get_max_retries() + parsed = urlparse(handle) + path = parsed.path if parsed.scheme else handle + return HttpRequest(self._conn_details.url, extensions, path, method=method, headers={}, max_retries=max_retries) + + def build_cancel_request(self, request_id: str) -> CancelRequest: + extensions = deepcopy(self._extensions) + extensions['timeout']['read'] = self._handle_request_timeout + max_retries = self._conn_details.get_max_retries() + return CancelRequest( + self._conn_details.url, + extensions, + '/api/v1/active_requests', + 'DELETE', + {'Content-Type': 'application/x-www-form-urlencoded'}, + max_retries, + {'request_id': request_id}, + ) + + def build_discard_results_request(self, handle: str) -> HttpRequest: + return self.build_request_from_handle(handle, method='DELETE') + + def build_fetch_results_request( + self, handle: str, options: Optional[FetchResultsOptions] = None, **kwargs: object + ) -> FetchResultsRequest: + q_opts = self._opts_builder.build_options(FetchResultsOptions, kwargs, options) + base_request = self.build_request_from_handle(handle) + deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer + max_retries = self._conn_details.get_max_retries() + return FetchResultsRequest( + base_request.url, + base_request.extensions, + base_request.path, + base_request.method, + {}, + max_retries, + deserializer, + ) + + def build_query_request( self, statement: str, *args: object, - is_async: Optional[bool] = False, **kwargs: object, - ) -> QueryRequest: # noqa: C901 + ) -> QueryRequest: enable_cancel: Optional[bool] = None cancel_kwarg_token = kwargs.pop('enable_cancel', None) if isinstance(cancel_kwarg_token, bool): @@ -138,21 +231,104 @@ def build_base_query_request( # noqa: C901 else: parsed_args_list.append(arg) + extensions, body, q_opts = self._get_query_request_details( + QueryOptions, opts, statement, parsed_args_list=parsed_args_list, **kwargs + ) + + # handle deserializer and max_retries + deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer + retries = q_opts.pop('max_retries', None) + max_retries = retries if retries is not None else self._conn_details.get_max_retries() + + return QueryRequest( + self._conn_details.url, + extensions, + '', + 'POST', + {}, + max_retries, + deserializer, + body, + options=q_opts, + enable_cancel=enable_cancel, + ) + + def build_start_query_request( # noqa: C901 + self, + statement: str, + *args: object, + **kwargs: object, + ) -> StartQueryRequest: # noqa: C901 + # default if no options provided + opts = StartQueryOptions() + args_list = list(args) + parsed_args_list = [] + for arg in args_list: + if isinstance(arg, StartQueryOptions): + # we have options passed in + opts = arg + else: + parsed_args_list.append(arg) + + extensions, body, q_opts = self._get_query_request_details( + StartQueryOptions, opts, statement, parsed_args_list=parsed_args_list, **kwargs + ) + + body['mode'] = 'async' + retries = q_opts.pop('max_retries', None) + max_retries = retries if retries is not None else self._conn_details.get_max_retries() + + return StartQueryRequest( + self._conn_details.url, + extensions, + '', + 'POST', + {}, + max_retries, + body, + options=q_opts, + ) + + @overload + def _get_query_request_details( + self, + option_type: type[QueryOptions], + query_opts: QueryOptions, + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> tuple[RequestExtensions, Dict[str, Union[str, object]], QueryOptionsTransformedKwargs]: ... + + @overload + def _get_query_request_details( + self, + option_type: type[StartQueryOptions], + query_opts: StartQueryOptions, + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> tuple[RequestExtensions, Dict[str, Union[str, object]], StartQueryOptionsTransformedKwargs]: ... + + def _get_query_request_details( # noqa: C901 + self, + option_type: Union[type[QueryOptions], type[StartQueryOptions]], + query_opts: Union[QueryOptions, StartQueryOptions], + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> Any: # noqa: C901 # need to pop out named params prior to sending options to the builder - named_param_keys = list(filter(lambda k: k not in QueryOptions.VALID_OPTION_KEYS, kwargs.keys())) + named_param_keys = list(filter(lambda k: k not in option_type.VALID_OPTION_KEYS, kwargs.keys())) named_params = {} for key in named_param_keys: named_params[key] = kwargs.pop(key) - q_opts = self._opts_builder.build_options(QueryOptions, QueryOptionsTransformedKwargs, kwargs, opts) + q_opts = self._opts_builder.build_options(option_type, kwargs, query_opts) # positional params and named params passed in outside of QueryOptions serve as overrides if parsed_args_list and len(parsed_args_list) > 0: q_opts['positional_parameters'] = parsed_args_list if named_params and len(named_params) > 0: q_opts['named_parameters'] = named_params - # handle deserializer and max_retries - deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer - max_retries = q_opts.pop('max_retries', None) or self._conn_details.get_max_retries() body: Dict[str, Union[str, object]] = { 'statement': statement, @@ -165,8 +341,11 @@ def build_base_query_request( # noqa: C901 # handle timeouts timeout = q_opts.get('timeout', None) or self._default_query_timeout extensions = deepcopy(self._extensions) - if timeout is not None and timeout != self._default_query_timeout: - extensions['timeout']['read'] = timeout + if option_type == QueryOptions: + if timeout is not None and timeout != self._default_query_timeout: + extensions['timeout']['read'] = timeout + else: + extensions['timeout']['read'] = self._handle_request_timeout # we add 5 seconds to the server timeout to ensure we always trigger a client side timeout timeout_ms = (timeout + 5) * 1e3 # convert to milliseconds body['timeout'] = f'{timeout_ms}ms' @@ -191,12 +370,4 @@ def build_base_query_request( # noqa: C901 else: body['scan_consistency'] = opt_val - return QueryRequest( - self._conn_details.url, - deserializer, - body, - extensions=extensions, - max_retries=max_retries, - options=q_opts, - enable_cancel=enable_cancel, - ) + return extensions, body, q_opts diff --git a/couchbase_analytics/protocol/_core/request_context.py b/couchbase_analytics/protocol/_core/request_context.py index feab717..239f7b5 100644 --- a/couchbase_analytics/protocol/_core/request_context.py +++ b/couchbase_analytics/protocol/_core/request_context.py @@ -21,7 +21,7 @@ import time from concurrent.futures import CancelledError, Future, ThreadPoolExecutor from threading import Event -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast from uuid import uuid4 from httpx import Response as HttpCoreResponse @@ -35,12 +35,12 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.json_stream import JsonStream from couchbase_analytics.protocol._core.net_utils import get_request_ip +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import DEFAULT_TIMEOUTS from couchbase_analytics.protocol.errors import ErrorMapper, WrappedError if TYPE_CHECKING: from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter - from couchbase_analytics.protocol._core.request import QueryRequest class BackgroundRequest: @@ -89,84 +89,210 @@ class RequestContext: def __init__( self, client_adapter: _ClientAdapter, - request: QueryRequest, - tp_executor: ThreadPoolExecutor, - stream_config: Optional[JsonStreamConfig] = None, + request: HttpRequest, + supports_cancellation: Optional[bool] = None, ) -> None: self._id = str(uuid4()) self._client_adapter = client_adapter self._request = request self._backoff_calc = DefaultBackoffCalculator() - self._error_ctx = ErrorContext(num_attempts=0, method=request.method, statement=request.get_request_statement()) + self._error_context = ErrorContext(num_attempts=0, method=request.method) + if isinstance(request, (QueryRequest, StartQueryRequest)): + self._error_context.set_statement(request.get_request_statement()) + self._supports_cancellation = False if supports_cancellation is None else supports_cancellation self._request_state = RequestState.NotStarted - self._stream_config = stream_config or JsonStreamConfig() - self._json_stream: JsonStream - self._cancel_event = Event() - self._tp_executor = tp_executor - self._stage_completed_ft: Optional[Future[Any]] = None - self._stage_notification_ft: Optional[Future[ParsedResultType]] = None + self._cancel_event: Optional[Event] = None self._request_deadline = math.inf self._background_request: Optional[BackgroundRequest] = None self._shutdown = False - - @property - def cancel_enabled(self) -> Optional[bool]: - return self._request.enable_cancel + if self._supports_cancellation: + self._cancel_event = Event() @property def cancelled(self) -> bool: + if not self._supports_cancellation: + return False self._check_cancelled_or_timed_out() return self._request_state in [RequestState.Cancelled, RequestState.SyncCancelledPriorToTimeout] @property def error_context(self) -> ErrorContext: - return self._error_ctx - - @property - def has_stage_completed(self) -> bool: - return self._stage_completed_ft is not None and self._stage_completed_ft.done() + return self._error_context @property def is_shutdown(self) -> bool: return self._shutdown - @property - def okay_to_iterate(self) -> bool: - # NOTE: Called prior to upstream logic attempting to iterate over results from HTTP client - self._check_cancelled_or_timed_out() - return RequestState.okay_to_iterate(self._request_state) - - @property - def okay_to_stream(self) -> bool: - # NOTE: Called prior to upstream logic attempting to send request to HTTP client - self._check_cancelled_or_timed_out() - return RequestState.okay_to_stream(self._request_state) - @property def request_state(self) -> RequestState: return self._request_state @property def retry_limit_exceeded(self) -> bool: - return self.error_context.num_attempts > self._request.max_retries + return self._error_context.num_attempts > self._request.max_retries @property def timed_out(self) -> bool: self._check_cancelled_or_timed_out() return self._request_state == RequestState.Timeout + def calculate_backoff(self) -> float: + return self._backoff_calc.calculate_backoff(self._error_context.num_attempts) / 1000 + + def initialize(self) -> None: + if self._request_state == RequestState.ResetAndNotStarted: + self.log_message( + 'Request is a retry, skipping initialization', + LogLevel.DEBUG, + message_data={'request_deadline': f'{self._request_deadline}'}, + ) + return + self._request_state = RequestState.Started + timeouts = self._request.get_request_timeouts() or {} + current_time = time.monotonic() + self._request_deadline = current_time + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout']) + message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} + self.log_message('Request context initialized', LogLevel.DEBUG, message_data=message_data) + + def log_message( + self, + message: str, + log_level: LogLevel, + message_data: Optional[Dict[str, str]] = None, + append_ctx: Optional[bool] = True, + ) -> None: + if append_ctx is True: + message = f'{message}: ctx={self._id}' + if message_data is not None: + message_data_str = ', '.join(f'{k}={v}' for k, v in message_data.items()) + message = f'{message}, {message_data_str}' + self._client_adapter.log_message(message, log_level) + + def okay_to_delay_and_retry(self, delay: float) -> bool: + # calling self.timed_out will call _check_cancelled_or_timed_out, so we don't need to call it again + if self.timed_out: + return False + elif self._supports_cancellation and self._request_state == RequestState.Cancelled: + return False + + current_time = time.monotonic() + delay_time = current_time + delay + will_time_out = self._request_deadline < delay_time + if will_time_out: + self._request_state = RequestState.Timeout + message_data = { + 'current_time': f'{current_time}', + 'delay_time': f'{delay_time}', + 'request_deadline': f'{self._request_deadline}', + } + self.log_message('Request will timeout after delay', LogLevel.DEBUG, message_data=message_data) + return False + elif self.retry_limit_exceeded: + self._request_state = RequestState.Error + message_data = { + 'num_attempts': f'{self.error_context.num_attempts}', + 'max_retries': f'{self._request.max_retries}', + } + self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) + return False + elif self._supports_cancellation: + # _reset_stream() _should_ exist, but surround w/ try/except just in case + try: + self._reset_stream() # type: ignore[attr-defined] + except AttributeError: + pass # nosec + + return True + + def process_response( + self, + core_response: HttpCoreResponse, + close_handler: Callable[[], None], + handle_context_shutdown: Optional[bool] = False, + ) -> Any: + # we have all the data, close the core response/stream + close_handler() + try: + json_response = core_response.json() + except json.JSONDecodeError: + self._process_error(core_response.text, handle_context_shutdown=handle_context_shutdown) + else: + if 'errors' in json_response: + self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) + return json_response + + def send_request( + self, enable_trace_handling: Optional[bool] = False, ignore_not_found_status: Optional[bool] = False + ) -> HttpCoreResponse: + self._error_context.update_num_attempts() + ip = get_request_ip(self._request.url.host, self._request.url.port, self.log_message) + + if self._request.path and not self._request.path.isspace(): + req_path = f'{self._request.path}' + else: + req_path = self._client_adapter.analytics_path + + if enable_trace_handling is True and hasattr(self, '_trace_handler'): + self._request.update_url(ip, req_path).add_trace_to_extensions(self._trace_handler) + else: + self._request.update_url(ip, req_path) + + self._error_context.update_request_context(self._request, path=req_path) + message_data = { + 'url': f'{self._request.url.get_formatted_url()}', + 'request_deadline': f'{self._request_deadline}', + } + + if isinstance(self._request, (QueryRequest, StartQueryRequest)): + message_data['body'] = f'{self._request.body}' + + stream = hasattr(self._request, 'should_stream') and self._request.should_stream is True + message_data['streaming'] = str(stream) + self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) + response = self._client_adapter.send_request(self._request, stream=stream) + self._error_context.update_response_context(response) + message_data = { + 'status_code': f'{response.status_code}', + 'last_dispatched_to': f'{self._error_context.last_dispatched_to}', + 'last_dispatched_from': f'{self._error_context.last_dispatched_from}', + 'request_deadline': f'{self._request_deadline}', + } + self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) + self._check_for_http_status_error(response.status_code, ignore_not_found_status=ignore_not_found_status) + return response + + def shutdown(self, exc_val: Optional[BaseException] = None) -> None: + if self.is_shutdown: + self.log_message('Request context already shutdown', LogLevel.WARNING) + return + if self._supports_cancellation and isinstance(exc_val, CancelledError): + self._request_state = RequestState.Cancelled + elif exc_val is not None: + # calling self.timed_out will call _check_cancelled_or_timed_out, so we don't need to call it again + is_timed_out = self.timed_out + is_cancelled = self._supports_cancellation and self._request_state in ( + RequestState.Cancelled, + RequestState.SyncCancelledPriorToTimeout, + ) + if not is_timed_out and not is_cancelled: + self._request_state = RequestState.Error + + if RequestState.is_okay(self._request_state): + self._request_state = RequestState.Completed + self._shutdown = True + self.log_message('Request context shutdown complete', LogLevel.INFO) + def _check_cancelled_or_timed_out(self) -> None: - if self._request_state in [RequestState.Timeout, RequestState.Cancelled, RequestState.Error]: + if self._request_state in (RequestState.Timeout, RequestState.Error): return - if self._cancel_event.is_set() or ( - self._background_request is not None and self._background_request.user_cancelled - ): + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + return + + if self._supports_cancellation and self._cancel_event and self._cancel_event.is_set(): self._request_state = RequestState.Cancelled if self._cancel_event.is_set(): self.log_message('Request has been cancelled', LogLevel.DEBUG) - elif self._background_request is not None and self._background_request.user_cancelled: - self.log_message('Request has been cancelled via user background request', LogLevel.DEBUG) return current_time = time.monotonic() @@ -174,16 +300,16 @@ def _check_cancelled_or_timed_out(self) -> None: if timed_out: message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) - if self._request_state == RequestState.Cancelled: + if self._supports_cancellation and self._request_state == RequestState.Cancelled: self._request_state = RequestState.SyncCancelledPriorToTimeout else: self._request_state = RequestState.Timeout - def _create_stage_notification_future(self) -> None: - # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future - if self._stage_notification_ft is not None: - raise RuntimeError('Stage notification future already created for this context.') - self._stage_notification_ft = Future[ParsedResultType]() + def _check_for_http_status_error(self, status_code: int, ignore_not_found_status: Optional[bool] = False) -> None: + ctx = str(self._error_context) + ErrorMapper.maybe_raise_error_from_status_code( + status_code, ctx, ignore_not_found_status=ignore_not_found_status + ) def _process_error( self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False @@ -191,57 +317,55 @@ def _process_error( self._request_state = RequestState.Error request_error: Union[AnalyticsError, WrappedError] if isinstance(json_data, str): - request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_ctx) + request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_context) elif not isinstance(json_data, list): request_error = AnalyticsError( - message='Cannot parse error response; expected JSON array', context=str(self._error_ctx) + message='Cannot parse error response; expected JSON array', context=str(self._error_context) ) else: - request_error = ErrorMapper.build_error_from_json(json_data, self._error_ctx) + request_error = ErrorMapper.build_error_from_json(json_data, self._error_context) if handle_context_shutdown is True: self.shutdown() raise request_error - def _reset_stream(self) -> None: - if hasattr(self, '_json_stream'): - del self._json_stream - self._request_state = RequestState.ResetAndNotStarted - self._stage_notification_ft = None - self.log_message('Request state has been reset', LogLevel.DEBUG) - def _start_next_stage( +class StreamingRequestContext(RequestContext): + def __init__( self, - fn: Callable[..., Any], - *args: object, - create_notification: Optional[bool] = False, - reset_previous_stage: Optional[bool] = False, + client_adapter: _ClientAdapter, + request: Union[FetchResultsRequest, QueryRequest], + tp_executor: ThreadPoolExecutor, + stream_config: Optional[JsonStreamConfig] = None, ) -> None: - if reset_previous_stage is True: - if self._stage_completed_ft is not None: - self._stage_completed_ft = None - elif self._stage_completed_ft is not None and not self._stage_completed_ft.done(): - raise RuntimeError('Future already running in this context.') - - kwargs: Dict[str, Union[RequestContext, Future[ParsedResultType]]] = {'request_context': self} - if create_notification is True: - self._create_stage_notification_future() - if self._stage_notification_ft is None: - raise RuntimeError('Unable to create stage notification future.') - kwargs['notify_on_results_or_error'] = self._stage_notification_ft + super().__init__(client_adapter, request, supports_cancellation=True) + self._stream_config = stream_config or JsonStreamConfig() + self._json_stream: JsonStream + self._tp_executor = tp_executor + self._stage_completed_ft: Optional[Future[Any]] = None + self._stage_notification_ft: Optional[Future[ParsedResultType]] = None + self._deserializer = request.deserializer - self._stage_completed_ft = self._tp_executor.submit(fn, *args, **kwargs) + @property + def cancel_enabled(self) -> Optional[bool]: + if not isinstance(self._request, QueryRequest): + return None + return self._request.enable_cancel - def _trace_handler(self, event_name: str, _: str) -> None: - if event_name == 'connection.connect_tcp.complete': - pass + @property + def has_stage_completed(self) -> bool: + return self._stage_completed_ft is not None and self._stage_completed_ft.done() - def _wait_for_stage_completed(self) -> None: - if self._stage_completed_ft is None: - raise RuntimeError('Stage completed future not created for this context.') - self._stage_completed_ft.result() + @property + def okay_to_iterate(self) -> bool: + # NOTE: Called prior to upstream logic attempting to iterate over results from HTTP client + self._check_cancelled_or_timed_out() + return RequestState.okay_to_iterate(self._request_state) - def calculate_backoff(self) -> float: - return self._backoff_calc.calculate_backoff(self._error_ctx.num_attempts) / 1000 + @property + def okay_to_stream(self) -> bool: + # NOTE: Called prior to upstream logic attempting to send request to HTTP client + self._check_cancelled_or_timed_out() + return RequestState.okay_to_stream(self._request_state) def cancel_request(self) -> None: if self._request_state == RequestState.Timeout: @@ -249,7 +373,9 @@ def cancel_request(self) -> None: self._request_state = RequestState.Cancelled def deserialize_result(self, result: bytes) -> Any: - return self._request.deserializer.deserialize(result) + if not self._deserializer: + raise RuntimeError('No deserializer found for this request context.') + return self._deserializer.deserialize(result) def finish_processing_stream(self) -> None: if not self.has_stage_completed: @@ -264,35 +390,6 @@ def finish_processing_stream(self) -> None: def get_result_from_stream(self) -> Optional[ParsedResult]: return self._json_stream.get_result(self._stream_config.queue_timeout) - def initialize(self) -> None: - if self._request_state == RequestState.ResetAndNotStarted: - self.log_message( - 'Request is a retry, skipping initialization', - LogLevel.DEBUG, - message_data={'request_deadline': f'{self._request_deadline}'}, - ) - return - self._request_state = RequestState.Started - timeouts = self._request.get_request_timeouts() or {} - current_time = time.monotonic() - self._request_deadline = current_time + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout']) - message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} - self.log_message('Request context initialized', LogLevel.DEBUG, message_data=message_data) - - def log_message( - self, - message: str, - log_level: LogLevel, - message_data: Optional[Dict[str, str]] = None, - append_ctx: Optional[bool] = True, - ) -> None: - if append_ctx is True: - message = f'{message}: ctx={self._id}' - if message_data is not None: - message_data_str = ', '.join(f'{k}={v}' for k, v in message_data.items()) - message = f'{message}, {message_data_str}' - self._client_adapter.log_message(message, log_level) - def maybe_continue_to_process_stream(self) -> None: if not self.has_stage_completed: return @@ -306,36 +403,7 @@ def maybe_continue_to_process_stream(self) -> None: # NOTE: _start_next_stage injects the request context into args self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - def okay_to_delay_and_retry(self, delay: float) -> bool: - self._check_cancelled_or_timed_out() - if self._request_state in [RequestState.Timeout, RequestState.Cancelled]: - return False - - current_time = time.monotonic() - delay_time = current_time + delay - will_time_out = self._request_deadline < delay_time - if will_time_out: - self._request_state = RequestState.Timeout - message_data = { - 'current_time': f'{current_time}', - 'delay_time': f'{delay_time}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('Request will timeout after delay', LogLevel.DEBUG, message_data=message_data) - return False - elif self.retry_limit_exceeded: - self._request_state = RequestState.Error - message_data = { - 'num_attempts': f'{self.error_context.num_attempts}', - 'max_retries': f'{self._request.max_retries}', - } - self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) - return False - else: - self._reset_stream() - return True - - def process_response( + def process_streaming_response( self, close_handler: Callable[[], None], raw_response: Optional[ParsedResult] = None, @@ -346,13 +414,13 @@ def process_response( if raw_response is None: close_handler() raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) ) if raw_response.value is None: close_handler() raise AnalyticsError( - message='Received unexpected empty response value from JsonStream.', context=str(self._error_ctx) + message='Received unexpected empty response value from JsonStream.', context=str(self._error_context) ) # we have all the data, close the core response/stream @@ -366,35 +434,6 @@ def process_response( self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) return json_response - def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse: - self._error_ctx.update_num_attempts() - ip = get_request_ip(self._request.url.host, self._request.url.port, self.log_message) - if enable_trace_handling is True: - ( - self._request.update_url(ip, self._client_adapter.analytics_path).add_trace_to_extensions( - self._trace_handler - ) - ) - else: - self._request.update_url(ip, self._client_adapter.analytics_path) - self._error_ctx.update_request_context(self._request) - message_data = { - 'url': f'{self._request.url.get_formatted_url()}', - 'body': f'{self._request.body}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) - response = self._client_adapter.send_request(self._request) - self._error_ctx.update_response_context(response) - message_data = { - 'status_code': f'{response.status_code}', - 'last_dispatched_to': f'{self._error_ctx.last_dispatched_to}', - 'last_dispatched_from': f'{self._error_ctx.last_dispatched_from}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) - return response - def send_request_in_background( self, fn: Callable[..., BlockingQueryResult], @@ -405,32 +444,12 @@ def send_request_in_background( # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future user_ft = Future[BlockingQueryResult]() background_work_ft = self._tp_executor.submit(fn, *args) - self._background_request = BackgroundRequest(background_work_ft, user_ft, self._cancel_event) + self._background_request = BackgroundRequest(background_work_ft, user_ft, cast(Event, self._cancel_event)) return user_ft def set_state_to_streaming(self) -> None: self._request_state = RequestState.StreamingResults - def shutdown(self, exc_val: Optional[BaseException] = None) -> None: - if self.is_shutdown: - self.log_message('Request context already shutdown', LogLevel.WARNING) - return - if isinstance(exc_val, CancelledError): - self._request_state = RequestState.Cancelled - elif exc_val is not None: - self._check_cancelled_or_timed_out() - if self._request_state not in [ - RequestState.Timeout, - RequestState.Cancelled, - RequestState.SyncCancelledPriorToTimeout, - ]: - self._request_state = RequestState.Error - - if RequestState.is_okay(self._request_state): - self._request_state = RequestState.Completed - self._shutdown = True - self.log_message('Request context shutdown complete', LogLevel.INFO) - def start_stream(self, core_response: HttpCoreResponse) -> None: if hasattr(self, '_json_stream'): self.log_message('JSON stream already exists', LogLevel.WARNING) @@ -447,7 +466,9 @@ def wait_for_stage_notification(self) -> None: raise RuntimeError('Stage notification future not created for this context.') deadline = round(self._request_deadline - time.monotonic(), 6) # round to microseconds if deadline <= 0: - raise TimeoutError(message='Request timed out waiting for stage notification', context=str(self._error_ctx)) + raise TimeoutError( + message='Request timed out waiting for stage notification', context=str(self._error_context) + ) result_type = self._stage_notification_ft.result(timeout=deadline) if result_type == ParsedResultType.ROW: self.log_message('Received row, setting status to streaming', LogLevel.DEBUG) @@ -455,3 +476,47 @@ def wait_for_stage_notification(self) -> None: self._request_state = RequestState.StreamingResults else: self.log_message(f'Received result type {result_type.name}', LogLevel.DEBUG) + + def _create_stage_notification_future(self) -> None: + # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future + if self._stage_notification_ft is not None: + raise RuntimeError('Stage notification future already created for this context.') + self._stage_notification_ft = Future[ParsedResultType]() + + def _reset_stream(self) -> None: + if hasattr(self, '_json_stream'): + del self._json_stream + self._request_state = RequestState.ResetAndNotStarted + self._stage_notification_ft = None + self.log_message('Request state has been reset', LogLevel.DEBUG) + + def _start_next_stage( + self, + fn: Callable[..., Any], + *args: object, + create_notification: Optional[bool] = False, + reset_previous_stage: Optional[bool] = False, + ) -> None: + if reset_previous_stage is True: + if self._stage_completed_ft is not None: + self._stage_completed_ft = None + elif self._stage_completed_ft is not None and not self._stage_completed_ft.done(): + raise RuntimeError('Future already running in this context.') + + kwargs: Dict[str, Union[StreamingRequestContext, Future[ParsedResultType]]] = {'request_context': self} + if create_notification is True: + self._create_stage_notification_future() + if self._stage_notification_ft is None: + raise RuntimeError('Unable to create stage notification future.') + kwargs['notify_on_results_or_error'] = self._stage_notification_ft + + self._stage_completed_ft = self._tp_executor.submit(fn, *args, **kwargs) + + def _trace_handler(self, event_name: str, _: str) -> None: + if event_name == 'connection.connect_tcp.complete': + pass + + def _wait_for_stage_completed(self) -> None: + if self._stage_completed_ft is None: + raise RuntimeError('Stage completed future not created for this context.') + self._stage_completed_ft.result() diff --git a/couchbase_analytics/protocol/_core/response.py b/couchbase_analytics/protocol/_core/response.py new file mode 100644 index 0000000..d3bafac --- /dev/null +++ b/couchbase_analytics/protocol/_core/response.py @@ -0,0 +1,118 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any, Optional + +from httpx import Response as HttpCoreResponse + +from couchbase_analytics.common._core.query import build_query_metadata +from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError +from couchbase_analytics.common.logging import LogLevel +from couchbase_analytics.common.query import QueryMetadata +from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.retries import RetryHandler +from couchbase_analytics.protocol.errors import WrappedError + + +class HttpResponse: + def __init__( + self, + request_context: RequestContext, + has_no_body_response: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: + self._request_context = request_context + self._metadata: Optional[QueryMetadata] = None + self._core_response: HttpCoreResponse + self._json_response: Optional[Any] = None + self._has_no_body_response = has_no_body_response + self._request_id = request_id + + @property + def json_response(self) -> Optional[Any]: + """ + **INTERNAL** + """ + return self._json_response + + def close(self) -> None: + """ + **INTERNAL** + """ + if hasattr(self, '_core_response'): + self._core_response.close() + self._request_context.log_message('HTTP core response closed', LogLevel.INFO) + del self._core_response + + def get_metadata(self) -> QueryMetadata: + """ + **INTERNAL** + """ + if self._metadata is None: + raise RuntimeError('Query metadata is only available after response has been processed.') + return self._metadata + + def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: + """ + **INTERNAL** + """ + try: + self._metadata = QueryMetadata( + build_query_metadata( + json_data=json_data, + raw_metadata=raw_metadata, + request_id=self._request_id, + log_fn=self._request_context.log_message, + ) + ) + self._request_context.shutdown() + except (AnalyticsError, ValueError) as err: + self._request_context.shutdown(err) + raise err + except Exception as ex: + internal_err = InternalSDKError(cause=ex, message=str(ex), context=str(self._request_context.error_context)) + self._request_context.shutdown(internal_err) + finally: + self.close() + + @RetryHandler.with_retries + def send_request(self) -> None: + """ + **INTERNAL** + """ + self._request_context.initialize() + self._core_response = self._request_context.send_request(ignore_not_found_status=self._has_no_body_response) + if self._has_no_body_response is True: + self._process_no_body_response() + return + self._process_response() + + def _process_no_body_response(self) -> None: + status_code = self._core_response.status_code + self.close() + if 200 <= status_code < 300 or status_code == 404: + self._request_context.shutdown() + return + ctx = str(self._request_context.error_context) + raise WrappedError(AnalyticsError(context=ctx, message=f'Request failed with status {status_code}.')) + + def _process_response(self) -> None: + self._json_response = self._request_context.process_response( + self._core_response, self.close, handle_context_shutdown=True + ) + self.set_metadata(json_data=self._json_response) diff --git a/couchbase_analytics/protocol/_core/retries.py b/couchbase_analytics/protocol/_core/retries.py index c87fa43..4df1266 100644 --- a/couchbase_analytics/protocol/_core/retries.py +++ b/couchbase_analytics/protocol/_core/retries.py @@ -19,7 +19,7 @@ from concurrent.futures import CancelledError from functools import wraps from time import sleep -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union from httpx import ConnectError, ConnectTimeout, CookieConflict, HTTPError, InvalidURL, ReadTimeout, StreamError @@ -29,9 +29,13 @@ from couchbase_analytics.protocol.errors import WrappedError if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request_context import RequestContext + from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext + from couchbase_analytics.protocol._core.response import HttpResponse from couchbase_analytics.protocol.streaming import HttpStreamingResponse +ReqContext = Union['RequestContext', 'StreamingRequestContext'] +T = TypeVar('T', bound=Union['HttpResponse', 'HttpStreamingResponse']) + class RetryHandler: """ @@ -39,7 +43,7 @@ class RetryHandler: """ @staticmethod - def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: RequestContext) -> Optional[Exception]: + def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: ReqContext) -> Optional[Exception]: err_str = str(ex) if 'SSL:' in err_str: message = 'TLS connection error occurred.' @@ -62,7 +66,7 @@ def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: RequestCont return None @staticmethod - def handle_retry(ex: WrappedError, ctx: RequestContext) -> Optional[Union[BaseException, Exception]]: + def handle_retry(ex: WrappedError, ctx: ReqContext) -> Optional[Union[BaseException, Exception]]: if ex.retriable is True: delay = ctx.calculate_backoff() err: Optional[Union[BaseException, Exception]] = None @@ -91,9 +95,11 @@ def handle_retry(ex: WrappedError, ctx: RequestContext) -> Optional[Union[BaseEx return ex.unwrap() @staticmethod - def with_retries(fn: Callable[[HttpStreamingResponse], None]) -> Callable[[HttpStreamingResponse], None]: # noqa: C901 + def with_retries( # noqa: C901 + fn: Callable[[T], None], + ) -> Callable[[T], None]: # noqa: C901 @wraps(fn) - def wrapped_fn(self: HttpStreamingResponse) -> None: # noqa: C901 + def wrapped_fn(self: T) -> None: # noqa: C901 while True: try: fn(self) diff --git a/couchbase_analytics/protocol/cluster.py b/couchbase_analytics/protocol/cluster.py index a0b2053..4e77b62 100644 --- a/couchbase_analytics/protocol/cluster.py +++ b/couchbase_analytics/protocol/cluster.py @@ -17,6 +17,7 @@ from __future__ import annotations import atexit +import sys from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Optional, Union from uuid import uuid4 @@ -25,7 +26,9 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol._core.request import _RequestBuilder -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle from couchbase_analytics.protocol.streaming import HttpStreamingResponse if TYPE_CHECKING: @@ -85,6 +88,7 @@ def _shutdown(self) -> None: """ **INTERNAL** """ + atexit.unregister(self._shutdown_executor) self._client_adapter.close_client() self._client_adapter.reset_client() self._shutdown_executor() @@ -97,9 +101,10 @@ def _create_client(self) -> None: def _shutdown_executor(self) -> None: if self._tp_executor_shutdown_called is False: - self._client_adapter.log_message( - f'Shutting down ThreadPoolExecutor({self._tp_executor_prefix})', LogLevel.INFO - ) + if not sys.is_finalizing(): + self._client_adapter.log_message( + f'Shutting down ThreadPoolExecutor({self._tp_executor_prefix})', LogLevel.INFO + ) self._tp_executor.shutdown() self._tp_executor_shutdown_called = True @@ -120,11 +125,11 @@ def shutdown(self) -> None: def execute_query( self, statement: str, *args: object, **kwargs: object ) -> Union[BlockingQueryResult, Future[BlockingQueryResult]]: - base_req = self._request_builder.build_base_query_request(statement, *args, **kwargs) - lazy_execute = base_req.options.pop('lazy_execute', None) - stream_config = base_req.options.pop('stream_config', None) - request_context = RequestContext( - self.client_adapter, base_req, self.threadpool_executor, stream_config=stream_config + req = self._request_builder.build_query_request(statement, *args, **kwargs) + lazy_execute = req.options.pop('lazy_execute', None) + stream_config = req.options.pop('stream_config', None) + request_context = StreamingRequestContext( + self.client_adapter, req, self.threadpool_executor, stream_config=stream_config ) resp = HttpStreamingResponse(request_context, lazy_execute=lazy_execute) @@ -147,6 +152,16 @@ def _execute_query(http_response: HttpStreamingResponse) -> BlockingQueryResult: resp.send_request() return BlockingQueryResult(resp) + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + base_req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = base_req.options.pop('stream_config', None) + request_context = RequestContext(self.client_adapter, base_req) + resp = HttpResponse(request_context) + resp.send_request() + return BlockingQueryHandle( + self._client_adapter, self._request_builder, resp, self._tp_executor, stream_config=stream_config + ) + @classmethod def create_instance( cls, http_endpoint: str, credential: Credential, options: Optional[ClusterOptions], **kwargs: object diff --git a/couchbase_analytics/protocol/cluster.pyi b/couchbase_analytics/protocol/cluster.pyi index dbb950a..206bb55 100644 --- a/couchbase_analytics/protocol/cluster.pyi +++ b/couchbase_analytics/protocol/cluster.pyi @@ -25,8 +25,16 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.result import BlockingQueryResult -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle class Cluster: @overload @@ -119,6 +127,26 @@ class Cluster: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/protocol/connection.py b/couchbase_analytics/protocol/connection.py index 3d37496..031f511 100644 --- a/couchbase_analytics/protocol/connection.py +++ b/couchbase_analytics/protocol/connection.py @@ -27,7 +27,7 @@ from couchbase_analytics.common._core.utils import is_null_or_empty from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import DefaultJsonDeserializer, Deserializer -from couchbase_analytics.common.options import ClusterOptions, SecurityOptions, TimeoutOptions +from couchbase_analytics.common.options import SecurityOptions, TimeoutOptions from couchbase_analytics.common.request import RequestURL from couchbase_analytics.protocol.options import ( ClusterOptionsTransformedKwargs, @@ -46,11 +46,13 @@ class StreamingTimeouts(TypedDict, total=False): class DefaultTimeouts(TypedDict): connect_timeout: float + handle_request_timeout: float query_timeout: float DEFAULT_TIMEOUTS: DefaultTimeouts = { 'connect_timeout': 10, + 'handle_request_timeout': 10, 'query_timeout': 60 * 10, } @@ -173,8 +175,17 @@ def get_connect_timeout(self) -> float: return connect_timeout return DEFAULT_TIMEOUTS['connect_timeout'] + def get_handle_request_timeout(self) -> float: + timeout_opts: Optional[TimeoutOptionsTransformedKwargs] = self.cluster_options.get('timeout_options') + if timeout_opts is not None: + handle_request_timeout = timeout_opts.get('handle_request_timeout', None) + if handle_request_timeout is not None: + return handle_request_timeout + return DEFAULT_TIMEOUTS['handle_request_timeout'] + def get_max_retries(self) -> int: - return self.cluster_options.get('max_retries', None) or DEFAULT_MAX_RETRIES + max_retries = self.cluster_options.get('max_retries', None) + return max_retries if max_retries is not None else DEFAULT_MAX_RETRIES def get_init_details(self) -> str: details = {'url': self.url.get_formatted_url(), 'cluster_options': self.cluster_options} @@ -255,8 +266,6 @@ def create( logger_name = cast(Optional[str], kwargs.pop('logger_name', None)) cluster_opts = opts_builder.build_cluster_options( - ClusterOptions, - ClusterOptionsTransformedKwargs, kwargs, options, query_str_opts=parse_query_str_options(query_str_opts, logger_name=logger_name), diff --git a/couchbase_analytics/protocol/errors.py b/couchbase_analytics/protocol/errors.py index 3324e30..c417da7 100644 --- a/couchbase_analytics/protocol/errors.py +++ b/couchbase_analytics/protocol/errors.py @@ -25,6 +25,7 @@ AnalyticsError, InvalidCredentialError, QueryError, + QueryNotFoundError, TimeoutError, ) from couchbase_analytics.common.logging import LogLevel @@ -115,15 +116,19 @@ class ErrorMapper: def build_error_from_http_status_code(message: str, context: ErrorContext) -> WrappedError: if context.status_code == 503: return WrappedError(AnalyticsError(context=str(context), message=message), retriable=True) + if context.status_code == 404: + return WrappedError(QueryNotFoundError(context=str(context), message=message), retriable=False) return WrappedError(AnalyticsError(context=str(context), message=message)) @staticmethod # noqa: C901 - def build_error_from_json(json_data: List[Dict[str, Any]], context: ErrorContext) -> WrappedError: + def build_error_from_json(json_data: List[Dict[str, Any]], context: ErrorContext) -> WrappedError: # noqa: C901 if context.status_code is None: return WrappedError(AnalyticsError(context=str(context), message='Unknown error occurred.')) if context.status_code == 401: return WrappedError(InvalidCredentialError(context=str(context), message='Invalid credentials provided.')) + if context.status_code == 404: + return WrappedError(QueryNotFoundError(context=str(context), message='Resource not found'), retriable=False) first_non_retriable_error: Optional[ServerQueryError] = None first_retriable_error: Optional[ServerQueryError] = None @@ -156,6 +161,17 @@ def build_error_from_json(json_data: List[Dict[str, Any]], context: ErrorContext retriable = first_non_retriable_error is None and first_retriable_error is not None return WrappedError(q_err, retriable=retriable) + @staticmethod + def maybe_raise_error_from_status_code( + status_code: int, context: str, ignore_not_found_status: Optional[bool] = False + ) -> None: + if status_code == 401: + raise WrappedError(InvalidCredentialError(context=context, message='Invalid credentials provided.')) + if status_code == 404 and ignore_not_found_status is not True: + raise WrappedError(QueryNotFoundError(context=context, message='Resource not found')) + if status_code == 503: + raise WrappedError(AnalyticsError(context=context, message='Service unavailable.'), retriable=True) + @staticmethod def handle_socket_error( fn: Callable[[str, int, Optional[Callable[..., None]]], str], diff --git a/couchbase_analytics/protocol/options.py b/couchbase_analytics/protocol/options.py index d47ee08..b8827e8 100644 --- a/couchbase_analytics/protocol/options.py +++ b/couchbase_analytics/protocol/options.py @@ -17,7 +17,7 @@ from __future__ import annotations from copy import copy -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypedDict, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypedDict, Union, overload from couchbase_analytics.common._core import JsonStreamConfig from couchbase_analytics.common._core.utils import ( @@ -35,15 +35,19 @@ from couchbase_analytics.common.enums import QueryScanConsistency from couchbase_analytics.common.options import ( ClusterOptions, + FetchResultsOptions, OptionsClass, QueryOptions, SecurityOptions, + StartQueryOptions, TimeoutOptions, ) from couchbase_analytics.common.options_base import ( ClusterOptionsValidKeys, + FetchResultsOptionsValidKeys, QueryOptionsValidKeys, SecurityOptionsValidKeys, + StartQueryOptionsValidKeys, TimeoutOptionsValidKeys, ) @@ -103,17 +107,20 @@ class SecurityOptionsTransformedKwargs(TypedDict, total=False): class TimeoutOptionsTransforms(TypedDict): connect_timeout: Dict[Literal['connect_timeout'], Callable[[Any], float]] + handle_request_timeout: Dict[Literal['handle_request_timeout'], Callable[[Any], float]] query_timeout: Dict[Literal['query_timeout'], Callable[[Any], float]] TIMEOUT_OPTIONS_TRANSFORMS: TimeoutOptionsTransforms = { 'connect_timeout': {'connect_timeout': to_seconds}, + 'handle_request_timeout': {'handle_request_timeout': to_seconds}, 'query_timeout': {'query_timeout': to_seconds}, } class TimeoutOptionsTransformedKwargs(TypedDict, total=False): connect_timeout: Optional[int] + handle_request_timeout: Optional[int] query_timeout: Optional[int] @@ -155,7 +162,6 @@ class QueryOptionsTransformedKwargs(TypedDict, total=False): max_retries: Optional[int] named_parameters: Optional[Any] positional_parameters: Optional[Any] - priority: Optional[bool] query_context: Optional[str] raw: Optional[Dict[str, Any]] readonly: Optional[bool] @@ -164,25 +170,65 @@ class QueryOptionsTransformedKwargs(TypedDict, total=False): timeout: Optional[float] -TransformedOptionKwargs = TypeVar( - 'TransformedOptionKwargs', - QueryOptionsTransformedKwargs, - ClusterOptionsTransformedKwargs, - SecurityOptionsTransformedKwargs, - TimeoutOptionsTransformedKwargs, -) +class StartQueryOptionsTransforms(TypedDict): + client_context_id: Dict[Literal['client_context_id'], Callable[[Any], str]] + max_retries: Dict[Literal['max_retries'], Callable[[Any], int]] + named_parameters: Dict[Literal['named_parameters'], Callable[[Any], Any]] + positional_parameters: Dict[Literal['positional_parameters'], Callable[[Any], Any]] + query_context: Dict[Literal['query_context'], Callable[[Any], str]] + raw: Dict[Literal['raw'], Callable[[Any], Dict[str, Any]]] + readonly: Dict[Literal['readonly'], Callable[[Any], bool]] + scan_consistency: Dict[Literal['scan_consistency'], Callable[[Any], str]] + stream_config: Dict[Literal['stream_config'], Callable[[Any], JsonStreamConfig]] + timeout: Dict[Literal['timeout'], Callable[[Any], float]] + + +START_QUERY_OPTIONS_TRANSFORMS: StartQueryOptionsTransforms = { + 'client_context_id': {'client_context_id': VALIDATE_STR}, + 'max_retries': {'max_retries': VALIDATE_INT}, + 'named_parameters': {'named_parameters': lambda x: x}, + 'positional_parameters': {'positional_parameters': lambda x: x}, + 'query_context': {'query_context': VALIDATE_STR}, + 'raw': {'raw': validate_raw_dict}, + 'readonly': {'readonly': VALIDATE_BOOL}, + 'scan_consistency': {'scan_consistency': QUERY_CONSISTENCY_TO_STR}, + 'stream_config': {'stream_config': lambda x: x}, + 'timeout': {'timeout': to_seconds}, +} + + +class StartQueryOptionsTransformedKwargs(TypedDict, total=False): + client_context_id: Optional[str] + max_retries: Optional[int] + named_parameters: Optional[Any] + positional_parameters: Optional[Any] + query_context: Optional[str] + raw: Optional[Dict[str, Any]] + readonly: Optional[bool] + scan_consistency: Optional[str] + stream_config: Optional[JsonStreamConfig] + timeout: Optional[float] + + +class FetchResultsOptionsTransforms(TypedDict): + deserializer: Dict[Literal['deserializer'], Callable[[Any], Deserializer]] + + +FETCH_RESULTS_OPTIONS_TRANSFORMS: FetchResultsOptionsTransforms = { + 'deserializer': {'deserializer': VALIDATE_DESERIALIZER}, +} + + +class FetchResultsOptionsTransformedKwargs(TypedDict, total=False): + deserializer: Optional[Deserializer] -TransformedClusterOptionKwargs = TypeVar( - 'TransformedClusterOptionKwargs', - ClusterOptionsTransformedKwargs, - SecurityOptionsTransformedKwargs, - TimeoutOptionsTransformedKwargs, -) TransformDetailsPair = Union[ Tuple[List[QueryOptionsValidKeys], QueryOptionsTransforms], Tuple[List[ClusterOptionsValidKeys], ClusterOptionsTransforms], + Tuple[List[FetchResultsOptionsValidKeys], FetchResultsOptionsTransforms], Tuple[List[SecurityOptionsValidKeys], SecurityOptionsTransforms], + Tuple[List[StartQueryOptionsValidKeys], StartQueryOptionsTransforms], Tuple[List[TimeoutOptionsValidKeys], TimeoutOptionsTransforms], ] @@ -216,18 +262,20 @@ def _get_transform_details(self, option_type: str) -> TransformDetailsPair: # n return TimeoutOptions.VALID_OPTION_KEYS, TIMEOUT_OPTIONS_TRANSFORMS elif option_type == 'QueryOptions': return QueryOptions.VALID_OPTION_KEYS, QUERY_OPTIONS_TRANSFORMS + elif option_type == 'StartQueryOptions': + return StartQueryOptions.VALID_OPTION_KEYS, START_QUERY_OPTIONS_TRANSFORMS + elif option_type == 'FetchResultsOptions': + return FetchResultsOptions.VALID_OPTION_KEYS, FETCH_RESULTS_OPTIONS_TRANSFORMS else: raise ValueError('Invalid OptionType.') def build_cluster_options( # noqa: C901 self, - option_type: type[OptionsClass], - output_type: type[TransformedClusterOptionKwargs], orig_kwargs: Dict[str, object], options: Optional[object] = None, query_str_opts: Optional[Dict[str, QueryStrVal]] = None, - ) -> TransformedClusterOptionKwargs: - temp_options = self._get_options_copy(option_type, orig_kwargs, options) + ) -> ClusterOptionsTransformedKwargs: + temp_options = self._get_options_copy(ClusterOptions, orig_kwargs, options) # we flatten all the nested options (timeout_options & security_options) # so that we can combine the nested options w/ potential query string options @@ -254,37 +302,84 @@ def build_cluster_options( # noqa: C901 keys_to_ignore: List[str] = [*ClusterOptions.VALID_OPTION_KEYS, *TimeoutOptions.VALID_OPTION_KEYS] - # not going to be able to make mypy happy w/ keys_to_ignore :/ - transformed_security_opts = self.build_options( - SecurityOptions, SecurityOptionsTransformedKwargs, temp_options, keys_to_ignore=keys_to_ignore - ) + transformed_security_opts = self.build_options(SecurityOptions, temp_options, keys_to_ignore=keys_to_ignore) if transformed_security_opts: temp_options['security_options'] = transformed_security_opts keys_to_ignore = [*ClusterOptions.VALID_OPTION_KEYS, *SecurityOptions.VALID_OPTION_KEYS] - # not going to be able to make mypy happy w/ keys_to_ignore :/ - transformed_timeout_opts = self.build_options( - TimeoutOptions, TimeoutOptionsTransformedKwargs, temp_options, keys_to_ignore=keys_to_ignore - ) + transformed_timeout_opts = self.build_options(TimeoutOptions, temp_options, keys_to_ignore=keys_to_ignore) if transformed_timeout_opts: temp_options['timeout_options'] = transformed_timeout_opts # transform final ClusterOptions - transformed_opts = self.build_options(option_type, output_type, temp_options) + transformed_opts = self.build_options(ClusterOptions, temp_options) return transformed_opts + @overload + def build_options( + self, + option_type: type[ClusterOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> ClusterOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[SecurityOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> SecurityOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[TimeoutOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> TimeoutOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[QueryOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> QueryOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[StartQueryOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> StartQueryOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[FetchResultsOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> FetchResultsOptionsTransformedKwargs: ... + def build_options( self, option_type: type[OptionsClass], - output_type: type[TransformedOptionKwargs], orig_kwargs: Dict[str, object], options: Optional[object] = None, keys_to_ignore: Optional[List[str]] = None, - ) -> TransformedOptionKwargs: + ) -> Any: temp_options = self._get_options_copy(option_type, orig_kwargs, options) - transformed_opts: TransformedOptionKwargs = {} + transformed_opts: Any = {} # Option 1 satisfies mypy, but we want temp_options to be the limiting factor for the loop. # Option 2. Also makes providing warnings/exceptions for users not using static type checking easier, # but unfortunately we need to use some type: ignore comments @@ -304,7 +399,7 @@ def build_options( for nk, cfn in transforms.items(): conv = cfn(v) if conv is not None: - transformed_opts[nk] = conv # type: ignore[literal-required] + transformed_opts[nk] = conv elif keys_to_ignore and k not in keys_to_ignore: raise ValueError(f'Invalid key provided (key={k}).') diff --git a/couchbase_analytics/protocol/query_handle.py b/couchbase_analytics/protocol/query_handle.py new file mode 100644 index 0000000..a37d5ca --- /dev/null +++ b/couchbase_analytics/protocol/query_handle.py @@ -0,0 +1,172 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Optional + +from couchbase_analytics.common._core.query_handle import QueryHandleStatusResponse +from couchbase_analytics.common.errors import AnalyticsError, QueryNotFoundError +from couchbase_analytics.common.query_handle import BlockingQueryHandle as _CoreBlockingQueryHandle +from couchbase_analytics.common.query_handle import BlockingQueryResultHandle as _CoreBlockingQueryResultHandle +from couchbase_analytics.common.query_handle import BlockingQueryStatus as _CoreBlockingQueryStatus +from couchbase_analytics.common.result import BlockingQueryResult +from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter +from couchbase_analytics.protocol._core.request import _RequestBuilder +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.streaming import HttpStreamingResponse + +if TYPE_CHECKING: + from couchbase_analytics.common._core import JsonStreamConfig + from couchbase_analytics.options import FetchResultsOptions + + +class BlockingQueryHandle(_CoreBlockingQueryHandle): + def __init__( + self, + client_adapter: _ClientAdapter, + request_builder: _RequestBuilder, + http_response: HttpResponse, + tp_executor: ThreadPoolExecutor, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._http_response = http_response + self._tp_executor = tp_executor + self._stream_config = stream_config + self._request_id: str = '' + self._handle: str = '' + self._get_status_handle() + + def fetch_status(self, options: Optional[Any] = None, **kwargs: Any) -> BlockingQueryStatus: + server_req = self._request_builder.build_request_from_handle(self._handle) + request_context = RequestContext(self._client_adapter, server_req) + resp = HttpResponse(request_context) + resp.send_request() + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + status_response = self._get_handle_status_response(resp) + return BlockingQueryStatus( + self._client_adapter, + self._request_builder, + self._tp_executor, + status_response, + stream_config=self._stream_config, + ) + + def cancel(self, options: Optional[Any] = None, **kwargs: Any) -> None: + cancel_req = self._request_builder.build_cancel_request(self._request_id) + request_context = RequestContext(self._client_adapter, cancel_req) + resp = HttpResponse(request_context, has_no_body_response=True, request_id=self._request_id) + resp.send_request() + + def _get_status_handle(self) -> None: + if self._http_response.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + request_id = self._http_response.json_response.get('requestID', None) + if request_id is None: + raise QueryNotFoundError(message='Server response is missing "requestID" field.') + handle = self._http_response.json_response.get('handle', None) + if handle is None: + raise QueryNotFoundError(message='Server response is missing "handle" field.') + + self._request_id = request_id + self._handle = handle + + def _get_handle_status_response(self, resp: HttpResponse) -> QueryHandleStatusResponse: + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + return QueryHandleStatusResponse.from_server(self._request_id, resp.json_response) + + +class BlockingQueryResultHandle(_CoreBlockingQueryResultHandle): + def __init__( + self, + client_adapter: _ClientAdapter, + request_builder: _RequestBuilder, + tp_executor: ThreadPoolExecutor, + request_id: str, + handle: str, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._tp_executor = tp_executor + self._request_id = request_id + self._handle = handle + self._stream_config = stream_config + + def fetch_results(self, options: Optional[FetchResultsOptions] = None, **kwargs: Any) -> BlockingQueryResult: + server_req = self._request_builder.build_fetch_results_request(self._handle, options, **kwargs) + request_context = StreamingRequestContext( + self._client_adapter, server_req, self._tp_executor, stream_config=self._stream_config + ) + resp = HttpStreamingResponse(request_context, request_id=self._request_id) + resp.send_request() + return BlockingQueryResult(resp) + + def discard_results(self, options: Optional[Any] = None, **kwargs: Any) -> None: + req = self._request_builder.build_discard_results_request(self._handle) + request_context = RequestContext(self._client_adapter, req) + resp = HttpResponse(request_context, has_no_body_response=True, request_id=self._request_id) + resp.send_request() + + +class BlockingQueryStatus(_CoreBlockingQueryStatus): + def __init__( + self, + client_adapter: _ClientAdapter, + request_builder: _RequestBuilder, + tp_executor: ThreadPoolExecutor, + status_resp: QueryHandleStatusResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._tp_executor = tp_executor + self._status_resp = status_resp + self._stream_config = stream_config + + def results_ready(self) -> bool: + return self._status_resp.handle is not None + + def result_handle(self) -> BlockingQueryResultHandle: + if self._status_resp.handle is None: + raise AnalyticsError(message='Query is not ready. Handle is not available.') + + return BlockingQueryResultHandle( + self._client_adapter, + self._request_builder, + self._tp_executor, + self._status_resp.request_id, + self._status_resp.handle, + stream_config=self._stream_config, + ) + + def __repr__(self) -> str: + return f'BlockingQueryStatus({self._status_resp.get_details()})' + + def __str__(self) -> str: + return self.__repr__() diff --git a/couchbase_analytics/protocol/scope.py b/couchbase_analytics/protocol/scope.py index 1e77457..4600e71 100644 --- a/couchbase_analytics/protocol/scope.py +++ b/couchbase_analytics/protocol/scope.py @@ -22,7 +22,9 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol._core.request import _RequestBuilder -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle from couchbase_analytics.protocol.streaming import HttpStreamingResponse if TYPE_CHECKING: @@ -33,7 +35,7 @@ class Scope: def __init__(self, database: Database, scope_name: str) -> None: self._database = database self._scope_name = scope_name - self._request_builder = _RequestBuilder(self.client_adapter, self._database.name, self.name) + self._request_builder = _RequestBuilder(self._database.client_adapter, self._database.name, self.name) @property def client_adapter(self) -> _ClientAdapter: @@ -59,11 +61,11 @@ def threadpool_executor(self) -> ThreadPoolExecutor: def execute_query( self, statement: str, *args: object, **kwargs: object ) -> Union[BlockingQueryResult, Future[BlockingQueryResult]]: - base_req = self._request_builder.build_base_query_request(statement, *args, **kwargs) - lazy_execute = base_req.options.pop('lazy_execute', None) - stream_config = base_req.options.pop('stream_config', None) - request_context = RequestContext( - self.client_adapter, base_req, self.threadpool_executor, stream_config=stream_config + req = self._request_builder.build_query_request(statement, *args, **kwargs) + lazy_execute = req.options.pop('lazy_execute', None) + stream_config = req.options.pop('stream_config', None) + request_context = StreamingRequestContext( + self.client_adapter, req, self.threadpool_executor, stream_config=stream_config ) resp = HttpStreamingResponse(request_context, lazy_execute=lazy_execute) @@ -84,3 +86,13 @@ def _execute_query(http_response: HttpStreamingResponse) -> BlockingQueryResult: if lazy_execute is not True: resp.send_request() return BlockingQueryResult(resp) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + base_req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = base_req.options.pop('stream_config', None) + request_context = RequestContext(self.client_adapter, base_req) + resp = HttpResponse(request_context) + resp.send_request() + return BlockingQueryHandle( + self.client_adapter, self._request_builder, resp, self.threadpool_executor, stream_config=stream_config + ) diff --git a/couchbase_analytics/protocol/scope.pyi b/couchbase_analytics/protocol/scope.pyi index 9296863..e3d0b59 100644 --- a/couchbase_analytics/protocol/scope.pyi +++ b/couchbase_analytics/protocol/scope.pyi @@ -24,9 +24,10 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.common.result import BlockingQueryResult -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol.database import Database as Database +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle class Scope: def __init__(self, database: Database, scope_name: str) -> None: ... @@ -106,3 +107,23 @@ class Scope: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... diff --git a/couchbase_analytics/protocol/streaming.py b/couchbase_analytics/protocol/streaming.py index d6372a8..cf2dcf1 100644 --- a/couchbase_analytics/protocol/streaming.py +++ b/couchbase_analytics/protocol/streaming.py @@ -26,12 +26,17 @@ from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError, TimeoutError from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.query import QueryMetadata -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import StreamingRequestContext from couchbase_analytics.protocol._core.retries import RetryHandler class HttpStreamingResponse: - def __init__(self, request_context: RequestContext, lazy_execute: Optional[bool] = None) -> None: + def __init__( + self, + request_context: StreamingRequestContext, + lazy_execute: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: self._request_context = request_context if lazy_execute is not None: self._lazy_execute = lazy_execute @@ -39,6 +44,7 @@ def __init__(self, request_context: RequestContext, lazy_execute: Optional[bool] self._lazy_execute = False self._metadata: Optional[QueryMetadata] = None self._core_response: HttpCoreResponse + self._request_id = request_id @property def lazy_execute(self) -> bool: @@ -68,7 +74,7 @@ def _handle_iteration_abort(self) -> None: def _process_response( self, raw_response: Optional[ParsedResult] = None, handle_context_shutdown: Optional[bool] = False ) -> None: - json_response = self._request_context.process_response( + json_response = self._request_context.process_streaming_response( self.close, raw_response=raw_response, handle_context_shutdown=handle_context_shutdown ) self.set_metadata(json_data=json_response) @@ -98,7 +104,14 @@ def get_metadata(self) -> QueryMetadata: def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: try: - self._metadata = QueryMetadata(build_query_metadata(json_data=json_data, raw_metadata=raw_metadata)) + self._metadata = QueryMetadata( + build_query_metadata( + json_data=json_data, + raw_metadata=raw_metadata, + request_id=self._request_id, + log_fn=self._request_context.log_message, + ) + ) self._request_context.shutdown() except (AnalyticsError, ValueError) as err: self._request_context.shutdown(err) diff --git a/couchbase_analytics/query_handle.py b/couchbase_analytics/query_handle.py new file mode 100644 index 0000000..01c6769 --- /dev/null +++ b/couchbase_analytics/query_handle.py @@ -0,0 +1,18 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from couchbase_analytics.common.query_handle import BlockingQueryHandle as BlockingQueryHandle # noqa: F401 +from couchbase_analytics.common.query_handle import BlockingQueryResultHandle as BlockingQueryResultHandle # noqa: F401 +from couchbase_analytics.common.query_handle import BlockingQueryStatus as BlockingQueryStatus # noqa: F401 diff --git a/couchbase_analytics/result.py b/couchbase_analytics/result.py index 4f0e8e4..4b0cb37 100644 --- a/couchbase_analytics/result.py +++ b/couchbase_analytics/result.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from couchbase_analytics.common.result import AsyncQueryResult as AsyncQueryResult # noqa: F401 from couchbase_analytics.common.result import BlockingQueryResult as BlockingQueryResult # noqa: F401 from couchbase_analytics.common.result import QueryResult as QueryResult # noqa: F401 diff --git a/couchbase_analytics/scope.py b/couchbase_analytics/scope.py index 702becb..7a805e6 100644 --- a/couchbase_analytics/scope.py +++ b/couchbase_analytics/scope.py @@ -19,6 +19,7 @@ from concurrent.futures import Future from typing import TYPE_CHECKING, Union +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult if TYPE_CHECKING: @@ -114,3 +115,19 @@ def execute_query( """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + """Executes a query against an Analytics scope in async mode. + + .. seealso:: + :meth:`couchbase_analytics.Cluster.start_query`: For how to execute cluster-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~couchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~couchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`: An instance of a :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) diff --git a/couchbase_analytics/scope.pyi b/couchbase_analytics/scope.pyi index c5d36d2..fa70bef 100644 --- a/couchbase_analytics/scope.pyi +++ b/couchbase_analytics/scope.pyi @@ -23,8 +23,9 @@ else: from typing import Unpack from couchbase_analytics import JSONType -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from couchbase_analytics.protocol.database import Database as Database +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult class Scope: @@ -101,3 +102,23 @@ class Scope: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... diff --git a/couchbase_analytics/tests/connection_t.py b/couchbase_analytics/tests/connection_t.py index b9b0cf8..709124d 100644 --- a/couchbase_analytics/tests/connection_t.py +++ b/couchbase_analytics/tests/connection_t.py @@ -67,7 +67,7 @@ def test_connstr_options_max_retries(self) -> None: connstr = f'https://localhost?max_retries={max_retries}' client = _ClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') assert req.max_retries == max_retries @pytest.mark.parametrize( @@ -99,7 +99,7 @@ def test_connstr_options_timeout(self, duration: str, expected_seconds: str) -> connstr = f'https://localhost?{to_query_str(opts)}' client = _ClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') expected = float(expected_seconds) returned_timeout_opts = req.get_request_timeouts() assert isinstance(returned_timeout_opts, dict) diff --git a/couchbase_analytics/tests/options_t.py b/couchbase_analytics/tests/options_t.py index a6722cd..30e05de 100644 --- a/couchbase_analytics/tests/options_t.py +++ b/couchbase_analytics/tests/options_t.py @@ -186,10 +186,15 @@ def test_security_options_invalid_kwargs(self, opts: Dict[str, object]) -> None: [ ({}, None), ({'connect_timeout': timedelta(seconds=30)}, {'connect_timeout': 30}), + ({'handle_request_timeout': timedelta(seconds=20)}, {'handle_request_timeout': 20}), ({'query_timeout': timedelta(seconds=30)}, {'query_timeout': 30}), ( - {'connect_timeout': timedelta(seconds=60), 'query_timeout': timedelta(seconds=30)}, - {'connect_timeout': 60, 'query_timeout': 30}, + { + 'connect_timeout': timedelta(seconds=60), + 'handle_request_timeout': timedelta(seconds=20), + 'query_timeout': timedelta(seconds=30), + }, + {'connect_timeout': 60, 'handle_request_timeout': 20, 'query_timeout': 30}, ), ], ) @@ -202,10 +207,15 @@ def test_timeout_options(self, opts: TimeoutOptionsKwargs, expected_opts: Timeou 'opts, expected_opts', [ ({'connect_timeout': timedelta(seconds=30)}, {'connect_timeout': 30}), + ({'handle_request_timeout': timedelta(seconds=20)}, {'handle_request_timeout': 20}), ({'query_timeout': timedelta(seconds=30)}, {'query_timeout': 30}), ( - {'connect_timeout': timedelta(seconds=60), 'query_timeout': timedelta(seconds=30)}, - {'connect_timeout': 60, 'query_timeout': 30}, + { + 'connect_timeout': timedelta(seconds=60), + 'handle_request_timeout': timedelta(seconds=20), + 'query_timeout': timedelta(seconds=30), + }, + {'connect_timeout': 60, 'handle_request_timeout': 20, 'query_timeout': 30}, ), ], ) @@ -215,7 +225,12 @@ def test_timeout_options_kwargs(self, opts: Dict[str, object], expected_opts: Di assert expected_opts == client.connection_details.cluster_options.get('timeout_options') @pytest.mark.parametrize( - 'opts', [{'connect_timeout': timedelta(seconds=-1)}, {'query_timeout': timedelta(seconds=-1)}] + 'opts', + [ + {'connect_timeout': timedelta(seconds=-1)}, + {'handle_request_timeout': timedelta(seconds=-1)}, + {'query_timeout': timedelta(seconds=-1)}, + ], ) def test_timeout_options_must_be_positive(self, opts: TimeoutOptionsKwargs) -> None: cred = Credential.from_username_and_password('Administrator', 'password') @@ -223,7 +238,12 @@ def test_timeout_options_must_be_positive(self, opts: TimeoutOptionsKwargs) -> N _ClientAdapter('https://localhost', cred, ClusterOptions(timeout_options=TimeoutOptions(**opts))) @pytest.mark.parametrize( - 'opts', [{'connect_timeout': timedelta(seconds=-1)}, {'query_timeout': timedelta(seconds=-1)}] + 'opts', + [ + {'connect_timeout': timedelta(seconds=-1)}, + {'handle_request_timeout': timedelta(seconds=-1)}, + {'query_timeout': timedelta(seconds=-1)}, + ], ) def test_timeout_options_must_be_positive_kwargs(self, opts: Dict[str, object]) -> None: cred = Credential.from_username_and_password('Administrator', 'password') diff --git a/couchbase_analytics/tests/query_options_t.py b/couchbase_analytics/tests/query_options_t.py index 39d0b67..46d58c3 100644 --- a/couchbase_analytics/tests/query_options_t.py +++ b/couchbase_analytics/tests/query_options_t.py @@ -76,7 +76,7 @@ def test_options_deserializer( deserializer = DefaultJsonDeserializer() q_opts = QueryOptions(deserializer=deserializer) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -89,35 +89,35 @@ def test_options_deserializer_kwargs( deserializer = DefaultJsonDeserializer() kwargs: QueryOptionsKwargs = {'deserializer': deserializer} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer query_ctx.validate_query_context(req.body) - @pytest.mark.parametrize('max_retries', [5, 10, None]) + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) def test_options_max_retries( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] ) -> None: if max_retries is not None: q_opts = QueryOptions(max_retries=max_retries) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) query_ctx.validate_query_context(req.body) - @pytest.mark.parametrize('max_retries', [5, 10, None]) + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) def test_options_max_retries_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] ) -> None: if max_retries is not None: kwargs: QueryOptionsKwargs = {'max_retries': max_retries} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -128,7 +128,7 @@ def test_options_named_parameters( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} q_opts = QueryOptions(named_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -138,7 +138,7 @@ def test_options_named_parameters_kwargs( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} kwargs: QueryOptionsKwargs = {'named_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -148,7 +148,7 @@ def test_options_positional_parameters( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] q_opts = QueryOptions(positional_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -158,7 +158,7 @@ def test_options_positional_parameters_kwargs( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] kwargs: QueryOptionsKwargs = {'positional_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -167,7 +167,7 @@ def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} q_opts = QueryOptions(raw=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -178,7 +178,7 @@ def test_options_raw_kwargs( pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} kwargs: QueryOptionsKwargs = {'raw': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -187,7 +187,7 @@ def test_options_readonly( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(readonly=True) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -196,7 +196,7 @@ def test_options_readonly_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'readonly': True} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -207,7 +207,7 @@ def test_options_scan_consistency( from couchbase_analytics.query import QueryScanConsistency q_opts = QueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -218,7 +218,7 @@ def test_options_scan_consistency_kwargs( from couchbase_analytics.query import QueryScanConsistency kwargs: QueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -227,7 +227,7 @@ def test_options_timeout( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=20)) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -238,7 +238,7 @@ def test_options_timeout_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=20)} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -248,14 +248,14 @@ def test_options_timeout_kwargs( def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=-1)) with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, q_opts) + request_builder.build_query_request(query_statment, q_opts) def test_options_timeout_must_be_positive_kwargs( self, query_statment: str, request_builder: _RequestBuilder ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, **kwargs) + request_builder.build_query_request(query_statment, **kwargs) class ClusterQueryOptionsTests(QueryOptionsTestSuite): diff --git a/couchbase_analytics/tests/start_query_integration_t.py b/couchbase_analytics/tests/start_query_integration_t.py new file mode 100644 index 0000000..90d45ab --- /dev/null +++ b/couchbase_analytics/tests/start_query_integration_t.py @@ -0,0 +1,390 @@ +# Copyright 2016-2026. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import json +from datetime import timedelta +from typing import Any, Dict + +import pytest + +from couchbase_analytics.common.request import RequestState +from couchbase_analytics.deserializer import PassthroughDeserializer +from couchbase_analytics.errors import AnalyticsError, QueryError, QueryNotFoundError, TimeoutError +from couchbase_analytics.options import FetchResultsOptions, StartQueryOptions +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle, BlockingQueryStatus +from tests import YieldFixture +from tests.environments.base_environment import BlockingTestEnvironment + + +class QueryTestSuite: + TEST_MANIFEST = [ + 'test_cancel_prior_iterating', + 'test_cancel_while_iterating', + 'test_query_metadata', + 'test_query_metadata_not_available', + 'test_query_named_parameters', + 'test_query_named_parameters_no_options', + 'test_query_named_parameters_override', + 'test_query_passthrough_deserializer', + 'test_query_positional_params', + 'test_query_positional_params_no_option', + 'test_query_positional_params_override', + 'test_query_raises_exception_prior_to_iterating', + 'test_query_raw_options', + 'test_query_results', + 'test_query_status_not_found', + 'test_query_status_prior_to_results', + 'test_query_timeout', + ] + + @pytest.fixture(scope='class') + def query_statement_limit2(self, test_env: BlockingTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_pos_params_limit2(self, test_env: BlockingTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} WHERE country = $1 LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} WHERE country = $1 LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_named_params_limit2(self, test_env: BlockingTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} WHERE country = $country LIMIT 2;' + else: + return f'SELECT * FROM {test_env.fqdn} WHERE country = $country LIMIT 2;' + + @pytest.fixture(scope='class') + def query_statement_limit5(self, test_env: BlockingTestEnvironment) -> str: + if test_env.use_scope: + return f'SELECT * FROM {test_env.collection_name} LIMIT 5;' + else: + return f'SELECT * FROM {test_env.fqdn} LIMIT 5;' + + def test_cancel_prior_iterating(self, test_env: BlockingTestEnvironment) -> None: + statement = 'FROM range(0, 100000) AS r SELECT *' + q_handle = test_env.cluster_or_scope.start_query(statement) + assert isinstance(q_handle, BlockingQueryHandle) + q_handle.cancel() + + # it takes a moment for the cancellation to propagate, so we'll retry fetching + # status a few times until we get an exception + BlockingTestEnvironment.try_n_times_till_exception(10, 2, q_handle.fetch_status) + + with pytest.raises(QueryError): + q_handle.fetch_status() + + q_handle.cancel() # should be idempotent and not raise + + def test_cancel_while_iterating( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_limit5: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + rows = [] + count = 0 + for row in result.rows(): + if count == 2: + result.cancel() + assert row is not None + rows.append(row) + count += 1 + + assert len(rows) == count + expected_state = RequestState.Cancelled + assert result._http_response._request_context.request_state == expected_state + with pytest.raises(RuntimeError): + result.metadata() + test_env.assert_streaming_response_state(result) + + def test_query_metadata( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_limit5: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + + expected_count = 5 + test_env.assert_rows(result, expected_count) + + metadata = result.metadata() + + assert len(metadata.warnings()) == 0 + assert len(metadata.request_id()) > 0 + + metrics = metadata.metrics() + + assert metrics.result_size() > 0 + assert metrics.result_count() == expected_count + assert metrics.processed_objects() > 0 + # sometimes we have a negative elapsed time which we set to 0 + assert metrics.elapsed_time() >= timedelta(0) + assert metrics.execution_time() > timedelta(0) + test_env.assert_streaming_response_state(result) + result_handle.discard_results() + + def test_query_metadata_not_available( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_limit5: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + + with pytest.raises(RuntimeError): + result.metadata() + + # Read one row + next(iter(result.rows())) + + with pytest.raises(RuntimeError): + result.metadata() + + # Iterate the rest of the rows + rows = list(result.rows()) + assert len(rows) == 4 + + metadata = result.metadata() + assert len(metadata.warnings()) == 0 + assert len(metadata.request_id()) > 0 + test_env.assert_streaming_response_state(result) + + def test_query_named_parameters( + self, + request: pytest.FixtureRequest, + test_env: BlockingTestEnvironment, + query_statement_named_params_limit2: str, + ) -> None: + named_parameters: Dict[str, Any] = {'country': 'United States'} + q_handle = test_env.cluster_or_scope.start_query( + query_statement_named_params_limit2, StartQueryOptions(named_parameters=named_parameters) + ) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_named_parameters_no_options( + self, + request: pytest.FixtureRequest, + test_env: BlockingTestEnvironment, + query_statement_named_params_limit2: str, + ) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_named_params_limit2, country='United States') + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_named_parameters_override( + self, + request: pytest.FixtureRequest, + test_env: BlockingTestEnvironment, + query_statement_named_params_limit2: str, + ) -> None: + q_handle = test_env.cluster_or_scope.start_query( + query_statement_named_params_limit2, + StartQueryOptions(named_parameters={'country': 'abcdefg'}), + country='United States', + ) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_passthrough_deserializer( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment + ) -> None: + statement = 'FROM range(0, 10) AS num SELECT *' + q_handle = test_env.cluster_or_scope.start_query(statement) + result_handle, _ = test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + request.addfinalizer(result_handle.discard_results) + result = result_handle.fetch_results(FetchResultsOptions(deserializer=PassthroughDeserializer())) + for idx, row in enumerate(result.rows()): + assert isinstance(row, bytes) + assert json.loads(row) == {'num': idx} + test_env.assert_streaming_response_state(result) + result_handle.discard_results() + + def test_query_positional_params( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, StartQueryOptions(positional_parameters=['United States']) + ) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_positional_params_no_option( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_pos_params_limit2, 'United States') + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_positional_params_override( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + q_handle = test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, StartQueryOptions(positional_parameters=['abcdefg']), 'United States' + ) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_raises_exception_prior_to_iterating(self, test_env: BlockingTestEnvironment) -> None: + statement = "I'm not N1QL!" + with pytest.raises(QueryError): + test_env.cluster_or_scope.start_query(statement) + + def test_query_raw_options( + self, request: pytest.FixtureRequest, test_env: BlockingTestEnvironment, query_statement_pos_params_limit2: str + ) -> None: + # via raw, we should be able to pass any option + # if using named params, need to match full name param in query + # which is different for when we pass in name_parameters via their specific + # query option (i.e. include the $ when using raw) + if test_env.use_scope: + statement = f'SELECT * FROM {test_env.collection_name} WHERE country = $country LIMIT $1;' + else: + statement = f'SELECT * FROM {test_env.fqdn} WHERE country = $country LIMIT $1;' + + q_handle = test_env.cluster_or_scope.start_query( + statement, StartQueryOptions(raw={'$country': 'United States', 'args': [2]}) + ) + result_handle, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle.discard_results) + + test_env.assert_rows(result, 2) + + q_handle = test_env.cluster_or_scope.start_query( + query_statement_pos_params_limit2, StartQueryOptions(raw={'args': ['United States']}) + ) + result_handle1, result = test_env.wait_for_query_results(q_handle) + assert result is not None + request.addfinalizer(result_handle1.discard_results) + test_env.assert_rows(result, 2) + test_env.assert_streaming_response_state(result) + + def test_query_results(self, test_env: BlockingTestEnvironment, query_statement_limit5: str) -> None: + q_handle = test_env.cluster_or_scope.start_query(query_statement_limit5) + result_handle, _ = test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + result = result_handle.fetch_results() + test_env.assert_rows(result, 5) + # fetch results it again + result = result_handle.fetch_results() + test_env.assert_rows(result, 5) + # now discard results + result_handle.discard_results() + # fetching results after discarding should raise + with pytest.raises(QueryNotFoundError): + result_handle.fetch_results() + + def test_query_status_not_found(self, test_env: BlockingTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 1000) AS some_field;' + q_handle = test_env.cluster_or_scope.start_query(statement) + + result_handle, _ = test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + result_handle.discard_results() + + with pytest.raises(QueryNotFoundError): + q_handle.fetch_status() + + def test_query_status_prior_to_results(self, test_env: BlockingTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 1000) AS some_field;' + q_handle = test_env.cluster_or_scope.start_query(statement) + assert isinstance(q_handle, BlockingQueryHandle) + q_status = q_handle.fetch_status() + assert isinstance(q_status, BlockingQueryStatus) + assert q_status.results_ready() is False + with pytest.raises(AnalyticsError): + q_status.result_handle() + + # lets clean up the query + result_handle, _ = test_env.wait_for_query_results(q_handle, return_only_result_handle=True) + result_handle.discard_results() + + def test_query_timeout(self, test_env: BlockingTestEnvironment) -> None: + statement = 'SELECT sleep("some value", 10000) AS some_field;' + q_handle = test_env.cluster_or_scope.start_query(statement, StartQueryOptions(timeout=timedelta(seconds=2))) + BlockingTestEnvironment.try_n_times_till_exception(10, 2, q_handle.fetch_status) + with pytest.raises(TimeoutError): + q_handle.fetch_status() + + +class ClusterStartQueryTests(QueryTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ClusterStartQueryTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ClusterStartQueryTests) if valid_test_method(meth)] + test_list = set(QueryTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='test_env') + def couchbase_test_environment( + self, sync_test_env: BlockingTestEnvironment + ) -> YieldFixture[BlockingTestEnvironment]: + sync_test_env.setup() + yield sync_test_env + sync_test_env.teardown() + + +class ScopeStartQueryTests(QueryTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ScopeStartQueryTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ScopeStartQueryTests) if valid_test_method(meth)] + test_list = set(QueryTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='test_env') + def couchbase_test_environment( + self, sync_test_env: BlockingTestEnvironment + ) -> YieldFixture[BlockingTestEnvironment]: + sync_test_env.setup() + test_env = sync_test_env.enable_scope() + yield test_env + test_env.disable_scope() + test_env.teardown() diff --git a/couchbase_analytics/tests/start_query_options_t.py b/couchbase_analytics/tests/start_query_options_t.py new file mode 100644 index 0000000..14e8b5a --- /dev/null +++ b/couchbase_analytics/tests/start_query_options_t.py @@ -0,0 +1,274 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List, Optional, Union + +import pytest + +from couchbase_analytics import JSONType +from couchbase_analytics.credential import Credential +from couchbase_analytics.options import StartQueryOptions, StartQueryOptionsKwargs +from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter +from couchbase_analytics.protocol._core.request import _RequestBuilder +from couchbase_analytics.protocol.options import StartQueryOptionsTransformedKwargs + + +@dataclass +class QueryContext: + database_name: Optional[str] = None + scope_name: Optional[str] = None + + def validate_query_context(self, body: Dict[str, Union[str, object]]) -> None: + if self.database_name is None or self.scope_name is None: + with pytest.raises(KeyError): + body['query_context'] + else: + assert body['query_context'] == f'default:`{self.database_name}`.`{self.scope_name}`' + + +class StartQueryOptionsTestSuite: + TEST_MANIFEST = [ + 'test_options_max_retries', + 'test_options_max_retries_kwargs', + 'test_options_named_parameters', + 'test_options_named_parameters_kwargs', + 'test_options_positional_parameters', + 'test_options_positional_parameters_kwargs', + 'test_options_raw', + 'test_options_raw_kwargs', + 'test_options_readonly', + 'test_options_readonly_kwargs', + 'test_options_scan_consistency', + 'test_options_scan_consistency_kwargs', + 'test_options_timeout', + 'test_options_timeout_kwargs', + 'test_options_timeout_must_be_positive', + 'test_options_timeout_must_be_positive_kwargs', + ] + + @pytest.fixture(scope='class') + def query_statment(self) -> str: + return 'SELECT * FROM default' + + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) + def test_options_max_retries( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] + ) -> None: + if max_retries is not None: + q_opts = StartQueryOptions(max_retries=max_retries) + req = request_builder.build_start_query_request(query_statment, q_opts) + else: + req = request_builder.build_start_query_request(query_statment) + exp_opts: StartQueryOptionsTransformedKwargs = {} + assert req.options == exp_opts + assert req.max_retries == (max_retries if max_retries is not None else 7) + query_ctx.validate_query_context(req.body) + + @pytest.mark.parametrize('max_retries', [5, 10, 0, None]) + def test_options_max_retries_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext, max_retries: Optional[int] + ) -> None: + if max_retries is not None: + kwargs: StartQueryOptionsKwargs = {'max_retries': max_retries} + req = request_builder.build_start_query_request(query_statment, **kwargs) + else: + req = request_builder.build_start_query_request(query_statment) + exp_opts: StartQueryOptionsTransformedKwargs = {} + assert req.options == exp_opts + assert req.max_retries == (max_retries if max_retries is not None else 7) + query_ctx.validate_query_context(req.body) + + def test_options_named_parameters( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} + q_opts = StartQueryOptions(named_parameters=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'named_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_named_parameters_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} + kwargs: StartQueryOptionsKwargs = {'named_parameters': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'named_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_positional_parameters( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: List[JSONType] = ['foo', 'bar', 1, False] + q_opts = StartQueryOptions(positional_parameters=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'positional_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_positional_parameters_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + params: List[JSONType] = ['foo', 'bar', 1, False] + kwargs: StartQueryOptionsKwargs = {'positional_parameters': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'positional_parameters': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext) -> None: + pos_params: List[JSONType] = ['foo', 'bar', 1, False] + params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} + q_opts = StartQueryOptions(raw=params) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'raw': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_raw_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + pos_params: List[JSONType] = ['foo', 'bar', 1, False] + params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} + kwargs: StartQueryOptionsKwargs = {'raw': params} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'raw': params} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_readonly( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + q_opts = StartQueryOptions(readonly=True) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'readonly': True} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_readonly_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + kwargs: StartQueryOptionsKwargs = {'readonly': True} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'readonly': True} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_scan_consistency( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + from couchbase_analytics.query import QueryScanConsistency + + q_opts = StartQueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_scan_consistency_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + from couchbase_analytics.query import QueryScanConsistency + + kwargs: StartQueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} + assert req.options == exp_opts + query_ctx.validate_query_context(req.body) + + def test_options_timeout( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + q_opts = StartQueryOptions(timeout=timedelta(seconds=20)) + req = request_builder.build_start_query_request(query_statment, q_opts) + exp_opts: StartQueryOptionsTransformedKwargs = {'timeout': 20.0} + assert req.options == exp_opts + # NOTE: we add time to the server timeout to ensure a client side timeout + assert req.body['timeout'] == '25000.0ms' + query_ctx.validate_query_context(req.body) + + def test_options_timeout_kwargs( + self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext + ) -> None: + kwargs: StartQueryOptionsKwargs = {'timeout': timedelta(seconds=20)} + req = request_builder.build_start_query_request(query_statment, **kwargs) + exp_opts: StartQueryOptionsTransformedKwargs = {'timeout': 20.0} + assert req.options == exp_opts + # NOTE: we add time to the server timeout to ensure a client side timeout + assert req.body['timeout'] == '25000.0ms' + query_ctx.validate_query_context(req.body) + + def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: + q_opts = StartQueryOptions(timeout=timedelta(seconds=-1)) + with pytest.raises(ValueError): + request_builder.build_start_query_request(query_statment, q_opts) + + def test_options_timeout_must_be_positive_kwargs( + self, query_statment: str, request_builder: _RequestBuilder + ) -> None: + kwargs: StartQueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} + with pytest.raises(ValueError): + request_builder.build_start_query_request(query_statment, **kwargs) + + +class ClusterStartQueryOptionsTests(StartQueryOptionsTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ClusterStartQueryOptionsTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ClusterStartQueryOptionsTests) if valid_test_method(meth)] + test_list = set(StartQueryOptionsTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='query_ctx') + def query_context(self) -> QueryContext: + return QueryContext() + + @pytest.fixture(scope='class') + def request_builder(self) -> _RequestBuilder: + cred = Credential.from_username_and_password('Administrator', 'password') + return _RequestBuilder(_ClientAdapter('https://localhost', cred)) + + +class ScopeStartQueryOptionsTests(StartQueryOptionsTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(ScopeStartQueryOptionsTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(ScopeStartQueryOptionsTests) if valid_test_method(meth)] + test_list = set(StartQueryOptionsTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') + + @pytest.fixture(scope='class', name='query_ctx') + def query_context(self) -> QueryContext: + return QueryContext('test-database', 'test-scope') + + @pytest.fixture(scope='class') + def request_builder(self) -> _RequestBuilder: + cred = Credential.from_username_and_password('Administrator', 'password') + return _RequestBuilder(_ClientAdapter('https://localhost', cred), 'test-database', 'test-scope') diff --git a/couchbase_analytics/tests/test_server_t.py b/couchbase_analytics/tests/test_server_t.py index 3d43a77..844b5a0 100644 --- a/couchbase_analytics/tests/test_server_t.py +++ b/couchbase_analytics/tests/test_server_t.py @@ -18,7 +18,7 @@ from concurrent.futures import Future from datetime import timedelta -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import pytest @@ -133,13 +133,8 @@ def test_error_retriable_http503(self, test_env: BlockingTestEnvironment, analyt statement = 'SELECT "Hello, data!" AS greeting' allowed_retries = 5 q_opts = QueryOptions(max_retries=allowed_retries, timeout=timedelta(seconds=10)) - ex: Union[pytest.ExceptionInfo[AnalyticsError], pytest.ExceptionInfo[QueryError]] - if analytics_error: - with pytest.raises(QueryError) as ex: - test_env.cluster_or_scope.execute_query(statement, q_opts) - else: - with pytest.raises(AnalyticsError) as ex: - test_env.cluster_or_scope.execute_query(statement, q_opts) + with pytest.raises(AnalyticsError) as ex: + test_env.cluster_or_scope.execute_query(statement, q_opts) test_env.assert_error_context_num_attempts(allowed_retries + 1, ex.value._context) test_env.assert_error_context_contains_last_dispatch(ex.value._context) diff --git a/docs/acouchbase_analytics_api/acouchbase_analytics.rst b/docs/acouchbase_analytics_api/acouchbase_analytics.rst index 1fef7ee..ade12e8 100644 --- a/docs/acouchbase_analytics_api/acouchbase_analytics.rst +++ b/docs/acouchbase_analytics_api/acouchbase_analytics.rst @@ -12,6 +12,9 @@ Asynchronous API :doc:`query` API reference for query (SQL++) operations. +:doc:`query_handle` + API reference for server async requests operations. + :doc:`options` API reference for operation options. @@ -42,6 +45,7 @@ Asynchronous API acouchbase_analytics_core credential query + query_handle options results errors diff --git a/docs/acouchbase_analytics_api/acouchbase_analytics_core.rst b/docs/acouchbase_analytics_api/acouchbase_analytics_core.rst index f07415d..3acd714 100644 --- a/docs/acouchbase_analytics_api/acouchbase_analytics_core.rst +++ b/docs/acouchbase_analytics_api/acouchbase_analytics_core.rst @@ -22,6 +22,12 @@ AsyncCluster See :ref:`AsyncCluster Overloads` for details on overloaded methods. .. automethod:: execute_query + + .. important:: + See :ref:`AsyncCluster Overloads` for details on overloaded methods. + + .. automethod:: start_query + .. automethod:: shutdown @@ -34,6 +40,7 @@ AsyncDatabase .. autoproperty:: name .. automethod:: scope + AsyncScope ============== @@ -46,3 +53,8 @@ AsyncScope See :ref:`AsyncScope Overloads` for details on overloaded methods. .. automethod:: execute_query + + .. important:: + See :ref:`AsyncScope Overloads` for details on overloaded methods. + + .. automethod:: start_query diff --git a/docs/acouchbase_analytics_api/errors.rst b/docs/acouchbase_analytics_api/errors.rst index 4ffd789..ec88a9f 100644 --- a/docs/acouchbase_analytics_api/errors.rst +++ b/docs/acouchbase_analytics_api/errors.rst @@ -26,6 +26,11 @@ QueryError .. autoproperty:: code .. autoproperty:: server_message +QueryNotFoundError +++++++++++++++++++++++++++++++++ +.. autoclass:: QueryNotFoundError + :no-index: + TimeoutError ++++++++++++++++++++++++++++++++ .. autoclass:: TimeoutError diff --git a/docs/acouchbase_analytics_api/options.rst b/docs/acouchbase_analytics_api/options.rst index e120bf5..0c05dfd 100644 --- a/docs/acouchbase_analytics_api/options.rst +++ b/docs/acouchbase_analytics_api/options.rst @@ -33,6 +33,17 @@ QueryOptions .. autoclass:: QueryOptions :no-index: +StartQueryOptions +++++++++++++++++++++++ +.. autoclass:: StartQueryOptions + :no-index: + +FetchResultsOptions +++++++++++++++++++++++ +.. autoclass:: FetchResultsOptions + :no-index: + + Option TypeDict Classes ========================= @@ -63,3 +74,17 @@ QueryOptionsKwargs :no-index: :members: :undoc-members: + +StartQueryOptionsKwargs ++++++++++++++++++++++++ +.. autoclass:: StartQueryOptionsKwargs + :no-index: + :members: + :undoc-members: + +FetchResultsOptionsKwargs ++++++++++++++++++++++++++ +.. autoclass:: FetchResultsOptionsKwargs + :no-index: + :members: + :undoc-members: diff --git a/docs/acouchbase_analytics_api/overloads/async_cluster_overloads.rst b/docs/acouchbase_analytics_api/overloads/async_cluster_overloads.rst index 4949db0..a0e2ef9 100644 --- a/docs/acouchbase_analytics_api/overloads/async_cluster_overloads.rst +++ b/docs/acouchbase_analytics_api/overloads/async_cluster_overloads.rst @@ -25,7 +25,7 @@ AsyncCluster execute_query(statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult] :no-index: - Executes a query against a Capella analytics cluster. + Executes a query against an Analytics cluster. .. important:: The cancel API is **VOLATILE** and is subject to change at any time. @@ -43,6 +43,30 @@ AsyncCluster :returns: An `Awaitable` is returned. Once the `Awaitable` completes, an instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult` will be available. :rtype: Awaitable[:class:`~acouchbase_analytics.result.AsyncQueryResult`] + .. py:method:: start_query(statement: str) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle] + :no-index: + + Executes a query against an Analytics cluster using the asynchronous server requests API. + + :param statement: The SQL++ statement to execute. + :type statement: str + :param options: Options to set for the query. + :type options: Optional[:class:`~acouchbase_analytics.options.StartQueryOptions`] + :param \*args: Can be used to pass in positional query placeholders. + :type \*args: Optional[:py:type:`~acouchbase_analytics.JSONType`] + :param \*\*kwargs: Keyword arguments that can be used in place or to overrride provided :class:`~acouchbase_analytics.options.StartClusterOptions`. + Can also be used to pass in named query placeholders. + :type \*\*kwargs: Optional[Union[:class:`~acouchbase_analytics.options.StartQueryOptionsKwargs`, str]] + + :returns: An `Awaitable` is returned. Once the `Awaitable` completes, an instance of a :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle` will be available. + :rtype: Awaitable[:class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`] + .. py:method:: create_instance(endpoint: str, credential: Credential) -> AsyncCluster create_instance(endpoint: str, credential: Credential, options: ClusterOptions) -> AsyncCluster create_instance(endpoint: str, credential: Credential, **kwargs: ClusterOptionsKwargs) -> AsyncCluster diff --git a/docs/acouchbase_analytics_api/overloads/async_scope_overloads.rst b/docs/acouchbase_analytics_api/overloads/async_scope_overloads.rst index c84a135..85f6955 100644 --- a/docs/acouchbase_analytics_api/overloads/async_scope_overloads.rst +++ b/docs/acouchbase_analytics_api/overloads/async_scope_overloads.rst @@ -16,16 +16,16 @@ AsyncScope .. py:class:: AysncScope :no-index: - .. py:method:: execute_query(statement: str) -> Future[AsyncQueryResult] - execute_query(statement: str, options: QueryOptions) -> Future[AsyncQueryResult] - execute_query(statement: str, **kwargs: QueryOptionsKwargs) -> Future[AsyncQueryResult] - execute_query(statement: str, options: QueryOptions, **kwargs: QueryOptionsKwargs) -> BlockingQueryResult - execute_query(statement: str, options: QueryOptions, *args: JSONType, **kwargs: QueryOptionsKwargs) -> Future[AsyncQueryResult] - execute_query(statement: str, options: QueryOptions, *args: JSONType, **kwargs: str) -> Future[AsyncQueryResult] - execute_query(statement: str, *args: JSONType, **kwargs: str) -> Future[AsyncQueryResult] + .. py:method:: execute_query(statement: str) -> Awaitable[AsyncQueryResult] + execute_query(statement: str, options: QueryOptions) -> Awaitable[AsyncQueryResult] + execute_query(statement: str, **kwargs: QueryOptionsKwargs) -> Awaitable[AsyncQueryResult] + execute_query(statement: str, options: QueryOptions, **kwargs: QueryOptionsKwargs) -> Awaitable[AysncQueryResult] + execute_query(statement: str, options: QueryOptions, *args: JSONType, **kwargs: QueryOptionsKwargs) -> Awaitable[AsyncQueryResult] + execute_query(statement: str, options: QueryOptions, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult] + execute_query(statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult] :no-index: - Executes a query against a Capella analytics scope. + Executes a query against an analytics scope. .. important:: The cancel API is **VOLATILE** and is subject to change at any time. @@ -40,5 +40,29 @@ AsyncScope Can also be used to pass in named query placeholders. :type \*\*kwargs: Optional[Union[:class:`~acouchbase_analytics.options.QueryOptionsKwargs`, str]] - :returns: A :class:`~asyncio.Future` is returned. Once the :class:`~asyncio.Future` completes, an instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult` will be available. - :rtype: Future[:class:`~acouchbase_analytics.result.AsyncQueryResult`] + :returns: :class:`~couchbase_analytics.result.AsyncQueryResult`: An instance of a :class:`~acouchbase_analytics.result.AsyncQueryResult`. + :rtype: Awaitable[:class:`~acouchbase_analytics.result.AsyncQueryResult`] + + .. py:method:: start_query(statement: str) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: StartQueryOptionsKwargs) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle] + start_query(statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle] + :no-index: + + Executes a query against an analytics scope using the asynchronous server requests API. + + :param statement: The SQL++ statement to execute. + :type statement: str + :param options: Options to set for the query. + :type options: Optional[:class:`~acouchbase_analytics.options.StartQueryOptions`] + :param \*args: Can be used to pass in positional query placeholders. + :type \*args: Optional[:py:type:`~acouchbase_analytics.JSONType`] + :param \*\*kwargs: Keyword arguments that can be used in place or to overrride provided :class:`~acouchbase_analytics.options.StartClusterOptions`. + Can also be used to pass in named query placeholders. + :type \*\*kwargs: Optional[Union[:class:`~acouchbase_analytics.options.StartQueryOptionsKwargs`, str]] + + :returns: :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`: An instance of a :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle` + :rtype: Awaitable[:class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`] diff --git a/docs/acouchbase_analytics_api/query.rst b/docs/acouchbase_analytics_api/query.rst index 6d9869c..a87be50 100644 --- a/docs/acouchbase_analytics_api/query.rst +++ b/docs/acouchbase_analytics_api/query.rst @@ -24,6 +24,9 @@ Options .. autoclass:: QueryOptions :no-index: +.. autoclass:: StartQueryOptions + :no-index: + Results =============== diff --git a/docs/acouchbase_analytics_api/query_handle.rst b/docs/acouchbase_analytics_api/query_handle.rst new file mode 100644 index 0000000..c708143 --- /dev/null +++ b/docs/acouchbase_analytics_api/query_handle.rst @@ -0,0 +1,32 @@ +========================= +Server Async Request API +========================= + +.. contents:: + :local: + +.. module:: acouchbase_analytics.query_handle + +AsyncQueryHandle ++++++++++++++++++++ + +.. py:class:: AsyncQueryHandle + + .. automethod:: fetch_status + .. automethod:: cancel + +AsyncQueryResultHandle +++++++++++++++++++++++++ + +.. py:class:: AsyncQueryResultHandle + + .. automethod:: fetch_results + .. automethod:: discard_results + +AsyncQueryStatus ++++++++++++++++++++ + +.. py:class:: AsyncQueryStatus + + .. automethod:: results_ready + .. automethod:: result_handle diff --git a/docs/couchbase_analytics_api/couchbase_analytics.rst b/docs/couchbase_analytics_api/couchbase_analytics.rst index fc022bc..2ddb737 100644 --- a/docs/couchbase_analytics_api/couchbase_analytics.rst +++ b/docs/couchbase_analytics_api/couchbase_analytics.rst @@ -11,6 +11,9 @@ Synchronous API :doc:`query` API reference for query (SQL++) operations. +:doc:`query_handle` + API reference for server async requests operations. + :doc:`options` API reference for operation options. @@ -41,6 +44,7 @@ Synchronous API couchbase_analytics_core credential query + query_handle options results errors diff --git a/docs/couchbase_analytics_api/couchbase_analytics_core.rst b/docs/couchbase_analytics_api/couchbase_analytics_core.rst index 009ade0..9506449 100644 --- a/docs/couchbase_analytics_api/couchbase_analytics_core.rst +++ b/docs/couchbase_analytics_api/couchbase_analytics_core.rst @@ -21,6 +21,11 @@ Cluster See :ref:`Cluster Overloads` for details on overloaded methods. .. automethod:: execute_query + + .. important:: + See :ref:`Cluster Overloads` for details on overloaded methods. + + .. automethod:: start_query .. automethod:: shutdown @@ -45,3 +50,8 @@ Scope See :ref:`Scope Overloads` for details on overloaded methods. .. automethod:: execute_query + + .. important:: + See :ref:`Scope Overloads` for details on overloaded methods. + + .. automethod:: start_query diff --git a/docs/couchbase_analytics_api/errors.rst b/docs/couchbase_analytics_api/errors.rst index b87f371..e3adf92 100644 --- a/docs/couchbase_analytics_api/errors.rst +++ b/docs/couchbase_analytics_api/errors.rst @@ -23,6 +23,10 @@ QueryError .. autoproperty:: code .. autoproperty:: server_message +QueryNotFoundError +++++++++++++++++++++++++++++++++ +.. autoclass:: QueryNotFoundError + TimeoutError ++++++++++++++++++++++++++++++++ .. autoclass:: TimeoutError diff --git a/docs/couchbase_analytics_api/options.rst b/docs/couchbase_analytics_api/options.rst index f035605..0a83e98 100644 --- a/docs/couchbase_analytics_api/options.rst +++ b/docs/couchbase_analytics_api/options.rst @@ -29,6 +29,14 @@ QueryOptions ++++++++++++++++++++++ .. autoclass:: QueryOptions +StartQueryOptions +++++++++++++++++++++++ +.. autoclass:: StartQueryOptions + +FetchResultsOptions +++++++++++++++++++++++ +.. autoclass:: FetchResultsOptions + Option TypeDict Classes ========================= @@ -56,3 +64,15 @@ QueryOptionsKwargs .. autoclass:: QueryOptionsKwargs :members: :undoc-members: + +StartQueryOptionsKwargs ++++++++++++++++++++++++ +.. autoclass:: StartQueryOptionsKwargs + :members: + :undoc-members: + +FetchResultsOptionsKwargs ++++++++++++++++++++++++++ +.. autoclass:: FetchResultsOptionsKwargs + :members: + :undoc-members: diff --git a/docs/couchbase_analytics_api/overloads/cluster_overloads.rst b/docs/couchbase_analytics_api/overloads/cluster_overloads.rst index 41e53ea..43d8f36 100644 --- a/docs/couchbase_analytics_api/overloads/cluster_overloads.rst +++ b/docs/couchbase_analytics_api/overloads/cluster_overloads.rst @@ -57,6 +57,29 @@ Cluster a :class:`~concurrent.futures.Future` is returned. Once the :class:`~concurrent.futures.Future` completes, an instance of a :class:`~couchbase_analytics.result.BlockingQueryResult` will be available. :rtype: Union[Future[:class:`~couchbase_analytics.result.BlockingQueryResult`], :class:`~couchbase_analytics.result.BlockingQueryResult`] + .. py:method:: start_query(statement: str) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions) -> BlockingQueryHandle + start_query(statement: str, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str) -> BlockingQueryHandle + start_query(statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle + :no-index: + + Executes a query against an Analytics cluster using the asynchronous server requests API. + + :param statement: The SQL++ statement to execute. + :type statement: str + :param options: Options to set for the query. + :type options: Optional[:class:`~couchbase_analytics.options.StartQueryOptions`] + :type \*args: Optional[:py:type:`~couchbase_analytics.JSONType`] + :param \*\*kwargs: Keyword arguments that can be used in place or to overrride provided :class:`~couchbase_analytics.options.StartClusterOptions`. + Can also be used to pass in named query placeholders. + :type \*\*kwargs: Optional[Union[:class:`~couchbase_analytics.options.StartQueryOptionsKwargs`, str]] + + :returns: An instance of :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`. + :rtype: :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` + .. py:method:: create_instance(endpoint: str, credential: Credential) -> Cluster create_instance(endpoint: str, credential: Credential, options: ClusterOptions) -> Cluster create_instance(endpoint: str, credential: Credential, **kwargs: ClusterOptionsKwargs) -> Cluster diff --git a/docs/couchbase_analytics_api/overloads/scope_overloads.rst b/docs/couchbase_analytics_api/overloads/scope_overloads.rst index 40c58d9..4479535 100644 --- a/docs/couchbase_analytics_api/overloads/scope_overloads.rst +++ b/docs/couchbase_analytics_api/overloads/scope_overloads.rst @@ -56,3 +56,26 @@ Scope :returns: An instance of :class:`~couchbase_analytics.result.BlockingQueryResult`. When a cancel token is provided a :class:`~concurrent.futures.Future` is returned. Once the :class:`~concurrent.futures.Future` completes, an instance of a :class:`~couchbase_analytics.result.BlockingQueryResult` will be available. :rtype: Union[Future[:class:`~couchbase_analytics.result.BlockingQueryResult`], :class:`~couchbase_analytics.result.BlockingQueryResult`] + + .. py:method:: start_query(statement: str) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions) -> BlockingQueryHandle + start_query(statement: str, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: StartQueryOptionsKwargs) -> BlockingQueryHandle + start_query(statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str) -> BlockingQueryHandle + start_query(statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle + :no-index: + + Executes a query against an Analytics scope using the asynchronous server requests API. + + :param statement: The SQL++ statement to execute. + :type statement: str + :param options: Options to set for the query. + :type options: Optional[:class:`~couchbase_analytics.options.StartQueryOptions`] + :type \*args: Optional[:py:type:`~couchbase_analytics.JSONType`] + :param \*\*kwargs: Keyword arguments that can be used in place or to overrride provided :class:`~couchbase_analytics.options.StartClusterOptions`. + Can also be used to pass in named query placeholders. + :type \*\*kwargs: Optional[Union[:class:`~couchbase_analytics.options.StartQueryOptionsKwargs`, str]] + + :returns: An instance of :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`. + :rtype: :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` diff --git a/docs/couchbase_analytics_api/query_handle.rst b/docs/couchbase_analytics_api/query_handle.rst new file mode 100644 index 0000000..595ed0d --- /dev/null +++ b/docs/couchbase_analytics_api/query_handle.rst @@ -0,0 +1,32 @@ +========================= +Server Async Request API +========================= + +.. contents:: + :local: + +.. module:: couchbase_analytics.query_handle + +BlockingQueryHandle ++++++++++++++++++++ + +.. py:class:: BlockingQueryHandle + + .. automethod:: fetch_status + .. automethod:: cancel + +BlockingQueryResultHandle ++++++++++++++++++++++++++ + +.. py:class:: BlockingQueryResultHandle + + .. automethod:: fetch_results + .. automethod:: discard_results + +BlockingQueryStatus ++++++++++++++++++++ + +.. py:class:: BlockingQueryStatus + + .. automethod:: results_ready + .. automethod:: result_handle diff --git a/tests/environments/base_environment.py b/tests/environments/base_environment.py index b2126a4..7821e1d 100644 --- a/tests/environments/base_environment.py +++ b/tests/environments/base_environment.py @@ -20,8 +20,20 @@ import logging import pathlib import sys +import time from os import path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypedDict, + Union, +) if sys.version_info < (3, 11): from typing_extensions import Unpack @@ -32,11 +44,17 @@ import pytest from acouchbase_analytics.cluster import AsyncCluster +from acouchbase_analytics.protocol._core.anyio_utils import get_time, sleep +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle, AsyncQueryResultHandle from acouchbase_analytics.result import AsyncQueryResult from acouchbase_analytics.scope import AsyncScope from couchbase_analytics.cluster import Cluster +from couchbase_analytics.common.query_handle import AsyncQueryHandle as _CoreAsyncQueryHandle +from couchbase_analytics.common.query_handle import BlockingQueryHandle as _CoreBlockingQueryHandle +from couchbase_analytics.common.request import RequestState from couchbase_analytics.credential import Credential from couchbase_analytics.options import ClusterOptions, SecurityOptions +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle, BlockingQueryResultHandle from couchbase_analytics.result import BlockingQueryResult from couchbase_analytics.scope import Scope from tests import TEST_LOGGER_NAME, AnalyticsTestEnvironmentError @@ -312,6 +330,44 @@ def warmup_test_server(self) -> None: if exc is not None: raise exc + def wait_for_query_results( + self, + handle: _CoreBlockingQueryHandle, + delay: float = 2.5, + timeout: int = 120, + return_only_result_handle: Optional[bool] = False, + ) -> Tuple[BlockingQueryResultHandle, Optional[BlockingQueryResult]]: + assert isinstance(handle, BlockingQueryHandle) + expected_state = RequestState.Completed + assert handle._http_response._request_context.request_state == expected_state + + current_time = time.monotonic() + deadline = current_time + timeout # seconds + status = None + result_handle = None + while True: + try: + status = handle.fetch_status() + if status.results_ready(): + result_handle = status.result_handle() + break + except Exception: + raise + + current_time = time.monotonic() + delay_time = current_time + delay + if deadline < delay_time: + raise TimeoutError(f'Query results not ready within {timeout} seconds.') + + time.sleep(delay) + + assert isinstance(result_handle, BlockingQueryResultHandle) + if return_only_result_handle is True: + return result_handle, None + result = result_handle.fetch_results() + assert isinstance(result, BlockingQueryResult) + return result_handle, result + @classmethod def get_environment( cls, config: AnalyticsConfig, server_handler: Optional[WebServerHandler] = None @@ -351,6 +407,28 @@ def get_environment( env_opts['collection_name'] = config.collection_name return cls(config, **env_opts) + @staticmethod + def try_n_times_till_exception( + num_times: int, + seconds_between: Union[int, float], + func: Callable[..., Any], + *args: Any, + expected_exceptions: Tuple[Type[Exception], ...] = (Exception,), + raise_exception: Optional[bool] = False, + **kwargs: Any, + ) -> None: + for _ in range(num_times): + try: + func(*args, **kwargs) + time.sleep(seconds_between) + except expected_exceptions: + if raise_exception: + raise + # helpful to have this print statement when tests fail + return + except Exception: + raise + class AsyncTestEnvironment(TestEnvironment): def __init__(self, config: AnalyticsConfig, **kwargs: Unpack[TestEnvironmentOptionsKwargs]) -> None: @@ -516,6 +594,45 @@ def update_request_json(self, json: Dict[str, object]) -> None: raise AnalyticsTestEnvironmentError('No cluster available, cannot enable test server.') self._async_cluster._impl._client_adapter.update_request_json(json) + async def wait_for_query_results( + self, + handle: _CoreAsyncQueryHandle, + delay: float = 2.5, + timeout: int = 120, + return_only_result_handle: Optional[bool] = False, + ) -> Tuple[AsyncQueryResultHandle, Optional[AsyncQueryResult]]: + assert isinstance(handle, AsyncQueryHandle) + expected_state = RequestState.Completed + assert handle._http_response._request_context.request_state == expected_state + + current_time = get_time() + deadline = current_time + timeout # seconds + status = None + result_handle = None + while True: + try: + status = await handle.fetch_status() + if status.results_ready(): + result_handle = status.result_handle() + break + except Exception as ex: + logger.error(f'Error while fetching query status: {ex}') + raise + + current_time = get_time() + delay_time = current_time + delay + if deadline < delay_time: + raise TimeoutError(f'Query results not ready within {timeout} seconds.') + + await sleep(delay) + + assert isinstance(result_handle, AsyncQueryResultHandle) + if return_only_result_handle is True: + return result_handle, None + result = await result_handle.fetch_results() + assert isinstance(result, AsyncQueryResult) + return result_handle, result + async def warmup_test_server(self) -> None: row_count = 5 self.set_url_path('/test_results') @@ -575,6 +692,28 @@ def get_environment( env_opts['collection_name'] = config.collection_name return cls(config, **env_opts) + @staticmethod + async def try_n_times_till_exception( + num_times: int, + seconds_between: Union[int, float], + func: Callable[..., Any], + *args: Any, + expected_exceptions: Tuple[Type[Exception], ...] = (Exception,), + raise_exception: Optional[bool] = False, + **kwargs: Any, + ) -> None: + for _ in range(num_times): + try: + await func(*args, **kwargs) + await sleep(seconds_between) + except expected_exceptions as ex: + if raise_exception: + raise + logger.error(f'Caught expected exception(s) {ex} from function {func.__name__}, as expected.') + return + except Exception: + raise + @pytest.fixture(scope='class', name='sync_test_env') def base_test_environment(analytics_config: AnalyticsConfig) -> BlockingTestEnvironment: diff --git a/tests/utils/_async_client_adapter.py b/tests/utils/_async_client_adapter.py index 5043edb..70aee69 100644 --- a/tests/utils/_async_client_adapter.py +++ b/tests/utils/_async_client_adapter.py @@ -14,12 +14,12 @@ # limitations under the License. -from typing import Dict +from typing import Dict, Optional, Union from httpx import URL, Response from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -38,7 +38,11 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._http_transport_cls = adapter._http_transport_cls -async def send_request_override(self: _AsyncClientAdapter, request: QueryRequest) -> Response: +async def send_request_override( + self: _AsyncClientAdapter, + request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], + stream: Optional[bool] = True, +) -> Response: if not hasattr(self, '_client'): raise RuntimeError('Client not created yet') @@ -56,7 +60,9 @@ async def send_request_override(self: _AsyncClientAdapter, request: QueryRequest url = URL(scheme=request.url.scheme, host=request.url.host, port=request.url.port, path=request.url.path) req = self._client.build_request(request.method, url, json=request_json, extensions=request_extensions) - return await self._client.send(req, stream=True) + if stream is None: + stream = True + return await self._client.send(req, stream=stream) def set_request_path(self: _AsyncClientAdapter, path: str) -> None: diff --git a/tests/utils/_client_adapter.py b/tests/utils/_client_adapter.py index 0acf76d..7e6cbe8 100644 --- a/tests/utils/_client_adapter.py +++ b/tests/utils/_client_adapter.py @@ -14,12 +14,12 @@ # limitations under the License. -from typing import Dict +from typing import Dict, Optional, Union from httpx import URL, Response from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -39,7 +39,11 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._http_transport_cls = adapter._http_transport_cls -def send_request_override(self: _ClientAdapter, request: QueryRequest) -> Response: +def send_request_override( + self: _ClientAdapter, + request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], + stream: Optional[bool] = True, +) -> Response: if not hasattr(self, '_client'): raise RuntimeError('Client not created yet') @@ -57,7 +61,7 @@ def send_request_override(self: _ClientAdapter, request: QueryRequest) -> Respon url = URL(scheme=request.url.scheme, host=request.url.host, port=request.url.port, path=request.url.path) req = self._client.build_request(request.method, url, json=request_json, extensions=request_extensions) - return self._client.send(req, stream=True) + return self._client.send(req, stream=stream) def set_request_path(self: _ClientAdapter, path: str) -> None: