From f254569fce624c5718eafecdc584a5d681991626 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 6 May 2026 15:28:35 -0700 Subject: [PATCH] feat: Add mTLS support to Google Cloud Telemetry exporter This change enables the Google Cloud Telemetry exporter to use mTLS endpoints. It checks for the availability of client certificates and respects the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variables to determine whether to use the mTLS-specific endpoint and configure the session accordingly. PiperOrigin-RevId: 911581237 --- src/google/adk/telemetry/google_cloud.py | 88 ++++++++++++++++++- .../unittests/telemetry/test_google_cloud.py | 88 +++++++++++++++++++ 2 files changed, 174 insertions(+), 2 deletions(-) diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index dee8f3f554..3526bb22a1 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -14,6 +14,7 @@ from __future__ import annotations +import enum import logging import os from typing import cast @@ -21,6 +22,7 @@ from typing import TYPE_CHECKING import google.auth +from google.auth.transport import mtls from opentelemetry.sdk._logs import LogRecordProcessor from opentelemetry.sdk._logs.export import BatchLogRecordProcessor from opentelemetry.sdk.metrics.export import MetricReader @@ -40,6 +42,19 @@ _GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' _DEFAULT_LOG_NAME = 'adk-otel' +_DEFAULT_TELEMETRY_ENDPOINT = 'https://telemetry.googleapis.com/v1/traces' +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = ( + 'https://telemetry.mtls.googleapis.com/v1/traces' +) + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = 'auto' + ALWAYS = 'always' + NEVER = 'never' + def get_gcp_exporters( enable_cloud_tracing: bool = False, @@ -100,10 +115,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: from google.auth.transport.requests import AuthorizedSession from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + session = AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + return BatchSpanProcessor( OTLPSpanExporter( - session=AuthorizedSession(credentials=credentials), - endpoint='https://telemetry.googleapis.com/v1/traces', + session=session, + endpoint=endpoint, ) ) @@ -158,3 +187,58 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: ' GCE, GKE or CloudRun related resource attributes may be missing' ) return resource + + +def _get_api_endpoint(client_cert_source: bytes | None = None) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (bytes | None): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + 'GOOGLE_API_USE_MTLS_ENDPOINT', MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + logger.warning( + 'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of ' + f'{[e.value for e in MtlsEndpoint]}. Defaulting to' + f' {MtlsEndpoint.AUTO.value}.' + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if (use_mtls_endpoint == MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint == MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + 'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false' + ).lower() + if use_client_cert_str not in ('true', 'false'): + logger.warning( + 'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be' + ' either `true` or `false`' + ) + return use_client_cert_str == 'true' diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 0199e7b4b6..7a5457f1c8 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -16,6 +16,11 @@ from typing import Optional from unittest import mock +from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_ENDPOINT +from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_ENDPOINT +from google.adk.telemetry.google_cloud import _get_api_endpoint +from google.adk.telemetry.google_cloud import _get_gcp_span_exporter +from google.adk.telemetry.google_cloud import _use_client_cert_effective from google.adk.telemetry.google_cloud import get_gcp_exporters from google.adk.telemetry.google_cloud import get_gcp_resource import pytest @@ -89,3 +94,86 @@ def test_get_gcp_resource( otel_resource.attributes.get("gcp.project_id", None) == expected_project_id ) + + +@mock.patch("google.auth.transport.mtls.should_use_client_cert") +def test_use_client_cert_effective_from_mtls(mock_should_use): + mock_should_use.return_value = True + assert _use_client_cert_effective() + + mock_should_use.return_value = False + assert not _use_client_cert_effective() + + +def test_use_client_cert_effective_from_env(monkeypatch, caplog): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", + side_effect=AttributeError, + ): + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + assert _use_client_cert_effective() + + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + assert not _use_client_cert_effective() + + # Test invalid value defaults to False + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe") + assert not _use_client_cert_effective() + assert ( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + in caplog.text + ) + + +@pytest.mark.parametrize( + "env_val, cert_source, expected", + [ + ("auto", b"cert", _DEFAULT_MTLS_TELEMETRY_ENDPOINT), + ("auto", None, _DEFAULT_TELEMETRY_ENDPOINT), + ("always", None, _DEFAULT_MTLS_TELEMETRY_ENDPOINT), + ("never", b"cert", _DEFAULT_TELEMETRY_ENDPOINT), + ("invalid", None, _DEFAULT_TELEMETRY_ENDPOINT), + ], +) +def test_get_api_endpoint(env_val, cert_source, expected, monkeypatch, caplog): + monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val) + if env_val == "invalid": + assert _get_api_endpoint(cert_source) == expected + assert ( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of" + in caplog.text + ) + else: + assert _get_api_endpoint(cert_source) == expected + + +@mock.patch("google.auth.transport.requests.AuthorizedSession") +@mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" +) +@mock.patch("google.adk.telemetry.google_cloud.BatchSpanProcessor") +@mock.patch("google.adk.telemetry.google_cloud._use_client_cert_effective") +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source") +@mock.patch("google.auth.transport.mtls.default_client_cert_source") +def test_get_gcp_span_exporter_mtls( + mock_default_cert, + mock_has_cert, + mock_use_cert, + mock_batch, + mock_exporter, + mock_session, +): + credentials = mock.Mock() + mock_use_cert.return_value = True + mock_has_cert.return_value = True + mock_default_cert.return_value = b"cert" + + _get_gcp_span_exporter(credentials) + + mock_session.assert_called_once_with(credentials=credentials) + mock_session.return_value.configure_mtls_channel.assert_called_once() + mock_exporter.assert_called_once_with( + session=mock_session.return_value, + endpoint=_DEFAULT_MTLS_TELEMETRY_ENDPOINT, + )