Skip to content

Commit

Permalink
simd conformance to EqualityComparable
Browse files Browse the repository at this point in the history
constrain SIMD.__bool__() to size=1
change to explicit reduce for SIMD

Signed-off-by: Max Brylski <helehex@gmail.com>
  • Loading branch information
helehex committed Apr 29, 2024
1 parent 725c356 commit 2dcd097
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 117 deletions.
18 changes: 0 additions & 18 deletions stdlib/src/builtin/bool.mojo
Expand Up @@ -68,24 +68,6 @@ struct Bool(
_type = __mlir_type.`!pop.scalar<bool>`
](value)

@always_inline("nodebug")
fn __init__[width: Int](value: SIMD[DType.bool, width]) -> Bool:
"""Construct a Bool value given a SIMD value.
If there is more than a single element in the SIMD value, then value is
reduced using the and operator.
Parameters:
width: SIMD width.
Args:
value: The initial SIMD value.
Returns:
The constructed Bool value.
"""
return value.__bool__()

@always_inline("nodebug")
fn __init__[boolable: Boolable](value: boolable) -> Bool:
"""Implicitly convert a Boolable value to a Bool.
Expand Down
72 changes: 54 additions & 18 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -370,10 +370,13 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
otherwise.
"""

constrained[size == 1, "size must be 1"]()

@parameter
if Self.element_type == DType.bool:
return self.reduce_and()
return (self != 0).reduce_and()
if type == DType.bool:
return rebind[Scalar[DType.bool]](self).value
else:
return rebind[Scalar[type]](self).cast[DType.bool]().value

@staticmethod
@always_inline("nodebug")
Expand Down Expand Up @@ -609,7 +612,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
constrained[type.is_numeric(), "the type must be numeric"]()

if rhs == 0:
if (rhs == 0).reduce_and():
# this should raise an exception.
return 0

Expand All @@ -621,7 +624,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
elif type.is_unsigned():
return div
else:
if self > 0 and rhs > 0:
if ((self > 0) & (rhs > 0)).reduce_and():
return div

var mod = self - div * rhs
Expand All @@ -640,7 +643,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
constrained[type.is_numeric(), "the type must be numeric"]()

if rhs == 0:
if (rhs == 0).reduce_and():
# this should raise an exception.
return 0

Expand Down Expand Up @@ -764,15 +767,17 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# while still conforming to EqualityComparable
@always_inline("nodebug")
fn __eq__[__: None = None](self, rhs: Self) -> Bool:
"""Compares two SIMD vectors using equal-to comparison.
"""Compares two SIMD vectors using all-equal-to comparison.
This overload allows EqualityComparable conformance.
Args:
rhs: The rhs of the operation.
Returns:
True if every lane is equal, False otherwise.
True if all lanes are equal, False otherwise.
"""
return (self == rhs).reduce_and()
return _all_equal(self, rhs)

@always_inline("nodebug")
fn __ne__(self, rhs: Self) -> SIMD[DType.bool, size]:
Expand Down Expand Up @@ -802,15 +807,17 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# while still conforming to EqualityComparable
@always_inline("nodebug")
fn __ne__[__: None = None](self, rhs: Self) -> Bool:
"""Compares two SIMD vectors using equal-to comparison.
"""Compares two SIMD vectors using any-not-equal comparison.
This overload allows EqualityComparable conformance.
Args:
rhs: The rhs of the operation.
Returns:
True if any lanes are not equal, False otherwise.
"""
return (self != rhs).reduce_or()
return _any_not_equal(self, rhs)

@always_inline("nodebug")
fn __gt__(self, rhs: Self) -> SIMD[DType.bool, size]:
Expand Down Expand Up @@ -2054,8 +2061,10 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""

@parameter
if size == 1:
return self.cast[DType.bool]()[0].value
if type != DType.bool:
return self.cast[DType.bool]().reduce_and()
elif size == 1:
return rebind[Scalar[DType.bool]](self)
return llvm_intrinsic[
"llvm.vector.reduce.and", Scalar[DType.bool], has_side_effect=False
](self)
Expand All @@ -2072,8 +2081,10 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""

