Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions src/google/adk/telemetry/google_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import annotations

import enum
import logging
import os
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING

import google.auth
from google.auth.transport import mtls
from opentelemetry.sdk._logs import LogRecordProcessor
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
from opentelemetry.sdk.metrics.export import MetricReader
Expand All @@ -40,6 +42,19 @@
_GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME'
_DEFAULT_LOG_NAME = 'adk-otel'

_DEFAULT_TELEMETRY_ENDPOINT = 'https://telemetry.googleapis.com/v1/traces'
_DEFAULT_MTLS_TELEMETRY_ENDPOINT = (
'https://telemetry.mtls.googleapis.com/v1/traces'
)


class MtlsEndpoint(enum.Enum):
"""Enum for the mTLS endpoint setting."""

AUTO = 'auto'
ALWAYS = 'always'
NEVER = 'never'


def get_gcp_exporters(
enable_cloud_tracing: bool = False,
Expand Down Expand Up @@ -100,10 +115,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor:
from google.auth.transport.requests import AuthorizedSession
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter

session = AuthorizedSession(credentials=credentials)

use_client_cert = _use_client_cert_effective()
if use_client_cert:
client_cert_source = (
mtls.default_client_cert_source()
if mtls.has_default_client_cert_source()
else None
)
session.configure_mtls_channel()
endpoint = _get_api_endpoint(client_cert_source)
else:
endpoint = _DEFAULT_TELEMETRY_ENDPOINT

return BatchSpanProcessor(
OTLPSpanExporter(
session=AuthorizedSession(credentials=credentials),
endpoint='https://telemetry.googleapis.com/v1/traces',
session=session,
endpoint=endpoint,
)
)

Expand Down Expand Up @@ -158,3 +187,58 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource:
' GCE, GKE or CloudRun related resource attributes may be missing'
)
return resource


def _get_api_endpoint(client_cert_source: bytes | None = None) -> str:
"""Returns API endpoint based on mTLS configuration and cert availability.

Args:
client_cert_source (bytes | None): The client certificate source.

Returns:
str: The API endpoint to be used.
"""
use_mtls_endpoint_str = os.getenv(
'GOOGLE_API_USE_MTLS_ENDPOINT', MtlsEndpoint.AUTO.value
).lower()

try:
use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str)
except ValueError:
logger.warning(
'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of '
f'{[e.value for e in MtlsEndpoint]}. Defaulting to'
f' {MtlsEndpoint.AUTO.value}.'
)
use_mtls_endpoint = MtlsEndpoint.AUTO

if (use_mtls_endpoint == MtlsEndpoint.ALWAYS) or (
use_mtls_endpoint == MtlsEndpoint.AUTO and client_cert_source
):
return _DEFAULT_MTLS_TELEMETRY_ENDPOINT

return _DEFAULT_TELEMETRY_ENDPOINT


def _use_client_cert_effective() -> bool:
"""Returns whether client certificate should be used for mTLS.

This checks if the google-auth version supports should_use_client_cert
automatic mTLS enablement. Alternatively, it reads from the
GOOGLE_API_USE_CLIENT_CERTIFICATE env var.

Returns:
bool: whether client certificate should be used for mTLS.
"""
try:
return bool(mtls.should_use_client_cert())
except (ImportError, AttributeError):
use_client_cert_str = os.getenv(
'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false'
).lower()
if use_client_cert_str not in ('true', 'false'):
logger.warning(
'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be'
' either `true` or `false`'
)
return use_client_cert_str == 'true'
88 changes: 88 additions & 0 deletions tests/unittests/telemetry/test_google_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from typing import Optional
from unittest import mock

from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_ENDPOINT
from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_ENDPOINT
from google.adk.telemetry.google_cloud import _get_api_endpoint
from google.adk.telemetry.google_cloud import _get_gcp_span_exporter
from google.adk.telemetry.google_cloud import _use_client_cert_effective
from google.adk.telemetry.google_cloud import get_gcp_exporters
from google.adk.telemetry.google_cloud import get_gcp_resource
import pytest
Expand Down Expand Up @@ -89,3 +94,86 @@ def test_get_gcp_resource(
otel_resource.attributes.get("gcp.project_id", None)
== expected_project_id
)


@mock.patch("google.auth.transport.mtls.should_use_client_cert")
def test_use_client_cert_effective_from_mtls(mock_should_use):
mock_should_use.return_value = True
assert _use_client_cert_effective()

mock_should_use.return_value = False
assert not _use_client_cert_effective()


def test_use_client_cert_effective_from_env(monkeypatch, caplog):
with mock.patch(
"google.auth.transport.mtls.should_use_client_cert",
side_effect=AttributeError,
):
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
assert _use_client_cert_effective()

monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
assert not _use_client_cert_effective()

# Test invalid value defaults to False
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe")
assert not _use_client_cert_effective()
assert (
"Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be"
" either `true` or `false`"
in caplog.text
)


@pytest.mark.parametrize(
"env_val, cert_source, expected",
[
("auto", b"cert", _DEFAULT_MTLS_TELEMETRY_ENDPOINT),
("auto", None, _DEFAULT_TELEMETRY_ENDPOINT),
("always", None, _DEFAULT_MTLS_TELEMETRY_ENDPOINT),
("never", b"cert", _DEFAULT_TELEMETRY_ENDPOINT),
("invalid", None, _DEFAULT_TELEMETRY_ENDPOINT),
],
)
def test_get_api_endpoint(env_val, cert_source, expected, monkeypatch, caplog):
monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val)
if env_val == "invalid":
assert _get_api_endpoint(cert_source) == expected
assert (
"Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of"
in caplog.text
)
else:
assert _get_api_endpoint(cert_source) == expected


@mock.patch("google.auth.transport.requests.AuthorizedSession")
@mock.patch(
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
)
@mock.patch("google.adk.telemetry.google_cloud.BatchSpanProcessor")
@mock.patch("google.adk.telemetry.google_cloud._use_client_cert_effective")
@mock.patch("google.auth.transport.mtls.has_default_client_cert_source")
@mock.patch("google.auth.transport.mtls.default_client_cert_source")
def test_get_gcp_span_exporter_mtls(
mock_default_cert,
mock_has_cert,
mock_use_cert,
mock_batch,
mock_exporter,
mock_session,
):
credentials = mock.Mock()
mock_use_cert.return_value = True
mock_has_cert.return_value = True
mock_default_cert.return_value = b"cert"

_get_gcp_span_exporter(credentials)

mock_session.assert_called_once_with(credentials=credentials)
mock_session.return_value.configure_mtls_channel.assert_called_once()
mock_exporter.assert_called_once_with(
session=mock_session.return_value,
endpoint=_DEFAULT_MTLS_TELEMETRY_ENDPOINT,
)
Loading