Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Enhance Handling of Infinity and NaN in assert_almost_equal #2375

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 41 additions & 17 deletions stdlib/src/testing/testing.mojo
Expand Up @@ -19,7 +19,7 @@ from testing import assert_true
```
"""
from collections import Optional
from utils._numerics import isnan
from utils._numerics import isfinite, isnan
from builtin._location import __call_location, _SourceLocation

# ===----------------------------------------------------------------------=== #
Expand All @@ -36,21 +36,29 @@ fn _isclose(
rtol: Scalar[a.type],
equal_nan: Bool,
) -> SIMD[DType.bool, a.size]:
constrained[
a.type.is_bool() or a.type.is_integral() or a.type.is_floating_point(),
"input type must be boolean, integral, or floating-point",
]()

@parameter
if a.type.is_bool() or a.type.is_integral():
return a == b
else:
leandrolcampos marked this conversation as resolved.
Show resolved Hide resolved
var both_nan = isnan(a) & isnan(b)
if equal_nan and both_nan.reduce_and():
return True

var res = (a == b)
var atol_vec = SIMD[a.type, a.size](atol)
var rtol_vec = SIMD[a.type, a.size](rtol)
res |= (
isfinite(a)
& isfinite(b)
& (abs(a - b) <= (atol_vec.max(rtol_vec * abs(a).max(abs(b)))))
)

if equal_nan and isnan(a) and isnan(b):
return True

var atol_vec = SIMD[a.type, a.size](atol)
var rtol_vec = SIMD[a.type, a.size](rtol)
var res = abs(a - b) <= (atol_vec.max(rtol_vec * abs(a).max(abs(b))))

if not equal_nan:
return res

return res.select(res, isnan(a) and isnan(b))
return res | both_nan if equal_nan else res


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -249,6 +257,14 @@ fn assert_almost_equal[
"""Asserts that the input values are equal up to a tolerance. If it is
not then an Error is raised.

When the type is boolean or integral, then equality is checked. When the
type is floating-point, then this checks if the two input values are
numerically the close using the $abs(lhs - rhs) <= max(rtol * max(abs(lhs),
abs(rhs)), atol)$ formula.

Constraints:
The type must be boolean, integral, or floating-point.

Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.
Expand All @@ -257,23 +273,31 @@ fn assert_almost_equal[
lhs: The lhs of the equality.
rhs: The rhs of the equality.
msg: The message to print.
atol: The _absolute tolerance.
atol: The absolute tolerance.
rtol: The relative tolerance.
equal_nan: Whether to treat nans as equal.

Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
constrained[
type.is_bool() or type.is_integral() or type.is_floating_point(),
"type must be boolean, integral, or floating-point",
]()

var almost_equal = _isclose(
lhs, rhs, atol=atol, rtol=rtol, equal_nan=equal_nan
)
if not almost_equal:
var err = str(lhs) + " is not close to " + str(
rhs
) + " with a diff of " + abs(lhs - rhs)
if not almost_equal.reduce_and():
var err = str(lhs) + " is not close to " + str(rhs)

@parameter
if type.is_integral() or type.is_floating_point():
err += " with a diff of " + str(abs(lhs - rhs))

if msg:
err += " (" + msg + ")"

raise _assert_error(err, __call_location())


Expand Down
54 changes: 54 additions & 0 deletions stdlib/src/utils/_numerics.mojo
Expand Up @@ -647,6 +647,60 @@ fn isnan[
](val.value, (signaling_nan_test | quiet_nan_test).value)


# ===----------------------------------------------------------------------===#
# inf
# ===----------------------------------------------------------------------===#


@always_inline("nodebug")
fn inf[type: DType]() -> Scalar[type]:
"""Gets a +inf value for the given dtype.
Constraints:
Can only be used for FP dtypes.
Parameters:
type: The value dtype.
Returns:
The +inf value of the given dtype.
"""

@parameter
if type == DType.float16:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f16>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f16>`],
]()
)
elif type == DType.bfloat16:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<bf16>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<bf16>`],
]()
)
elif type == DType.float32:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f32>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f32>`],
]()
)
elif type == DType.float64:
return rebind[__mlir_type[`!pop.scalar<`, type.value, `>`]](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<f64>`],
value = __mlir_attr[`#pop.simd<"inf"> : !pop.scalar<f64>`],
]()
)
else:
constrained[False, "+inf only support on floating point types"]()

return 0


# ===----------------------------------------------------------------------===#
# isinf
# ===----------------------------------------------------------------------===#
Expand Down
114 changes: 108 additions & 6 deletions stdlib/test/testing/test_assertion.mojo
Expand Up @@ -13,13 +13,14 @@
# RUN: %mojo -debug-level full %s

from testing import (
assert_almost_equal,
assert_equal,
assert_false,
assert_not_equal,
assert_raises,
assert_true,
assert_false,
assert_almost_equal,
)
from utils._numerics import inf, nan


