Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] SIMD conformance to EqualityComparable #2412

Open
wants to merge 1 commit into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -132,6 +132,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Ceilable,
CeilDivable,
CollectionElement,
EqualityComparable,
Floorable,
Hashable,
Intable,
Expand Down Expand Up @@ -869,6 +870,23 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
self.value, rhs.value
)

# TODO there may be a better way to do this:
# the [__:None=None] parameter is to give this overload lower precedence,
# while still conforming to EqualityComparable
@always_inline("nodebug")
fn __eq__[__: None = None](self, rhs: Self) -> Bool:
"""Compares two SIMD vectors using all-equal-to comparison.

This overload allows EqualityComparable conformance.

Args:
rhs: The rhs of the operation.

Returns:
True if all lanes are equal, False otherwise.
"""
return all(self == rhs)

@always_inline("nodebug")
fn __ne__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using not-equal comparison.
Expand All @@ -892,6 +910,23 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
self.value, rhs.value
)

# TODO there may be a better way to do this:
# the [__:None=None] parameter is to give this overload lower precedence,
# while still conforming to EqualityComparable
@always_inline("nodebug")
fn __ne__[__: None = None](self, rhs: Self) -> Bool:
"""Compares two SIMD vectors using any-not-equal comparison.

This overload allows EqualityComparable conformance.

Args:
rhs: The rhs of the operation.

Returns:
True if any lanes are not equal, False otherwise.
"""
return any(self != rhs)

@always_inline("nodebug")
fn __gt__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using greater-than comparison.
Expand Down
48 changes: 0 additions & 48 deletions stdlib/src/testing/testing.mojo
Expand Up @@ -152,29 +152,6 @@ fn assert_equal(lhs: String, rhs: String, msg: String = "") raises:
raise _assert_equal_error(lhs, rhs, msg, __call_location())


@always_inline
fn assert_equal[
type: DType, size: Int
](lhs: SIMD[type, size], rhs: SIMD[type, size], msg: String = "") raises:
"""Asserts that the input values are equal. If it is not then an
Error is raised.

Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.

Args:
lhs: The lhs of the equality.
rhs: The rhs of the equality.
msg: The message to be printed if the assertion fails.

Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
if any(lhs != rhs):
raise _assert_equal_error(str(lhs), str(rhs), msg, __call_location())


@always_inline
fn assert_not_equal[T: Testable](lhs: T, rhs: T, msg: String = "") raises:
"""Asserts that the input values are not equal. If it is not then an
Expand Down Expand Up @@ -214,31 +191,6 @@ fn assert_not_equal(lhs: String, rhs: String, msg: String = "") raises:
raise _assert_not_equal_error(lhs, rhs, msg, __call_location())


@always_inline
fn assert_not_equal[
type: DType, size: Int
](lhs: SIMD[type, size], rhs: SIMD[type, size], msg: String = "") raises:
"""Asserts that the input values are not equal. If it is not then an
Error is raised.

Parameters:
type: The dtype of the left- and right-hand-side SIMD vectors.
size: The width of the left- and right-hand-side SIMD vectors.

Args:
lhs: The lhs of the inequality.
rhs: The rhs of the inequality.
msg: The message to be printed if the assertion fails.

Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
if all(lhs == rhs):
raise _assert_not_equal_error(
str(lhs), str(rhs), msg, __call_location()
)


@always_inline
fn assert_almost_equal[
type: DType, size: Int
Expand Down
12 changes: 6 additions & 6 deletions stdlib/test/builtin/test_math.mojo
Expand Up @@ -63,12 +63,12 @@ def test_max():


def test_round():
assert_equal(0, round(0.0))
assert_equal(1, round(1.0))
assert_equal(1, round(1.1))
assert_equal(2, round(1.5))
assert_equal(2, round(1.9))
assert_equal(2, round(2.0))
assert_equal(0.0, round(0.0))
assert_equal(1.0, round(1.0))
assert_equal(1.0, round(1.1))
assert_equal(2.0, round(1.5))
assert_equal(2.0, round(1.9))
assert_equal(2.0, round(2.0))

Comment on lines 64 to 72
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a couple of tests that had to be modified. They were finding the path of least resistance through the String overload of assert_equal(). If String was EqualityComparable, those overloads could be removed and this could be avoided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(when you add the .0, it uses the correct overload)

var lhs = SIMD[DType.float32, 4](1.1, 1.5, 1.9, 2.0)
var expected = SIMD[DType.float32, 4](1.0, 2.0, 2.0, 2.0)
Expand Down
24 changes: 24 additions & 0 deletions stdlib/test/builtin/test_simd.mojo
Expand Up @@ -1042,6 +1042,29 @@ def test_abs():
)


def test_equatable():
var s1 = SIMD[DType.int32, 2](1, 2)
var s2 = SIMD[DType.int32, 2](1, 6)
var s3 = SIMD[DType.int32, 2](6, 2)
var s4 = SIMD[DType.int32, 2](6, 6)

assert_equal(s1 == s1, SIMD[DType.bool, 2](True, True))
assert_equal(s1 != s1, SIMD[DType.bool, 2](False, False))
assert_equal(s1, s1)

assert_equal(s1 == s2, SIMD[DType.bool, 2](True, False))
assert_equal(s1 != s2, SIMD[DType.bool, 2](False, True))
assert_not_equal(s1, s2)

assert_equal(s1 == s3, SIMD[DType.bool, 2](False, True))
assert_equal(s1 != s3, SIMD[DType.bool, 2](True, False))
assert_not_equal(s1, s3)

assert_equal(s1 == s4, SIMD[DType.bool, 2](False, False))
assert_equal(s1 != s4, SIMD[DType.bool, 2](True, True))
assert_not_equal(s1, s4)


def test_min_max_clamp():
alias F = SIMD[DType.float32, 4]

Expand Down Expand Up @@ -1351,3 +1374,4 @@ def main():
test_trunc()
test_truthy()
test_reduce()
test_equatable()
2 changes: 1 addition & 1 deletion stdlib/test/builtin/test_string.mojo
Expand Up @@ -417,7 +417,7 @@ fn test_atof() raises:
assert_equal(1.0, atof(String("001.")))
assert_equal(+5.0, atof(String(" +005.")))
assert_equal(13.0, atof(String(" 013.f ")))
assert_equal(-89, atof(String("-89")))
assert_equal(-89.0, atof(String("-89")))
assert_equal(-0.3, atof(String(" -0.3")))
assert_equal(-69e3, atof(String(" -69E+3 ")))
assert_equal(123.2e1, atof(String(" 123.2E1 ")))
Expand Down