Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
def _check_ns_shape_dtype(
actual: Array,
desired: Array,
check_namespace: bool,
check_dtype: bool,
check_shape: bool,
check_scalar: bool,
Expand All @@ -47,6 +48,8 @@ def _check_ns_shape_dtype(
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
check_namespace : bool, default: True
Whether to check agreement between actual and desired namespace.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
Expand All @@ -60,8 +63,9 @@ def _check_ns_shape_dtype(
actual_xp = array_namespace(actual) # Raises on scalars and lists
desired_xp = array_namespace(desired)

msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg
if check_namespace:
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

# Dask uses nan instead of None for unknown shapes
actual_shape = cast(tuple[float, ...], actual.shape)
Expand Down Expand Up @@ -139,6 +143,7 @@ def xp_assert_equal(
desired: Array,
*,
err_msg: str = "",
check_namespace: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
Expand All @@ -154,6 +159,8 @@ def xp_assert_equal(
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.
check_namespace : bool, default: True
Whether to check agreement between actual and desired namespace.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
Expand All @@ -165,7 +172,14 @@ def xp_assert_equal(
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
xp = _check_ns_shape_dtype(
actual,
desired,
check_namespace,
check_dtype,
check_shape,
check_scalar,
)
if not _is_materializable(actual):
return
actual_np = as_numpy_array(actual, xp=xp)
Expand All @@ -178,6 +192,7 @@ def xp_assert_less(
y: Array,
*,
err_msg: str = "",
check_namespace: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
Expand All @@ -191,6 +206,8 @@ def xp_assert_less(
The arrays to compare according to ``x < y`` (elementwise).
err_msg : str, optional
Error message to display on failure.
check_namespace : bool, default: True
Whether to check agreement between actual and desired namespace.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
Expand All @@ -202,7 +219,7 @@ def xp_assert_less(
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
xp = _check_ns_shape_dtype(x, y, check_namespace, check_dtype, check_shape, check_scalar)
if not _is_materializable(x):
return
x_np = as_numpy_array(x, xp=xp)
Expand All @@ -217,6 +234,7 @@ def xp_assert_close(
rtol: float | None = None,
atol: float = 0,
err_msg: str = "",
check_namespace: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
Expand All @@ -236,6 +254,8 @@ def xp_assert_close(
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.
check_namespace : bool, default: True
Whether to check agreement between actual and desired namespace.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
Expand All @@ -252,7 +272,14 @@ def xp_assert_close(
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
xp = _check_ns_shape_dtype(
actual,
desired,
check_namespace,
check_dtype,
check_shape,
check_scalar,
)
if not _is_materializable(actual):
return

Expand Down
Loading