Skip to content

Commit

Permalink
[External] [stdlib] Enhance Handling of Infinity and NaN in `assert_a…
Browse files Browse the repository at this point in the history
…lmost_equal` (#38991)

[External] [stdlib] Enhance Handling of Infinity and NaN in
`assert_almost_equal`

This PR enhances the `assert_almost_equal` function to correctly handle
cases involving infinity and NaN.

According to `test_assert_almost_equal` added to
`/test/testing/test_assertion.mojo`, the current implementation of
`assert_almost_equal` results in errors in the following cases:

```mojo
    alias float_type = DType.float32
    alias _inf = inf[float_type]()
    alias _nan = nan[float_type]()
    ...
    _should_succeed(
        SIMD[float_type, 2](-_inf, _inf), SIMD[float_type, 2](-_inf, _inf)
    )
    ...
    _should_fail(
        SIMD[float_type, 2](-_inf, 0.0),
        SIMD[float_type, 2](_inf, 0.0),
        rtol=0.1,
    )
    _should_fail(
        SIMD[float_type, 2](_inf, 0.0),
        SIMD[float_type, 2](0.0, 0.0),
        rtol=0.1,
    )
    ...
    _should_fail(
        SIMD[float_type, 2](_nan, 0.0),
        SIMD[float_type, 2](0.0, 0.0),
        equal_nan=True,
    )
```

This PR also:
- Eliminates the use of `and` and `or` in the `_isclose` function due to
the issue outlined in #2374.
- Explicitly reduces boolean vectors to boolean scalar values instead of
counting on implicit conversions for clarity.
- Avoids arithmetic operations in `_isclose` and `assert_almost_equal`
when the type is boolean, as these operations are not supported in this
case.
- Clarifies the behavior of `assert_almost_equal` in the docstring,
highlighting differences from similar functions such as
`numpy.testing.assert_allclose`.
- Adds the `inf` function to `utils/_numerics` along with corresponding
tests in `test/utils/test_numerics.mojo`.

ORIGINAL_AUTHOR=Leandro Augusto Lacerda Campos
<15185896+leandrolcampos@users.noreply.github.com>
PUBLIC_PR_LINK=#2375

Co-authored-by: Leandro Augusto Lacerda Campos <15185896+leandrolcampos@users.noreply.github.com>
Closes #2375
MODULAR_ORIG_COMMIT_REV_ID: 2e8b24461dd3279bc841877cc0167acaa104f273
  • Loading branch information
2 people authored and JoeLoser committed May 3, 2024
1 parent efd83e1 commit d602ead
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 24 deletions.
59 changes: 42 additions & 17 deletions stdlib/src/testing/testing.mojo
Expand Up @@ -19,7 +19,7 @@ from testing import assert_true
```
"""
from collections import Optional
from utils._numerics import isnan
from utils._numerics import isfinite, isnan
from builtin._location import __call_location, _SourceLocation

# ===----------------------------------------------------------------------=== #
Expand All @@ -41,21 +41,29 @@ fn _isclose(
rtol: Scalar[a.type],
equal_nan: Bool,
) -> SIMD[DType.bool, a.size]:
constrained[
a.type.is_bool() or a.type.is_integral() or a.type.is_floating_point(),
"input type must be boolean, integral, or floating-point",
]()

@parameter
if a.type.is_bool() or a.type.is_integral():
return a == b
else:
var both_nan = isnan(a) & isnan(b)
if equal_nan and both_nan.reduce_and():
return True

var res = (a == b)
var atol_vec = SIMD[a.type, a.size](atol)
var rtol_vec = SIMD[a.type, a.size](rtol)
res |= (
isfinite(a)
& isfinite(b)
& (_abs(a - b) <= (atol_vec.max(rtol_vec * _abs(a).max(_abs(b)))))
)

if equal_nan and isnan(a) and isnan(b):
return True

var atol_vec = SIMD[a.type, a.size](atol)
var rtol_vec = SIMD[a.type, a.size](rtol)
var res = _abs(a - b) <= (atol_vec.max(rtol_vec * _abs(a).max(_abs(b))))

if not equal_nan:
return res

return res.select(res, isnan(a) and isnan(b))
return res | both_nan if equal_nan else res


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -254,6 +262,14 @@ fn assert_almost_equal[
"""Asserts that the input values are equal up to a tolerance. If it is
not then an Error is raised.
When the type is boolean or integral, then equality is checked. When the
type is floating-point, then this checks if the two input values are
numerically the close using the $abs(lhs - rhs) <= max(rtol * max(abs(lhs),
abs(rhs)), atol)$ formula.
Constraints:
The type must be boolean, integral, or floating-point.
Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.
Expand All @@ -262,22 +278,31 @@ fn assert_almost_equal[
lhs: The lhs of the equality.
rhs: The rhs of the equality.
msg: The message to print.
atol: The _absolute tolerance.
atol: The absolute tolerance.
rtol: The relative tolerance.
equal_nan: Whether to treat nans as equal.
Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
constrained[
type.is_bool() or type.is_integral() or type.is_floating_point(),
"type must be boolean, integral, or floating-point",
]()

var almost_equal = _isclose(
lhs, rhs, atol=atol, rtol=rtol, equal_nan=equal_nan
)
if not almost_equal:
var err = str(lhs) + " is not close to " + str(
rhs
) + " with a diff of " + _abs(lhs - rhs)
if not almost_equal.reduce_and():
var err = "AssertionError: " + str(lhs) + " is not close to " + str(rhs)

@parameter
if type.is_integral() or type.is_floating_point():
err += " with a diff of " + str(_abs(lhs - rhs))

if msg:
err += " (" + msg + ")"

raise _assert_error(err, __call_location())


Expand Down
54 changes: 54 additions & 0 deletions stdlib/src/utils/_numerics.mojo
Expand Up @@ -647,6 +647,60 @@ fn isnan[
](val.value, (signaling_nan_test | quiet_nan_test).value)


# ===----------------------------------------------------------------------===#
# inf
# ===----------------------------------------------------------------------===#


@always_inline("nodebug")
fn inf[type: DType]() -> Scalar[type]:
"""Gets a +inf value for the given dtype.
Constraints:
Can only be used for FP dtypes.
Parameters:
type: The value dtype.
Returns:
The +inf value of the given dtype.
"""

@parameter
if type == DType.float16:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f16>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f16>`],
]()
)
elif type == DType.bfloat16:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<bf16>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<bf16>`],
]()
)
elif type == DType.float32:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f32>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f32>`],
]()
)
elif type == DType.float64:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f64>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f64>`],
]()
)
else:
constrained[False, "+inf only support on floating point types"]()

return 0


# ===----------------------------------------------------------------------===#
# isinf
# ===----------------------------------------------------------------------===#
Expand Down
114 changes: 108 additions & 6 deletions stdlib/test/testing/test_assertion.mojo
Expand Up @@ -13,13 +13,14 @@
# RUN: %mojo -debug-level full %s

from testing import (
assert_almost_equal,
assert_equal,
assert_false,
assert_not_equal,
assert_raises,
assert_true,
assert_false,
assert_almost_equal,
)
from utils._numerics import inf, nan


@value
Expand Down Expand Up @@ -61,26 +62,127 @@ def test_assert_messages():
try:
assert_true(False)
except e:
assert_true("test_assertion.mojo:62:20: AssertionError:" in str(e))
assert_true("test_assertion.mojo:63:20: AssertionError:" in str(e))

try:
assert_false(True)
except e:
assert_true("test_assertion.mojo:67:21: AssertionError:" in str(e))
assert_true("test_assertion.mojo:68:21: AssertionError:" in str(e))

try:
assert_equal(1, 0)
except e:
assert_true("test_assertion.mojo:72:21: AssertionError:" in str(e))
assert_true("test_assertion.mojo:73:21: AssertionError:" in str(e))

try:
assert_not_equal(0, 0)
except e:
assert_true("test_assertion.mojo:77:25: AssertionError:" in str(e))
assert_true("test_assertion.mojo:78:25: AssertionError:" in str(e))


def test_assert_almost_equal():
alias float_type = DType.float32
alias _inf = inf[float_type]()
alias _nan = nan[float_type]()

@parameter
def _should_succeed[
type: DType, size: Int
](
lhs: SIMD[type, size],
rhs: SIMD[type, size],
*,
atol: Scalar[type] = 0,
rtol: Scalar[type] = 0,
equal_nan: Bool = False,
):
var msg = "`test_assert_almost_equal` should have succeeded"
assert_almost_equal(
lhs, rhs, msg=msg, atol=atol, rtol=rtol, equal_nan=equal_nan
)

_should_succeed[DType.bool, 1](True, True)
_should_succeed(SIMD[DType.int32, 2](0, 1), SIMD[DType.int32, 2](0, 1))
_should_succeed(
SIMD[float_type, 2](-_inf, _inf), SIMD[float_type, 2](-_inf, _inf)
)
_should_succeed(
SIMD[float_type, 2](-_nan, _nan),
SIMD[float_type, 2](-_nan, _nan),
equal_nan=True,
)
_should_succeed(
SIMD[float_type, 2](1.0, -1.1),
SIMD[float_type, 2](1.1, -1.0),
atol=0.11,
)
_should_succeed(
SIMD[float_type, 2](1.0, -1.1),
SIMD[float_type, 2](1.1, -1.0),
rtol=0.10,
)

@parameter
def _should_fail[
type: DType, size: Int
](
lhs: SIMD[type, size],
rhs: SIMD[type, size],
*,
atol: Scalar[type] = 0,
rtol: Scalar[type] = 0,
equal_nan: Bool = False,
):
var msg = "`test_assert_almost_equal` should have failed"
with assert_raises(contains=msg):
assert_almost_equal(
lhs, rhs, msg=msg, atol=atol, rtol=rtol, equal_nan=equal_nan
)

_should_fail[DType.bool, 1](True, False)
_should_fail(
SIMD[DType.int32, 2](0, 1), SIMD[DType.int32, 2](0, -1), atol=5.0
)
_should_fail(
SIMD[float_type, 2](-_inf, 0.0),
SIMD[float_type, 2](_inf, 0.0),
rtol=0.1,
)
_should_fail(
SIMD[float_type, 2](_inf, 0.0),
SIMD[float_type, 2](0.0, 0.0),
rtol=0.1,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](_nan, 0.0),
equal_nan=False,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](0.0, 0.0),
equal_nan=False,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](0.0, 0.0),
equal_nan=True,
)
_should_fail(
SIMD[float_type, 2](1.0, 0.0),
SIMD[float_type, 2](1.1, 0.0),
atol=0.05,
)
_should_fail(
SIMD[float_type, 2](-1.0, 0.0),
SIMD[float_type, 2](-1.1, 0.0),
rtol=0.05,
)


def main():
test_assert_equal_is_generic()
test_assert_not_equal_is_generic()
test_assert_equal_with_simd()
test_assert_messages()
test_assert_almost_equal()
21 changes: 20 additions & 1 deletion stdlib/test/utils/test_numerics.mojo
Expand Up @@ -12,8 +12,9 @@
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s

from utils._numerics import FPUtils
from sys.info import has_neon
from testing import assert_equal, assert_true, assert_false
from utils._numerics import FPUtils, inf, isinf

alias FPU64 = FPUtils[DType.float64]

Expand Down Expand Up @@ -45,5 +46,23 @@ fn test_numerics() raises:
assert_equal(FPU64.get_mantissa(FPU64.pack(True, 6, 12)), 12)


fn test_inf() raises:
@parameter
fn _test_inf[type: DType]() raises:
var val = inf[type]()
var msg = "`test_inf` failed for `type == " + str(type) + "`"
assert_true((val > 0.0) & isinf(val), msg=msg)

@parameter
if not has_neon():
# "bf16 is not supported for ARM architectures"
_test_inf[DType.bfloat16]()

_test_inf[DType.float16]()
_test_inf[DType.float32]()
_test_inf[DType.float64]()


def main():
test_numerics()
test_inf()

0 comments on commit d602ead

Please sign in to comment.