diff --git a/changelog/1007.fixed.md b/changelog/1007.fixed.md new file mode 100644 index 00000000..d87d256a --- /dev/null +++ b/changelog/1007.fixed.md @@ -0,0 +1 @@ +Render schema rejections originating in an `extensions:` block as a readable one-line message in `infrahubctl schema load`, instead of crashing with `ValueError: invalid literal for int()`. diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py index 67bd9923..21f9c7a9 100644 --- a/infrahub_sdk/ctl/schema.py +++ b/infrahub_sdk/ctl/schema.py @@ -4,7 +4,7 @@ import time from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import typer import yaml @@ -24,6 +24,8 @@ if TYPE_CHECKING: from .. import InfrahubClient +SchemaContainer = Literal["nodes", "generics", "relationships"] + app = AsyncTyper() console = Console() @@ -49,73 +51,127 @@ def validate_schema_content_and_exit(client: InfrahubClient, schemas: list[Schem raise typer.Exit(1) -def display_schema_load_errors(response: dict[str, Any], schemas_data: list[SchemaFile]) -> None: - console.print("[red]Unable to load the schema:") +def display_schema_load_errors( + response: dict[str, Any], schemas_data: list[SchemaFile], output: Console | None = None +) -> None: + out = output or console + out.print("[red]Unable to load the schema:") if "detail" not in response: - handle_non_detail_errors(response=response) + handle_non_detail_errors(response=response, output=out) return for error in response["detail"]: loc_path = error.get("loc", []) if not valid_error_path(loc_path=loc_path): continue + _render_schema_error(error=error, loc_path=loc_path, schemas_data=schemas_data, output=out) + - # if the len of the path is equal to 6, the error is at the root of the object - # if the len of the path is higher than 6, the error is in an attribute or a relationships - schema_index = int(loc_path[2]) +def _render_schema_error( + error: dict[str, Any], loc_path: list[Any], schemas_data: list[SchemaFile], output: Console +) -> None: + # Two layout shapes for loc_path. tail is the part after the node index. + # Top-level: body / schemas / / (nodes|generics) / / [ / ] + # Extensions: body / schemas / / extensions / (nodes|generics|relationships) / / [ / ] + schema_index = int(loc_path[2]) + is_extension = loc_path[3] == "extensions" + if is_extension: + container = loc_path[4] + node_index = int(loc_path[5]) + tail = loc_path[6:] + else: + container = loc_path[3] node_index = int(loc_path[4]) - node = get_node(schemas_data=schemas_data, schema_index=schema_index, node_index=node_index) + tail = loc_path[5:] + + node = get_node( + schemas_data=schemas_data, + schema_index=schema_index, + node_index=node_index, + container=container, + is_extension=is_extension, + ) - if not node: - console.print("Node data not found.") - continue + if not node: + output.print("Node data not found.") + return - if len(loc_path) == 6: - loc_type = loc_path[-1] - input_str = error.get("input", None) - error_message = f"{loc_type} ({input_str}) | {error['msg']} ({error['type']})" - console.print( - f" Node: {node.get('namespace', None)}{node.get('name', None)} | {error_message}", markup=False - ) + # Extensions reference an existing node by `kind`; new top-level nodes are identified by `namespace+name`. + node_label = ( + (node.get("kind") or node.get("name") or "") + if is_extension + else f"{node.get('namespace', None)}{node.get('name', None)}" + ) + path_suffix = f" (extensions/{container})" if is_extension else "" + input_str = error.get("input") + err_msg = error.get("msg", "No error message") + err_type = error.get("type", "unknown") + + if len(tail) == 1: + # Error on a direct field of the node (e.g. `name`, `namespace`). + loc_type = tail[0] + error_message = f"{loc_type} ({input_str}) | {err_msg} ({err_type})" + elif len(tail) > 1: + # Error nested inside a collection (e.g. attributes[2].kind, relationships[0].peer). + # loc_type is the collection name; attribute is either its index or the failing field name. + loc_type = tail[0] + attribute = tail[1] + input_label = _resolve_attribute_label(error_data=node.get(loc_type, []), attribute=attribute) + # Trim the trailing 's' so "attributes" → "Attribute" in the rendered label. + error_message = f"{loc_type[:-1].title()}: {input_label} ({input_str}) | {err_msg} ({err_type})" + else: + return - elif len(loc_path) > 6: - loc_type = loc_path[5] - error_data = node[loc_type] - attribute = loc_path[6] - - if isinstance(attribute, str): - input_label = None - for data in error_data: - if data.get(attribute) is not None: - input_label = data.get("name", None) - break - else: - input_label = error_data[attribute].get("name", None) + output.print(f" Node: {node_label}{path_suffix} | {error_message}", markup=False) - input_str = error.get("input", None) - error_message = f"{loc_type[:-1].title()}: {input_label} ({input_str}) | {error['msg']} ({error['type']})" - console.print( - f" Node: {node.get('namespace', None)}{node.get('name', None)} | {error_message}", markup=False - ) +def _resolve_attribute_label(error_data: list[dict[str, Any]], attribute: Any) -> str | None: + if isinstance(attribute, str): + for data in error_data: + if data.get(attribute) is not None: + return data.get("name", None) + return None + if isinstance(attribute, int) and 0 <= attribute < len(error_data): + return error_data[attribute].get("name", None) + return None -def handle_non_detail_errors(response: dict[str, Any]) -> None: + +def handle_non_detail_errors(response: dict[str, Any], output: Console | None = None) -> None: + out = output or console if "error" in response: - console.print(f" {response.get('error')}") + out.print(f" {response.get('error')}") elif "errors" in response: for error in response["errors"]: - console.print(f" {error.get('message')}") + out.print(f" {error.get('message')}") else: - console.print(f" '{response}'") + out.print(f" '{response}'") def valid_error_path(loc_path: list[Any]) -> bool: - return len(loc_path) >= 6 and loc_path[0] == "body" and loc_path[1] == "schemas" - - -def get_node(schemas_data: list[SchemaFile], schema_index: int, node_index: int) -> dict | None: - if schema_index < len(schemas_data) and node_index < len(schemas_data[schema_index].payload["nodes"]): - return schemas_data[schema_index].payload["nodes"][node_index] + if len(loc_path) < 6 or loc_path[0] != "body" or loc_path[1] != "schemas" or not isinstance(loc_path[2], int): + return False + if loc_path[3] == "extensions": + return ( + len(loc_path) >= 7 + and loc_path[4] in {"nodes", "generics", "relationships"} + and isinstance(loc_path[5], int) + ) + return loc_path[3] in {"nodes", "generics"} and isinstance(loc_path[4], int) + + +def get_node( + schemas_data: list[SchemaFile], + schema_index: int, + node_index: int, + container: SchemaContainer = "nodes", + is_extension: bool = False, +) -> dict | None: + if schema_index >= len(schemas_data): + return None + payload = schemas_data[schema_index].payload + items = payload.get("extensions", {}).get(container, []) if is_extension else payload.get(container, []) + if node_index < len(items): + return items[node_index] return None diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 314a3ffa..42307e3f 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -1,11 +1,15 @@ +from pathlib import Path from typing import Any import pytest +from rich.console import Console from infrahub_sdk import InfrahubClient +from infrahub_sdk.ctl.schema import display_schema_load_errors from infrahub_sdk.exceptions import BranchNotFoundError from infrahub_sdk.schema import NodeSchemaAPI from infrahub_sdk.testing.docker import TestInfrahubDockerClient +from infrahub_sdk.yaml import SchemaFile class TestInfrahubSchema(TestInfrahubDockerClient): @@ -43,3 +47,71 @@ async def test_schema_load_many( schema_nodes = await client.schema.all(refresh=True) assert "InfraRack" in schema_nodes assert "ProcurementContract" in schema_nodes + + +class TestInfrahubSchemaLoadErrorRendering(TestInfrahubDockerClient): + """Render real server error responses through display_schema_load_errors. + + These exist as integration tests so we catch any drift between the server's + validation error payload shape and the CLI renderer, particularly for + `extensions` paths which previously went unhandled. + """ + + async def test_extension_top_level_field_error(self, client: InfrahubClient) -> None: + broken_schema = { + "version": "1.0", + "extensions": { + "nodes": [ + { + "kind": "BuiltinTag", + "namespace": "Forbidden", + } + ] + }, + } + + response = await client.schema.load(schemas=[broken_schema]) + + assert response.errors, "Server should reject a forbidden field on an extensions/nodes entry" + assert "detail" in response.errors + + schemas_data = [SchemaFile(location=Path("broken.yml"), content=broken_schema)] + console = Console(width=1000) + with console.capture() as capture: + display_schema_load_errors(response=response.errors, schemas_data=schemas_data, output=console) + rendered = capture.get() + + assert "Unable to load the schema" in rendered + assert "BuiltinTag" in rendered + assert "extensions/nodes" in rendered + + async def test_extension_nested_attribute_error(self, client: InfrahubClient) -> None: + broken_schema = { + "version": "1.0", + "extensions": { + "nodes": [ + { + "kind": "BuiltinTag", + "attributes": [ + {"name": "speed", "kind": "Number", "made_up": True}, + ], + } + ] + }, + } + + response = await client.schema.load(schemas=[broken_schema]) + + assert response.errors, "Server should reject a forbidden field on an extensions attribute entry" + assert "detail" in response.errors + + schemas_data = [SchemaFile(location=Path("broken.yml"), content=broken_schema)] + console = Console(width=1000) + with console.capture() as capture: + display_schema_load_errors(response=response.errors, schemas_data=schemas_data, output=console) + rendered = capture.get() + + assert "Unable to load the schema" in rendered + assert "BuiltinTag" in rendered + assert "extensions/nodes" in rendered + assert "speed" in rendered diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py index 308fe64a..fe390537 100644 --- a/tests/unit/sdk/test_schema.py +++ b/tests/unit/sdk/test_schema.py @@ -8,7 +8,7 @@ from rich.console import Console from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync -from infrahub_sdk.ctl.schema import display_schema_load_errors +from infrahub_sdk.ctl.schema import display_schema_load_errors, valid_error_path from infrahub_sdk.exceptions import SchemaNotFoundError, ValidationError from infrahub_sdk.protocols import BuiltinIPAddress, BuiltinIPAddressSync, BuiltinTag, BuiltinTagSync from infrahub_sdk.schema import BranchSchema, InfrahubSchema, InfrahubSchemaBase, InfrahubSchemaSync, NodeSchemaAPI @@ -490,6 +490,35 @@ async def test_display_schema_load_errors_details_when_error_is_in_attribute_or_ assert output == expected_console +@pytest.mark.parametrize( + "loc_path", + [ + pytest.param(["body", "schemas", 0, "nodes", 0, "name"], id="top-level-nodes"), + pytest.param(["body", "schemas", 0, "generics", 1, "attributes", 0], id="top-level-generics"), + pytest.param(["body", "schemas", 0, "extensions", "nodes", 0, "kind"], id="extension-nodes"), + pytest.param(["body", "schemas", 0, "extensions", "generics", 0, "name"], id="extension-generics"), + pytest.param(["body", "schemas", 0, "extensions", "relationships", 2, "peer"], id="extension-relationships"), + ], +) +def test_valid_error_path_accepts_known_shapes(loc_path: list) -> None: + assert valid_error_path(loc_path=loc_path) + + +@pytest.mark.parametrize( + "loc_path", + [ + pytest.param(["body", "headers", "x-test"], id="wrong-root"), + pytest.param(["body", "schemas", "not-an-int", "nodes", 0, "name"], id="non-int-schema-index"), + pytest.param(["body", "schemas", 0, "wat", 0, "name"], id="unknown-container"), + pytest.param(["body", "schemas", 0, "extensions", "generics", "include_in_menu"], id="non-int-extension-index"), + pytest.param(["body", "schemas", 0, "extensions", "wat", 0, "name"], id="unknown-extension-container"), + pytest.param(["body", "schemas", 0, "nodes"], id="too-short"), + ], +) +def test_valid_error_path_rejects_unknown_shapes(loc_path: list) -> None: + assert not valid_error_path(loc_path=loc_path) + + def test_schema_base__get_schema_name__returns_correct_schema_name_for_protocols() -> None: assert InfrahubSchemaBase._get_schema_name(schema=BuiltinTagSync) == "BuiltinTag" assert InfrahubSchemaBase._get_schema_name(schema=BuiltinTag) == "BuiltinTag"