diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 2fcdcd8e..17503674 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -19,6 +19,7 @@ ) from ._lib._at import at from ._lib._funcs import ( + angle, apply_where, broadcast_shapes, default_dtype, @@ -32,6 +33,7 @@ # pylint: disable=duplicate-code __all__ = [ "__version__", + "angle", "apply_where", "argpartition", "at", diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..b58cf3d0 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -818,3 +818,36 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array: b = xp.reshape(b, (-1,)) # XXX: `sparse` returns NumPy arrays from `unique_values` return xp.asarray(xp.unique_values(xp.concat([a, b]))) + + +def angle(z: Array, deg: bool = False, /, *, xp: ModuleType | None = None) -> Array: + """ + Return the angle of the complex argument. + + Parameters + ---------- + z : Array + Input array. + deg : bool, optional + Return angle in degrees if True, radians if False (default). + xp : array_namespace, optional + The standard-compatible namespace for `z`. Default: infer. + + Returns + ------- + ndarray or scalar + The counterclockwise angle from the positive real axis on the complex + plane in the range ``(-pi, pi]``, with dtype as float64. + """ + if xp is None: + xp = array_namespace(z) + if xp.isdtype(z.dtype, "complex floating"): + zimage = xp.imag(z) + zreal = xp.real(z) + else: + zimage = xp.zeros_like(z, dtype=xp.float64) + zreal = xp.astype(z, xp.float64) + a = xp.atan2(zimage, zreal) + if deg: + a = a * 180 / xp.pi + return a diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..71d7bc57 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -13,6 +13,7 @@ from typing_extensions import override from array_api_extra import ( + angle, apply_where, argpartition, at, @@ -1881,3 +1882,30 @@ def test_device(self, xp: ModuleType, device: Device): a = xp.asarray([-1, 1, 0], device=device) b = xp.asarray([2, -2, 0], device=device) assert get_device(union1d(a, b)) == device + + +class TestAngle: + def test_simple(self, xp: ModuleType): + a = xp.asarray([1, 0]) + expected = xp.asarray([0.0, 0.0], dtype=xp.float64) + res = angle(a) + xp_assert_equal(res, expected) + + def test_complex(self, xp: ModuleType): + a = xp.asarray([1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j]) + expected = xp.asarray([np.pi / 4, -np.pi / 4, 3 * np.pi / 4, -3 * np.pi / 4]) + res = angle(a) + xp_assert_equal(res, expected) + + def test_2d(self, xp: ModuleType): + a = xp.asarray([[1 + 1j, 1 - 1j], [-1 + 1j, -1 - 1j]]) + expected = xp.asarray( + [[np.pi / 4, -np.pi / 4], [3 * np.pi / 4, -3 * np.pi / 4]] + ) + res = angle(a) + xp_assert_equal(res, expected) + + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray([1 + 1j], device=device) + assert get_device(angle(a)) == device