diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index b3a80e11aa..5244fe4b5b 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -29,7 +29,7 @@ from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt from pyiceberg import __version__ -from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary +from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, MetastoreCatalog, PropertiesUpdateSummary from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager from pyiceberg.catalog.rest.response import _handle_non_200_response from pyiceberg.catalog.rest.scan_planning import ( @@ -69,6 +69,7 @@ ) from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec, assign_fresh_partition_spec_ids from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.serializers import FromInputFile from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, @@ -79,12 +80,10 @@ TableIdentifier, TableProperties, ) -from pyiceberg.table.metadata import TableMetadata +from pyiceberg.table.locations import load_location_provider +from pyiceberg.table.metadata import TableMetadata, new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder, assign_fresh_sort_order_ids -from pyiceberg.table.update import ( - TableRequirement, - TableUpdate, -) +from pyiceberg.table.update import AssertMetadataLocation, SetTableMetadataLocationUpdate, TableRequirement, TableUpdate from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel, Identifier, Properties from pyiceberg.types import transform_dict_value_to_str from pyiceberg.utils.deprecated import deprecation_message @@ -377,8 +376,11 @@ class ListViewsResponse(IcebergBaseModel): _PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) +PREFER_CLIENT_SIDE_METADATA = "prefer-client-side-metadata" +METADATA_LOCATION = "metadata-location" + -class RestCatalog(Catalog): +class RestCatalog(MetastoreCatalog): uri: str _session: Session _auth_manager: AuthManager | None @@ -878,9 +880,27 @@ def _create_table( namespace_and_table = self._split_identifier_for_path(identifier) if location: location = location.rstrip("/") + if self.properties.get(PREFER_CLIENT_SIDE_METADATA): + namespace_identifier = Catalog.namespace_from(identifier) + namespace = Catalog.namespace_to_string(namespace_identifier) + table_name = Catalog.table_name_from(identifier) + + table_location = self._resolve_table_location(location, namespace, table_name) + location_provider = load_location_provider(table_location=table_location, table_properties=properties) + metadata = new_table_metadata( + location=table_location, + schema=fresh_schema, + partition_spec=fresh_partition_spec, + sort_order=fresh_sort_order, + properties=properties, + ) + metadata_location = location_provider.new_table_metadata_file_location() + io = load_file_io(properties=self.properties, location=metadata_location) + self._write_metadata(metadata, io, metadata_location) + properties = {METADATA_LOCATION: metadata_location, **properties} request = CreateTableRequest( name=self._identifier_to_validated_tuple(identifier)[-1], - location=location, + location=table_location, table_schema=fresh_schema, partition_spec=fresh_partition_spec, write_order=fresh_sort_order, @@ -896,7 +916,15 @@ def _create_table( response.raise_for_status() except HTTPError as exc: _handle_non_200_response(exc, {409: TableAlreadyExistsError, 404: NoSuchNamespaceError}) - return TableResponse.model_validate_json(response.text) + tr = TableResponse.model_validate_json(response.text) + if self.properties.get(PREFER_CLIENT_SIDE_METADATA): + return TableResponse( + metadata_location=metadata_location, + metadata=metadata, + config=properties, + storage_credentials=tr.storage_credentials, + ) + return tr @retry(**_RETRY_ARGS) def create_table( @@ -1028,6 +1056,24 @@ def list_tables(self, namespace: str | Identifier) -> list[Identifier]: @retry(**_RETRY_ARGS) def load_table(self, identifier: str | Identifier) -> Table: + tr = self._load_table(identifier) + if self.properties.get(PREFER_CLIENT_SIDE_METADATA): + metadata_location = tr.metadata_location + if not metadata_location: + raise ValueError("Metadata location is required for client-side metadata loading") + io = load_file_io(properties=self.properties, location=metadata_location) + file = io.new_input(metadata_location) + metadata = FromInputFile.table_metadata(file) + return Table( + identifier=identifier, + metadata=metadata, + metadata_location=metadata_location, + io=self._load_file_io(metadata.properties, metadata_location), + catalog=self, + ) + return tr + + def _load_table(self, identifier: str | Identifier) -> Table: self._check_endpoint(Capability.V1_LOAD_TABLE) params = {} if mode := self.properties.get(SNAPSHOT_LOADING_MODE): @@ -1147,6 +1193,18 @@ def commit_table( """ self._check_endpoint(Capability.V1_UPDATE_TABLE) identifier = table.name() + if self.properties.get(PREFER_CLIENT_SIDE_METADATA): + updated_staged_table = self._update_and_stage_table(table, identifier, requirements, updates) + if table and updated_staged_table.metadata == table.metadata: + # no changes, do nothing + return CommitTableResponse(metadata=table.metadata, metadata_location=table.metadata_location) + self._write_metadata( + metadata=updated_staged_table.metadata, + io=updated_staged_table.io, + metadata_path=updated_staged_table.metadata_location, + ) + requirements = (AssertMetadataLocation(metadata_location=table.metadata_location), *requirements) + updates = (SetTableMetadataLocationUpdate(metadata_location=updated_staged_table.metadata_location), *updates) table_identifier = TableIdentifier(namespace=identifier[:-1], name=identifier[-1]) table_request = CommitTableRequest(identifier=table_identifier, requirements=requirements, updates=updates) @@ -1171,6 +1229,12 @@ def commit_table( 504: CommitStateUnknownException, }, ) + + if self.properties.get(PREFER_CLIENT_SIDE_METADATA): + return CommitTableResponse( + metadata_location=updated_staged_table.metadata_location, + metadata=updated_staged_table.metadata, + ) return CommitTableResponse.model_validate_json(response.text) @retry(**_RETRY_ARGS) diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index e892b838c9..858e045a37 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -220,6 +220,11 @@ class RemovePartitionStatisticsUpdate(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") +class SetTableMetadataLocationUpdate(IcebergBaseModel): + action: Literal["set-table-metadata-location"] = Field(default="set-table-metadata-location") + metadata_location: str = Field(alias="metadata-location") + + TableUpdate = Annotated[ AssignUUIDUpdate | UpgradeFormatVersionUpdate @@ -241,7 +246,8 @@ class RemovePartitionStatisticsUpdate(IcebergBaseModel): | RemovePartitionSpecsUpdate | RemoveSchemasUpdate | SetPartitionStatisticsUpdate - | RemovePartitionStatisticsUpdate, + | RemovePartitionStatisticsUpdate + | SetTableMetadataLocationUpdate, Field(discriminator="action"), ] @@ -905,6 +911,22 @@ def validate(self, base_metadata: TableMetadata | None) -> None: ) +class AssertMetadataLocation(ValidatableTableRequirement): + """The table's metadata location must match the requirement's `metadata-location`.""" + + type: Literal["assert-table-metadata-location"] = Field(default="assert-table-metadata-location") + metadata_location: str = Field(..., alias="metadata-location") + + def validate(self, base_metadata: TableMetadata | None) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.metadata_location != base_metadata.metadata_location: + raise CommitFailedException( + f"Requirement failed: metadata location has changed: " + f"expected {self.metadata_location}, found {base_metadata.metadata_location}" + ) + + TableRequirement = Annotated[ AssertCreate | AssertTableUUID @@ -913,7 +935,8 @@ def validate(self, base_metadata: TableMetadata | None) -> None: | AssertCurrentSchemaId | AssertLastAssignedPartitionId | AssertDefaultSpecId - | AssertDefaultSortOrderId, + | AssertDefaultSortOrderId + | AssertMetadataLocation, Field(discriminator="type"), ]