Skip to content
1 change: 1 addition & 0 deletions changelog/1007.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Render schema rejections originating in an `extensions:` block as a readable one-line message in `infrahubctl schema load`, instead of crashing with `ValueError: invalid literal for int()`.
148 changes: 102 additions & 46 deletions infrahub_sdk/ctl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import typer
import yaml
Expand All @@ -24,6 +24,8 @@
if TYPE_CHECKING:
from .. import InfrahubClient

SchemaContainer = Literal["nodes", "generics", "relationships"]

app = AsyncTyper()
console = Console()

Expand All @@ -49,73 +51,127 @@ def validate_schema_content_and_exit(client: InfrahubClient, schemas: list[Schem
raise typer.Exit(1)


def display_schema_load_errors(response: dict[str, Any], schemas_data: list[SchemaFile]) -> None:
console.print("[red]Unable to load the schema:")
def display_schema_load_errors(
response: dict[str, Any], schemas_data: list[SchemaFile], output: Console | None = None
) -> None:
out = output or console
out.print("[red]Unable to load the schema:")
if "detail" not in response:
handle_non_detail_errors(response=response)
handle_non_detail_errors(response=response, output=out)
return

for error in response["detail"]:
loc_path = error.get("loc", [])
if not valid_error_path(loc_path=loc_path):
continue
_render_schema_error(error=error, loc_path=loc_path, schemas_data=schemas_data, output=out)


# if the len of the path is equal to 6, the error is at the root of the object
# if the len of the path is higher than 6, the error is in an attribute or a relationships
schema_index = int(loc_path[2])
def _render_schema_error(
error: dict[str, Any], loc_path: list[Any], schemas_data: list[SchemaFile], output: Console
) -> None:
# Two layout shapes for loc_path. tail is the part after the node index.
# Top-level: body / schemas / <si> / (nodes|generics) / <ni> / [<subtype> / <attr>]
# Extensions: body / schemas / <si> / extensions / (nodes|generics|relationships) / <ni> / [<subtype> / <attr>]
schema_index = int(loc_path[2])
is_extension = loc_path[3] == "extensions"
if is_extension:
container = loc_path[4]
node_index = int(loc_path[5])
tail = loc_path[6:]
else:
container = loc_path[3]
node_index = int(loc_path[4])
node = get_node(schemas_data=schemas_data, schema_index=schema_index, node_index=node_index)
tail = loc_path[5:]

node = get_node(
schemas_data=schemas_data,
schema_index=schema_index,
node_index=node_index,
container=container,
is_extension=is_extension,
)

if not node:
console.print("Node data not found.")
continue
if not node:
output.print("Node data not found.")
return

if len(loc_path) == 6:
loc_type = loc_path[-1]
input_str = error.get("input", None)
error_message = f"{loc_type} ({input_str}) | {error['msg']} ({error['type']})"
console.print(
f" Node: {node.get('namespace', None)}{node.get('name', None)} | {error_message}", markup=False
)
# Extensions reference an existing node by `kind`; new top-level nodes are identified by `namespace+name`.
node_label = (
(node.get("kind") or node.get("name") or "")
if is_extension
else f"{node.get('namespace', None)}{node.get('name', None)}"
)
path_suffix = f" (extensions/{container})" if is_extension else ""
input_str = error.get("input")
err_msg = error.get("msg", "No error message")
err_type = error.get("type", "unknown")

if len(tail) == 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not your fault, but adding a few one-line comments in here would make it easier to understand this pretty opaque error string parsing

# Error on a direct field of the node (e.g. `name`, `namespace`).
loc_type = tail[0]
error_message = f"{loc_type} ({input_str}) | {err_msg} ({err_type})"
elif len(tail) > 1:
# Error nested inside a collection (e.g. attributes[2].kind, relationships[0].peer).
# loc_type is the collection name; attribute is either its index or the failing field name.
loc_type = tail[0]
attribute = tail[1]
input_label = _resolve_attribute_label(error_data=node.get(loc_type, []), attribute=attribute)
# Trim the trailing 's' so "attributes" → "Attribute" in the rendered label.
error_message = f"{loc_type[:-1].title()}: {input_label} ({input_str}) | {err_msg} ({err_type})"
else:
return

elif len(loc_path) > 6:
loc_type = loc_path[5]
error_data = node[loc_type]
attribute = loc_path[6]

if isinstance(attribute, str):
input_label = None
for data in error_data:
if data.get(attribute) is not None:
input_label = data.get("name", None)
break
else:
input_label = error_data[attribute].get("name", None)
output.print(f" Node: {node_label}{path_suffix} | {error_message}", markup=False)

input_str = error.get("input", None)
error_message = f"{loc_type[:-1].title()}: {input_label} ({input_str}) | {error['msg']} ({error['type']})"
console.print(
f" Node: {node.get('namespace', None)}{node.get('name', None)} | {error_message}", markup=False
)

def _resolve_attribute_label(error_data: list[dict[str, Any]], attribute: Any) -> str | None:
if isinstance(attribute, str):
for data in error_data:
if data.get(attribute) is not None:
return data.get("name", None)
return None
if isinstance(attribute, int) and 0 <= attribute < len(error_data):
return error_data[attribute].get("name", None)
return None

def handle_non_detail_errors(response: dict[str, Any]) -> None:

def handle_non_detail_errors(response: dict[str, Any], output: Console | None = None) -> None:
out = output or console
if "error" in response:
console.print(f" {response.get('error')}")
out.print(f" {response.get('error')}")
elif "errors" in response:
for error in response["errors"]:
console.print(f" {error.get('message')}")
out.print(f" {error.get('message')}")
else:
console.print(f" '{response}'")
out.print(f" '{response}'")


def valid_error_path(loc_path: list[Any]) -> bool:
return len(loc_path) >= 6 and loc_path[0] == "body" and loc_path[1] == "schemas"


def get_node(schemas_data: list[SchemaFile], schema_index: int, node_index: int) -> dict | None:
if schema_index < len(schemas_data) and node_index < len(schemas_data[schema_index].payload["nodes"]):
return schemas_data[schema_index].payload["nodes"][node_index]
if len(loc_path) < 6 or loc_path[0] != "body" or loc_path[1] != "schemas" or not isinstance(loc_path[2], int):
return False
if loc_path[3] == "extensions":
return (
len(loc_path) >= 7
and loc_path[4] in {"nodes", "generics", "relationships"}
and isinstance(loc_path[5], int)
)
return loc_path[3] in {"nodes", "generics"} and isinstance(loc_path[4], int)


def get_node(
schemas_data: list[SchemaFile],
schema_index: int,
node_index: int,
container: SchemaContainer = "nodes",
is_extension: bool = False,
) -> dict | None:
if schema_index >= len(schemas_data):
return None
payload = schemas_data[schema_index].payload
items = payload.get("extensions", {}).get(container, []) if is_extension else payload.get(container, [])
if node_index < len(items):
return items[node_index]
return None


Expand Down
72 changes: 72 additions & 0 deletions tests/integration/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from pathlib import Path
from typing import Any

import pytest
from rich.console import Console

from infrahub_sdk import InfrahubClient
from infrahub_sdk.ctl.schema import display_schema_load_errors
from infrahub_sdk.exceptions import BranchNotFoundError
from infrahub_sdk.schema import NodeSchemaAPI
from infrahub_sdk.testing.docker import TestInfrahubDockerClient
from infrahub_sdk.yaml import SchemaFile


class TestInfrahubSchema(TestInfrahubDockerClient):
Expand Down Expand Up @@ -43,3 +47,71 @@ async def test_schema_load_many(
schema_nodes = await client.schema.all(refresh=True)
assert "InfraRack" in schema_nodes
assert "ProcurementContract" in schema_nodes


class TestInfrahubSchemaLoadErrorRendering(TestInfrahubDockerClient):
"""Render real server error responses through display_schema_load_errors.

These exist as integration tests so we catch any drift between the server's
validation error payload shape and the CLI renderer, particularly for
`extensions` paths which previously went unhandled.
"""

async def test_extension_top_level_field_error(self, client: InfrahubClient) -> None:
broken_schema = {
"version": "1.0",
"extensions": {
"nodes": [
{
"kind": "BuiltinTag",
"namespace": "Forbidden",
}
]
},
}

response = await client.schema.load(schemas=[broken_schema])

assert response.errors, "Server should reject a forbidden field on an extensions/nodes entry"
assert "detail" in response.errors

schemas_data = [SchemaFile(location=Path("broken.yml"), content=broken_schema)]
console = Console(width=1000)
with console.capture() as capture:
display_schema_load_errors(response=response.errors, schemas_data=schemas_data, output=console)
rendered = capture.get()

assert "Unable to load the schema" in rendered
assert "BuiltinTag" in rendered
assert "extensions/nodes" in rendered

async def test_extension_nested_attribute_error(self, client: InfrahubClient) -> None:
broken_schema = {
"version": "1.0",
"extensions": {
"nodes": [
{
"kind": "BuiltinTag",
"attributes": [
{"name": "speed", "kind": "Number", "made_up": True},
],
}
]
},
}

response = await client.schema.load(schemas=[broken_schema])

assert response.errors, "Server should reject a forbidden field on an extensions attribute entry"
assert "detail" in response.errors

schemas_data = [SchemaFile(location=Path("broken.yml"), content=broken_schema)]
console = Console(width=1000)
with console.capture() as capture:
display_schema_load_errors(response=response.errors, schemas_data=schemas_data, output=console)
rendered = capture.get()

assert "Unable to load the schema" in rendered
assert "BuiltinTag" in rendered
assert "extensions/nodes" in rendered
assert "speed" in rendered
31 changes: 30 additions & 1 deletion tests/unit/sdk/test_schema.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be easier to test if this code was written using dependency injection and then you could skip the patch() calls. sometimes patching is necessary for testing, but we generally try to avoid it when you can achieve a more robust test by moving the test up a layer or two in the test pyramid. in this case testing just display_schema_load_errors is fast, but brittle b/c it assumes a certain response structure from the schema load endpoint. I think that these would be better as integration tests similar to those in python_sdk/tests/integration/test_schema.py, then the test really ensures that the response from the server and the client handling work together correctly. sadly, integration tests are slower, but I think it is worth it in this case b/c we have clearly left schema extensions behind for a while and don't want to again

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed a full refactor, particularly around where to display the error, which was needed some time to think about how.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rich.console import Console

from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync
from infrahub_sdk.ctl.schema import display_schema_load_errors
from infrahub_sdk.ctl.schema import display_schema_load_errors, valid_error_path
from infrahub_sdk.exceptions import SchemaNotFoundError, ValidationError
from infrahub_sdk.protocols import BuiltinIPAddress, BuiltinIPAddressSync, BuiltinTag, BuiltinTagSync
from infrahub_sdk.schema import BranchSchema, InfrahubSchema, InfrahubSchemaBase, InfrahubSchemaSync, NodeSchemaAPI
Expand Down Expand Up @@ -490,6 +490,35 @@ async def test_display_schema_load_errors_details_when_error_is_in_attribute_or_
assert output == expected_console


@pytest.mark.parametrize(
"loc_path",
[
pytest.param(["body", "schemas", 0, "nodes", 0, "name"], id="top-level-nodes"),
pytest.param(["body", "schemas", 0, "generics", 1, "attributes", 0], id="top-level-generics"),
pytest.param(["body", "schemas", 0, "extensions", "nodes", 0, "kind"], id="extension-nodes"),
pytest.param(["body", "schemas", 0, "extensions", "generics", 0, "name"], id="extension-generics"),
pytest.param(["body", "schemas", 0, "extensions", "relationships", 2, "peer"], id="extension-relationships"),
],
)
def test_valid_error_path_accepts_known_shapes(loc_path: list) -> None:
assert valid_error_path(loc_path=loc_path)


@pytest.mark.parametrize(
"loc_path",
[
pytest.param(["body", "headers", "x-test"], id="wrong-root"),
pytest.param(["body", "schemas", "not-an-int", "nodes", 0, "name"], id="non-int-schema-index"),
pytest.param(["body", "schemas", 0, "wat", 0, "name"], id="unknown-container"),
pytest.param(["body", "schemas", 0, "extensions", "generics", "include_in_menu"], id="non-int-extension-index"),
pytest.param(["body", "schemas", 0, "extensions", "wat", 0, "name"], id="unknown-extension-container"),
pytest.param(["body", "schemas", 0, "nodes"], id="too-short"),
],
)
def test_valid_error_path_rejects_unknown_shapes(loc_path: list) -> None:
assert not valid_error_path(loc_path=loc_path)


def test_schema_base__get_schema_name__returns_correct_schema_name_for_protocols() -> None:
assert InfrahubSchemaBase._get_schema_name(schema=BuiltinTagSync) == "BuiltinTag"
assert InfrahubSchemaBase._get_schema_name(schema=BuiltinTag) == "BuiltinTag"
Expand Down