Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] String comparisons implemented #2409

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 54 additions & 0 deletions stdlib/src/builtin/string.mojo
Expand Up @@ -734,6 +734,60 @@ struct String(
"""
return not (self == other)

@always_inline
fn __lt__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using LT comparison.

Args:
rhs: The other String to compare against.

Returns:
True if this String is strictly less than the RHS String and False otherwise.
"""
var len1 = len(self)
var len2 = len(rhs)

if len1 < len2:
return memcmp(self._as_ptr(), rhs._as_ptr(), len1) <= 0
else:
return memcmp(self._as_ptr(), rhs._as_ptr(), len2) < 0

@always_inline
fn __le__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using LE comparison.

Args:
rhs: The other String to compare against.

Returns:
True if this String is less than or equal to the RHS String and False otherwise.
"""
return not (rhs < self)

@always_inline
fn __gt__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using GT comparison.

Args:
rhs: The other String to compare against.

Returns:
True if this String is strictly greater than the RHS String and False otherwise.
"""
return rhs < self

@always_inline
fn __ge__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using GE comparison.

Args:
rhs: The other String to compare against.

Returns:
True if this String is greater than or equal to the RHS String and False otherwise.
"""
return not (self < rhs)

@always_inline
fn __add__(self, other: String) -> String:
"""Creates a string by appending another string at the end.
Expand Down
54 changes: 54 additions & 0 deletions stdlib/src/builtin/string_literal.mojo
Expand Up @@ -129,6 +129,60 @@ struct StringLiteral(
"""
return not self == rhs

@always_inline("nodebug")
fn __lt__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using LT comparison.

Args:
rhs: The other StringLiteral to compare against.

Returns:
True if this StringLiteral is strictly less than the RHS StringLiteral and False otherwise.
"""
var len1 = len(self)
var len2 = len(rhs)

if len1 < len2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: you can consolidate the duplication between this and the implementation in String by implementing the comparison on StringRef

Copy link
Author

@siitron siitron Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like return StringRef(self) < StringRef(rhs) as __lt__ for StringLiteral (and return self._strref_dangerous() < rhs._strref_dangerous() for String), etc.?

I didn't even consider StringRef (forgot it existed), and I tried to match __lt__ with their respective __eq__-methods. Right now String, StringLiteral, and StringRef does __eq__ in three separate ways, so I thought there was some reason for keeping them apart ATM(?). Would it still be safe to use memory.memcmp in the StringRef __lt__ implementation (considering it currently looks like StringLiteral has to use it's own local _memcmp, and StringRef's __eq__ also avoids it), or should StringRef's __lt__ be done like its __eq__?

What do you think of:

    @always_inline("nodebug")
    fn __lt__(self, rhs: StringRef) -> Bool:
        """Compare this StringRef to the RHS using LT comparison.

        Args:
            rhs: The other StringRef to compare against.

        Returns:
            True if this StringRef is strictly less than the RHS StringRef and False otherwise.
        """
        var len1 = len(self)
        var len2 = len(rhs)

        if len1 < len2:
            return memcmp(self.data, rhs.data, len1) <= 0
        else:
            return memcmp(self.data, rhs.data, len2) < 0

return _memcmp(self.data(), rhs.data(), len1) <= 0
else:
return _memcmp(self.data(), rhs.data(), len2) < 0

@always_inline("nodebug")
fn __le__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using LE comparison.

Args:
rhs: The other StringLiteral to compare against.

Returns:
True if this StringLiteral is less than or equal to the RHS StringLiteral and False otherwise.
"""
return not (rhs < self)

@always_inline("nodebug")
fn __gt__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using GT comparison.

Args:
rhs: The other StringLiteral to compare against.

Returns:
True if this StringLiteral is strictly greater than the RHS StringLiteral and False otherwise.
"""
return rhs < self

@always_inline("nodebug")
fn __ge__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using GE comparison.

Args:
rhs: The other StringLiteral to compare against.

Returns:
True if this StringLiteral is greater than or equal to the RHS StringLiteral and False otherwise.
"""
return not (self < rhs)

fn __hash__(self) -> Int:
"""Hash the underlying buffer using builtin hash.

Expand Down
51 changes: 51 additions & 0 deletions stdlib/test/builtin/test_string.mojo
Expand Up @@ -98,6 +98,56 @@ fn test_equality_operators() raises:
assert_not_equal(s0, "notabc")


fn test_comparison_operators() raises:
var abc = String("abc")
var de = String("de")
var ABC = String("ABC")
var ab = String("ab")
var abcd = String("abcd")

# Test less than and greater than
assert_true(abc < de)
assert_false(de < abc)
assert_false(abc < abc)
assert_true(ab < abc)
assert_true(abc > ab)
assert_false(abc > abcd)

# Test less than or equal to and greater than or equal to
assert_true(abc <= de)
assert_true(abc <= abc)
assert_false(de <= abc)
assert_true(abc >= abc)
assert_false(ab >= abc)
assert_true(abcd >= abc)

# Test case sensitivity in comparison (assuming ASCII order)
assert_true(abc > ABC)
assert_false(abc <= ABC)

# Testing with implicit conversion
assert_true(abc < "defgh")
assert_false(abc > "xyz")
assert_true(abc >= "abc")
assert_false(abc <= "ab")

# Test against empty strings
assert_true(str("") < abc)
assert_false(abc < "")
assert_true(str("") <= "")
assert_true(str("") >= "")

# Test comparisons involving default constructed empty string
assert_true(String() < abc)
assert_false(abc < String())
assert_true(String() <= abc)
assert_false(abc <= String())
assert_true(String() <= "")
assert_true(String() >= "")
assert_false(abc <= String())
assert_true(String() >= String())


fn test_add() raises:
var s1 = String("123")
var s2 = String("abc")
Expand Down Expand Up @@ -697,6 +747,7 @@ def main():
test_constructors()
test_copy()
test_equality_operators()
test_comparison_operators()
test_add()
test_stringable()
test_string_join()
Expand Down
29 changes: 29 additions & 0 deletions stdlib/test/builtin/test_string_literal.mojo
Expand Up @@ -78,6 +78,34 @@ def test_rfind():
assert_equal(-1, "abc".rfind("abcd"))


fn test_comparison_operators() raises:
# Test less than and greater than
assert_true("abc" < "def")
assert_false("def" < "abc")
assert_false("abc" < "abc")
assert_true("ab" < "abc")
assert_true("abc" > "ab")
assert_false("abc" > "abcd")

# Test less than or equal to and greater than or equal to
assert_true("abc" <= "def")
assert_true("abc" <= "abc")
assert_false("def" <= "abc")
assert_true("abc" >= "abc")
assert_false("ab" >= "abc")
assert_true("abcd" >= "abc")

# Test case sensitivity in comparison (assuming ASCII order)
assert_true("abc" > "ABC")
assert_false("abc" <= "ABC")

# Test against empty strings
assert_true("" < "abc")
assert_false("abc" < "")
assert_true("" <= "")
assert_true("" >= "")


def test_hash():
# Test a couple basic hash behaviors.
# `test_hash.test_hash_bytes` has more comprehensive tests.
Expand All @@ -99,5 +127,6 @@ def main():
test_contains()
test_find()
test_rfind()
test_comparison_operators()
test_hash()
test_intable()