@parameter
if size == 1:
return self.cast[DType.bool]()[0].value
if type != DType.bool:
return self.cast[DType.bool]().reduce_or()
elif size == 1:
return rebind[Scalar[DType.bool]](self)
return llvm_intrinsic[
"llvm.vector.reduce.or", Scalar[DType.bool], has_side_effect=False
](self)
Expand Down Expand Up @@ -2307,7 +2318,7 @@ fn _pow[
@parameter
if rhs_type.is_floating_point() and lhs_type == rhs_type:
var rhs_quotient = _floor(rhs)
if rhs >= 0 and rhs_quotient == rhs:
if ((rhs >= 0) & (rhs_quotient == rhs)).reduce_and():
return _pow(lhs, rhs_quotient.cast[_integral_type_of[rhs_type]()]())

var result = SIMD[lhs_type, simd_width]()
Expand All @@ -2321,9 +2332,9 @@ fn _pow[
return result
elif rhs_type.is_integral():
# Common cases
if rhs == 2:
if (rhs == 2).reduce_and():
return lhs * lhs
if rhs == 3:
if (rhs == 3).reduce_and():
return lhs * lhs * lhs

var result = SIMD[lhs_type, simd_width]()
Expand Down Expand Up @@ -2487,6 +2498,31 @@ fn _f32_to_bfloat16[
return _simd_apply[wrapper_fn, DType.bfloat16, size](val)


# ===----------------------------------------------------------------------===#
# Comparison
# ===----------------------------------------------------------------------===#


@always_inline("nodebug")
fn _all_equal(a: SIMD, b: __type_of(a)) -> Bool:
return (a == b).reduce_and()


@always_inline("nodebug")
fn _all_not_equal(a: SIMD, b: __type_of(a)) -> Bool:
return (a != b).reduce_and()


@always_inline("nodebug")
fn _any_equal(a: SIMD, b: __type_of(a)) -> Bool:
return (a == b).reduce_or()


@always_inline("nodebug")
fn _any_not_equal(a: SIMD, b: __type_of(a)) -> Bool:
return (a != b).reduce_or()


# ===----------------------------------------------------------------------===#
# Limits
# ===----------------------------------------------------------------------===#
Expand Down
56 changes: 3 additions & 53 deletions stdlib/src/testing/testing.mojo
Expand Up @@ -45,7 +45,7 @@ fn _isclose(
if a.type.is_bool() or a.type.is_integral():
return a == b

if equal_nan and isnan(a) and isnan(b):
if equal_nan and (isnan(a) & isnan(b)).reduce_and():
return True

var atol_vec = SIMD[a.type, a.size](atol)
Expand All @@ -55,7 +55,7 @@ fn _isclose(
if not equal_nan:
return res

return res.select(res, isnan(a) and isnan(b))
return res.select(res, isnan(a) & isnan(b))


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -149,31 +149,6 @@ fn assert_equal(lhs: String, rhs: String, msg: String = "") raises:
raise _assert_equal_error(lhs, rhs, msg, __call_location())


@always_inline
fn assert_equal[
type: DType, size: Int
](lhs: SIMD[type, size], rhs: SIMD[type, size], msg: String = "") raises:
"""Asserts that the input values are equal. If it is not then an
Error is raised.
Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.
Args:
lhs: The lhs of the equality.
rhs: The rhs of the equality.
msg: The message to be printed if the assertion fails.
Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
# `if lhs != rhs:` is not enough. `reduce_or()` must be used here,
# otherwise, if any of the elements are equal, the error is not triggered.
if (lhs != rhs).reduce_or():
raise _assert_equal_error(str(lhs), str(rhs), msg, __call_location())


@always_inline
fn assert_not_equal[T: Testable](lhs: T, rhs: T, msg: String = "") raises:
"""Asserts that the input values are not equal. If it is not then an
Expand Down Expand Up @@ -213,31 +188,6 @@ fn assert_not_equal(lhs: String, rhs: String, msg: String = "") raises:
raise _assert_not_equal_error(lhs, rhs, msg, __call_location())


@always_inline
fn assert_not_equal[
type: DType, size: Int
](lhs: SIMD[type, size], rhs: SIMD[type, size], msg: String = "") raises:
"""Asserts that the input values are not equal. If it is not then an
Error is raised.
Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.
Args:
lhs: The lhs of the inequality.
rhs: The rhs of the inequality.
msg: The message to be printed if the assertion fails.
Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
if lhs == rhs:
raise _assert_not_equal_error(
str(lhs), str(rhs), msg, __call_location()
)


@always_inline
fn assert_almost_equal[
type: DType, size: Int
Expand Down Expand Up @@ -272,7 +222,7 @@ fn assert_almost_equal[
var almost_equal = _isclose(
lhs, rhs, atol=atol, rtol=rtol, equal_nan=equal_nan
)
if not almost_equal:
if not almost_equal.reduce_and():
var err = str(lhs) + " is not close to " + str(
rhs
) + " with a diff of " + _abs(lhs - rhs)
Expand Down
4 changes: 1 addition & 3 deletions stdlib/src/utils/_numerics.mojo
Expand Up @@ -636,9 +636,7 @@ fn isnan[
if type == DType.bfloat16:
alias int_dtype = _integral_type_of[type]()
var int_val = bitcast[int_dtype, simd_width](val)
return int_val & SIMD[int_dtype, simd_width](0x7FFF) > SIMD[
int_dtype, simd_width
](0x7F80)
return int_val & 0x7FFF > 0x7F80

alias signaling_nan_test: UInt32 = 0x0001
alias quiet_nan_test: UInt32 = 0x0002
Expand Down
41 changes: 16 additions & 25 deletions stdlib/test/builtin/test_simd.mojo
Expand Up @@ -14,7 +14,7 @@

from sys import has_neon, simdwidthof

from testing import assert_equal, assert_not_equal, assert_true
from testing import assert_equal, assert_not_equal, assert_true, assert_false


def test_cast():
Expand Down Expand Up @@ -122,16 +122,19 @@ def test_truthy():
@parameter
fn test_dtype[type: DType]() raises:
# # Scalars of 0-values are false-y, 1-values are truth-y
assert_equal(False, Scalar[type](False).__bool__())
assert_equal(True, Scalar[type](True).__bool__())
assert_false(Scalar[type](False))
assert_true(Scalar[type](True))

# # SIMD vectors are truth-y if _all_ values are truth-y
assert_equal(True, SIMD[type, 2](True, True).__bool__())
# # SIMD vectors with size > 1 must be explicitly reduced
assert_true(SIMD[type, 2](True, True).reduce_and())
assert_false(SIMD[type, 2](False, True).reduce_and())
assert_false(SIMD[type, 2](True, False).reduce_and())
assert_false(SIMD[type, 2](False, False).reduce_and())

# # SIMD vectors are false-y if _any_ values are false-y
assert_equal(False, SIMD[type, 2](False, True).__bool__())
assert_equal(False, SIMD[type, 2](True, False).__bool__())
assert_equal(False, SIMD[type, 2](False, False).__bool__())
assert_true(SIMD[type, 2](True, True).reduce_or())
assert_true(SIMD[type, 2](False, True).reduce_or())
assert_true(SIMD[type, 2](True, False).reduce_or())
assert_false(SIMD[type, 2](False, False).reduce_or())

@parameter
fn test_dtype_unrolled[i: Int]() raises:
Expand Down Expand Up @@ -716,38 +719,26 @@ def test_mul_with_overflow():


def test_equatable():
@parameter
fn eq[T: EqualityComparable](a: T, b: T) -> Bool:
return a == b

@parameter
fn ne[T: EqualityComparable](a: T, b: T) -> Bool:
return a != b

var s1 = SIMD[DType.int32, 2](1, 2)
var s2 = SIMD[DType.int32, 2](1, 6)
var s3 = SIMD[DType.int32, 2](6, 2)
var s4 = SIMD[DType.int32, 2](6, 6)

assert_equal(s1 == s1, SIMD[DType.bool, 2](True, True), "s1 == s1")
assert_equal(s1 != s1, SIMD[DType.bool, 2](False, False), "s1 != s1")
assert_equal(eq(s1, s1), True, "eq(s1, s1)")
assert_equal(ne(s1, s1), False, "ne(s1, s1)")
assert_equal(s1, s1, "eq(s1, s1)")

assert_equal(s1 == s2, SIMD[DType.bool, 2](True, False), "s1 == s2")
assert_equal(s1 != s2, SIMD[DType.bool, 2](False, True), "s1 != s2")
assert_equal(eq(s1, s2), False, "eq(s1, s2)")
assert_equal(ne(s1, s2), True, "ne(s1, s2)")
assert_not_equal(s1, s2, "ne(s1, s2)")

assert_equal(s1 == s3, SIMD[DType.bool, 2](False, True), "s1 == s3")
assert_equal(s1 != s3, SIMD[DType.bool, 2](True, False), "s1 != s3")
assert_equal(eq(s1, s3), False, "eq(s1, s3)")
assert_equal(ne(s1, s3), True, "ne(s1, s3)")
assert_not_equal(s1, s3, "ne(s1, s3)")

assert_equal(s1 == s4, SIMD[DType.bool, 2](False, False), "s1 == s4")
assert_equal(s1 != s4, SIMD[DType.bool, 2](True, True), "s1 != s4")
assert_equal(eq(s1, s4), False, "eq(s1, s4)")
assert_equal(ne(s1, s4), True, "ne(s1, s4)")
assert_not_equal(s1, s4, "ne(s1, s4)")


def main():
Expand Down

0 comments on commit 2dcd097

Please sign in to comment.