@value
Expand Down Expand Up @@ -61,26 +62,127 @@ def test_assert_messages():
try:
assert_true(False)
except e:
assert_true("test_assertion.mojo:62:20: AssertionError:" in str(e))
assert_true("test_assertion.mojo:63:20: AssertionError:" in str(e))

try:
assert_false(True)
except e:
assert_true("test_assertion.mojo:67:21: AssertionError:" in str(e))
assert_true("test_assertion.mojo:68:21: AssertionError:" in str(e))

try:
assert_equal(1, 0)
except e:
assert_true("test_assertion.mojo:72:21: AssertionError:" in str(e))
assert_true("test_assertion.mojo:73:21: AssertionError:" in str(e))

try:
assert_not_equal(0, 0)
except e:
assert_true("test_assertion.mojo:77:25: AssertionError:" in str(e))
assert_true("test_assertion.mojo:78:25: AssertionError:" in str(e))


def test_assert_almost_equal():
alias float_type = DType.float32
alias _inf = inf[float_type]()
alias _nan = nan[float_type]()

@parameter
def _should_succeed[
type: DType, size: Int
](
lhs: SIMD[type, size],
rhs: SIMD[type, size],
*,
atol: Scalar[type] = 0,
rtol: Scalar[type] = 0,
equal_nan: Bool = False,
):
var msg = "`test_assert_almost_equal` should have succeeded"
assert_almost_equal(
lhs, rhs, msg=msg, atol=atol, rtol=rtol, equal_nan=equal_nan
)

_should_succeed[DType.bool, 1](True, True)
_should_succeed(SIMD[DType.int32, 2](0, 1), SIMD[DType.int32, 2](0, 1))
_should_succeed(
SIMD[float_type, 2](-_inf, _inf), SIMD[float_type, 2](-_inf, _inf)
)
_should_succeed(
SIMD[float_type, 2](-_nan, _nan),
SIMD[float_type, 2](-_nan, _nan),
equal_nan=True,
)
_should_succeed(
SIMD[float_type, 2](1.0, -1.1),
SIMD[float_type, 2](1.1, -1.0),
atol=0.11,
)
_should_succeed(
SIMD[float_type, 2](1.0, -1.1),
SIMD[float_type, 2](1.1, -1.0),
rtol=0.10,
)

@parameter
def _should_fail[
type: DType, size: Int
](
lhs: SIMD[type, size],
rhs: SIMD[type, size],
*,
atol: Scalar[type] = 0,
rtol: Scalar[type] = 0,
equal_nan: Bool = False,
):
var msg = "`test_assert_almost_equal` should have failed"
with assert_raises(contains=msg):
assert_almost_equal(
lhs, rhs, msg=msg, atol=atol, rtol=rtol, equal_nan=equal_nan
)

_should_fail[DType.bool, 1](True, False)
_should_fail(
SIMD[DType.int32, 2](0, 1), SIMD[DType.int32, 2](0, -1), atol=5.0
)
_should_fail(
SIMD[float_type, 2](-_inf, 0.0),
SIMD[float_type, 2](_inf, 0.0),
rtol=0.1,
)
_should_fail(
SIMD[float_type, 2](_inf, 0.0),
SIMD[float_type, 2](0.0, 0.0),
rtol=0.1,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](_nan, 0.0),
equal_nan=False,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](0.0, 0.0),
equal_nan=False,
)
_should_fail(
SIMD[float_type, 2](_nan, 0.0),
SIMD[float_type, 2](0.0, 0.0),
equal_nan=True,
)
_should_fail(
SIMD[float_type, 2](1.0, 0.0),
SIMD[float_type, 2](1.1, 0.0),
atol=0.05,
)
_should_fail(
SIMD[float_type, 2](-1.0, 0.0),
SIMD[float_type, 2](-1.1, 0.0),
rtol=0.05,
)


def main():
test_assert_equal_is_generic()
test_assert_not_equal_is_generic()
test_assert_equal_with_simd()
test_assert_messages()
test_assert_almost_equal()
21 changes: 20 additions & 1 deletion stdlib/test/utils/test_numerics.mojo
Expand Up @@ -12,8 +12,9 @@
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s

from utils._numerics import FPUtils
from sys.info import has_neon
from testing import assert_equal, assert_true, assert_false
from utils._numerics import FPUtils, inf, isinf

alias FPU64 = FPUtils[DType.float64]

Expand Down Expand Up @@ -45,5 +46,23 @@ fn test_numerics() raises:
assert_equal(FPU64.get_mantissa(FPU64.pack(True, 6, 12)), 12)


fn test_inf() raises:
@parameter
fn _test_inf[type: DType]() raises:
var val = inf[type]()
var msg = "`test_inf` failed for `type == " + str(type) + "`"
assert_true((val > 0.0) & isinf(val), msg=msg)

@parameter
if not has_neon():
# "bf16 is not supported for ARM architectures"
_test_inf[DType.bfloat16]()

_test_inf[DType.float16]()
_test_inf[DType.float32]()
_test_inf[DType.float64]()


def main():
test_numerics()
test_inf()