diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index dee8f3f554..3526bb22a1 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -14,6 +14,7 @@ from __future__ import annotations +import enum import logging import os from typing import cast @@ -21,6 +22,7 @@ from typing import TYPE_CHECKING import google.auth +from google.auth.transport import mtls from opentelemetry.sdk._logs import LogRecordProcessor from opentelemetry.sdk._logs.export import BatchLogRecordProcessor from opentelemetry.sdk.metrics.export import MetricReader @@ -40,6 +42,19 @@ _GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' _DEFAULT_LOG_NAME = 'adk-otel' +_DEFAULT_TELEMETRY_ENDPOINT = 'https://telemetry.googleapis.com/v1/traces' +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = ( + 'https://telemetry.mtls.googleapis.com/v1/traces' +) + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = 'auto' + ALWAYS = 'always' + NEVER = 'never' + def get_gcp_exporters( enable_cloud_tracing: bool = False, @@ -100,10 +115,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: from google.auth.transport.requests import AuthorizedSession from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + session = AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + return BatchSpanProcessor( OTLPSpanExporter( - session=AuthorizedSession(credentials=credentials), - endpoint='https://telemetry.googleapis.com/v1/traces', + session=session, + endpoint=endpoint, ) ) @@ -158,3 +187,58 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: ' GCE, GKE or CloudRun related resource attributes may be missing' ) return resource + + +def _get_api_endpoint(client_cert_source: bytes | None = None) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (bytes | None): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + 'GOOGLE_API_USE_MTLS_ENDPOINT', MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + logger.warning( + 'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of ' + f'{[e.value for e in MtlsEndpoint]}. Defaulting to' + f' {MtlsEndpoint.AUTO.value}.' + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if (use_mtls_endpoint == MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint == MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + 'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false' + ).lower() + if use_client_cert_str not in ('true', 'false'): + logger.warning( + 'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be' + ' either `true` or `false`' + ) + return use_client_cert_str == 'true' diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 0199e7b4b6..7a5457f1c8 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -16,6 +16,11 @@ from typing import Optional from unittest import mock +from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_ENDPOINT +from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_ENDPOINT +from google.adk.telemetry.google_cloud import _get_api_endpoint +from google.adk.telemetry.google_cloud import _get_gcp_span_exporter +from google.adk.telemetry.google_cloud import _use_client_cert_effective from google.adk.telemetry.google_cloud import get_gcp_exporters from google.adk.telemetry.google_cloud import get_gcp_resource import pytest @@ -89,3 +94,86 @@ def test_get_gcp_resource( otel_resource.attributes.get("gcp.project_id", None) == expected_project_id ) + + +@mock.patch("google.auth.transport.mtls.should_use_client_cert") +def test_use_client_cert_effective_from_mtls(mock_should_use): + mock_should_use.return_value = True + assert _use_client_cert_effective() + + mock_should_use.return_value = False + assert not _use_client_cert_effective() + + +def test_use_client_cert_effective_from_env(monkeypatch, caplog): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", + side_effect=AttributeError, + ): + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + assert _use_client_cert_effective() + + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + assert not _use_client_cert_effective() + + # Test invalid value defaults to False + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe") + assert not _use_client_cert_effective() + assert ( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + in caplog.text + ) + + +@pytest.mark.parametrize( + "env_val, cert_source, expected", + [ + ("auto", b"cert", _DEFAULT_MTLS_TELEMETRY_ENDPOINT), + ("auto", None, _DEFAULT_TELEMETRY_ENDPOINT), + ("always", None, _DEFAULT_MTLS_TELEMETRY_ENDPOINT), + ("never", b"cert", _DEFAULT_TELEMETRY_ENDPOINT), + ("invalid", None, _DEFAULT_TELEMETRY_ENDPOINT), + ], +) +def test_get_api_endpoint(env_val, cert_source, expected, monkeypatch, caplog): + monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val) + if env_val == "invalid": + assert _get_api_endpoint(cert_source) == expected + assert ( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of" + in caplog.text + ) + else: + assert _get_api_endpoint(cert_source) == expected + + +@mock.patch("google.auth.transport.requests.AuthorizedSession") +@mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" +) +@mock.patch("google.adk.telemetry.google_cloud.BatchSpanProcessor") +@mock.patch("google.adk.telemetry.google_cloud._use_client_cert_effective") +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source") +@mock.patch("google.auth.transport.mtls.default_client_cert_source") +def test_get_gcp_span_exporter_mtls( + mock_default_cert, + mock_has_cert, + mock_use_cert, + mock_batch, + mock_exporter, + mock_session, +): + credentials = mock.Mock() + mock_use_cert.return_value = True + mock_has_cert.return_value = True + mock_default_cert.return_value = b"cert" + + _get_gcp_span_exporter(credentials) + + mock_session.assert_called_once_with(credentials=credentials) + mock_session.return_value.configure_mtls_channel.assert_called_once() + mock_exporter.assert_called_once_with( + session=mock_session.return_value, + endpoint=_DEFAULT_MTLS_TELEMETRY_ENDPOINT, + )