diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 30e2f1ef..17d6209f 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -34,6 +34,7 @@ def _check_ns_shape_dtype( actual: Array, desired: Array, + check_namespace: bool, check_dtype: bool, check_shape: bool, check_scalar: bool, @@ -47,6 +48,8 @@ def _check_ns_shape_dtype( The array produced by the tested function. desired : Array The expected array (typically hardcoded). + check_namespace : bool, default: True + Whether to check agreement between actual and desired namespace. check_dtype, check_shape : bool, default: True Whether to check agreement between actual and desired dtypes and shapes check_scalar : bool, default: False @@ -60,8 +63,9 @@ def _check_ns_shape_dtype( actual_xp = array_namespace(actual) # Raises on scalars and lists desired_xp = array_namespace(desired) - msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" - assert actual_xp == desired_xp, msg + if check_namespace: + msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" + assert actual_xp == desired_xp, msg # Dask uses nan instead of None for unknown shapes actual_shape = cast(tuple[float, ...], actual.shape) @@ -139,6 +143,7 @@ def xp_assert_equal( desired: Array, *, err_msg: str = "", + check_namespace: bool = True, check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, @@ -154,6 +159,8 @@ def xp_assert_equal( The expected array (typically hardcoded). err_msg : str, optional Error message to display on failure. + check_namespace : bool, default: True + Whether to check agreement between actual and desired namespace. check_dtype, check_shape : bool, default: True Whether to check agreement between actual and desired dtypes and shapes check_scalar : bool, default: False @@ -165,7 +172,14 @@ def xp_assert_equal( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + xp = _check_ns_shape_dtype( + actual, + desired, + check_namespace, + check_dtype, + check_shape, + check_scalar, + ) if not _is_materializable(actual): return actual_np = as_numpy_array(actual, xp=xp) @@ -178,6 +192,7 @@ def xp_assert_less( y: Array, *, err_msg: str = "", + check_namespace: bool = True, check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, @@ -191,6 +206,8 @@ def xp_assert_less( The arrays to compare according to ``x < y`` (elementwise). err_msg : str, optional Error message to display on failure. + check_namespace : bool, default: True + Whether to check agreement between actual and desired namespace. check_dtype, check_shape : bool, default: True Whether to check agreement between actual and desired dtypes and shapes check_scalar : bool, default: False @@ -202,7 +219,7 @@ def xp_assert_less( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + xp = _check_ns_shape_dtype(x, y, check_namespace, check_dtype, check_shape, check_scalar) if not _is_materializable(x): return x_np = as_numpy_array(x, xp=xp) @@ -217,6 +234,7 @@ def xp_assert_close( rtol: float | None = None, atol: float = 0, err_msg: str = "", + check_namespace: bool = True, check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, @@ -236,6 +254,8 @@ def xp_assert_close( Absolute tolerance. Default: 0. err_msg : str, optional Error message to display on failure. + check_namespace : bool, default: True + Whether to check agreement between actual and desired namespace. check_dtype, check_shape : bool, default: True Whether to check agreement between actual and desired dtypes and shapes check_scalar : bool, default: False @@ -252,7 +272,14 @@ def xp_assert_close( ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + xp = _check_ns_shape_dtype( + actual, + desired, + check_namespace, + check_dtype, + check_shape, + check_scalar, + ) if not _is_materializable(actual): return