Skip to content

Commit

Permalink
[External] [stdlib] String comparisons implemented (#39768)
Browse files Browse the repository at this point in the history
[External] [stdlib] String comparisons implemented

For issue modularml#2346 (as an alternative to modularml#2378). All four comparisons
(`__lt__`, `__le__`, `__gt__`, & `__ge__`) uses a single `__lt__`
comparison (instead of checking less/greater than + potentially another
"equals to"-check, for `__le__` & `__ge__`). Sorry if this is considered
a duplicate PR, I only meant to give an alternative suggestion. This is
my first ever PR on GitHub.

StringLiterals also get comparisons.

ORIGINAL_AUTHOR=Simon Hellsten
<56205346+siitron@users.noreply.github.com>
PUBLIC_PR_LINK=modularml#2409

---------

Co-authored-by: Simon Hellsten <56205346+siitron@users.noreply.github.com>
Closes modularml#2409
MODULAR_ORIG_COMMIT_REV_ID: b2ed4756c2741fd27387fa295515f4a7222e0ca5
  • Loading branch information
modularbot and siitron committed May 11, 2024
1 parent e44a00d commit 782145d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 0 deletions.
54 changes: 54 additions & 0 deletions stdlib/src/builtin/string.mojo
Expand Up @@ -890,6 +890,60 @@ struct String(
"""
return not (self == other)

@always_inline
fn __lt__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using LT comparison.
Args:
rhs: The other String to compare against.
Returns:
True if this String is strictly less than the RHS String and False otherwise.
"""
var len1 = len(self)
var len2 = len(rhs)

if len1 < len2:
return memcmp(self.unsafe_ptr(), rhs.unsafe_ptr(), len1) <= 0
else:
return memcmp(self.unsafe_ptr(), rhs.unsafe_ptr(), len2) < 0

@always_inline
fn __le__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using LE comparison.
Args:
rhs: The other String to compare against.
Returns:
True if this String is less than or equal to the RHS String and False otherwise.
"""
return not (rhs < self)

@always_inline
fn __gt__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using GT comparison.
Args:
rhs: The other String to compare against.
Returns:
True if this String is strictly greater than the RHS String and False otherwise.
"""
return rhs < self

@always_inline
fn __ge__(self, rhs: String) -> Bool:
"""Compare this String to the RHS using GE comparison.
Args:
rhs: The other String to compare against.
Returns:
True if this String is greater than or equal to the RHS String and False otherwise.
"""
return not (self < rhs)

@always_inline
fn __add__(self, other: String) -> String:
"""Creates a string by appending another string at the end.
Expand Down
54 changes: 54 additions & 0 deletions stdlib/src/builtin/string_literal.mojo
Expand Up @@ -127,6 +127,60 @@ struct StringLiteral(
"""
return not self == rhs

@always_inline("nodebug")
fn __lt__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using LT comparison.
Args:
rhs: The other StringLiteral to compare against.
Returns:
True if this StringLiteral is strictly less than the RHS StringLiteral and False otherwise.
"""
var len1 = len(self)
var len2 = len(rhs)

if len1 < len2:
return _memcmp(self.unsafe_ptr(), rhs.unsafe_ptr(), len1) <= 0
else:
return _memcmp(self.unsafe_ptr(), rhs.unsafe_ptr(), len2) < 0

@always_inline("nodebug")
fn __le__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using LE comparison.
Args:
rhs: The other StringLiteral to compare against.
Returns:
True if this StringLiteral is less than or equal to the RHS StringLiteral and False otherwise.
"""
return not (rhs < self)

@always_inline("nodebug")
fn __gt__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using GT comparison.
Args:
rhs: The other StringLiteral to compare against.
Returns:
True if this StringLiteral is strictly greater than the RHS StringLiteral and False otherwise.
"""
return rhs < self

@always_inline("nodebug")
fn __ge__(self, rhs: StringLiteral) -> Bool:
"""Compare this StringLiteral to the RHS using GE comparison.
Args:
rhs: The other StringLiteral to compare against.
Returns:
True if this StringLiteral is greater than or equal to the RHS StringLiteral and False otherwise.
"""
return not (self < rhs)

fn __hash__(self) -> Int:
"""Hash the underlying buffer using builtin hash.
Expand Down
41 changes: 41 additions & 0 deletions stdlib/test/builtin/test_string.mojo
Expand Up @@ -119,6 +119,46 @@ fn test_equality_operators() raises:
assert_not_equal(s0, "notabc")


fn test_comparison_operators() raises:
var abc = String("abc")
var de = String("de")
var ABC = String("ABC")
var ab = String("ab")
var abcd = String("abcd")

# Test less than and greater than
assert_true(String.__lt__(abc, de))
assert_false(String.__lt__(de, abc))
assert_false(String.__lt__(abc, abc))
assert_true(String.__lt__(ab, abc))
assert_true(String.__gt__(abc, ab))
assert_false(String.__gt__(abc, abcd))

# Test less than or equal to and greater than or equal to
assert_true(String.__le__(abc, de))
assert_true(String.__le__(abc, abc))
assert_false(String.__le__(de, abc))
assert_true(String.__ge__(abc, abc))
assert_false(String.__ge__(ab, abc))
assert_true(String.__ge__(abcd, abc))

# Test case sensitivity in comparison (assuming ASCII order)
assert_true(String.__gt__(abc, ABC))
assert_false(String.__le__(abc, ABC))

# Testing with implicit conversion
assert_true(String.__lt__(abc, "defgh"))
assert_false(String.__gt__(abc, "xyz"))
assert_true(String.__ge__(abc, "abc"))
assert_false(String.__le__(abc, "ab"))

# Test comparisons involving empty strings
assert_true(String.__lt__("", abc))
assert_false(String.__lt__(abc, ""))
assert_true(String.__le__("", ""))
assert_true(String.__ge__("", ""))


fn test_add() raises:
var s1 = String("123")
var s2 = String("abc")
Expand Down Expand Up @@ -717,6 +757,7 @@ def main():
test_constructors()
test_copy()
test_equality_operators()
test_comparison_operators()
test_add()
test_stringable()
test_repr()
Expand Down
29 changes: 29 additions & 0 deletions stdlib/test/builtin/test_string_literal.mojo
Expand Up @@ -78,6 +78,34 @@ def test_rfind():
assert_equal(-1, "abc".rfind("abcd"))


fn test_comparison_operators() raises:
# Test less than and greater than
assert_true(StringLiteral.__lt__("abc", "def"))
assert_false(StringLiteral.__lt__("def", "abc"))
assert_false(StringLiteral.__lt__("abc", "abc"))
assert_true(StringLiteral.__lt__("ab", "abc"))
assert_true(StringLiteral.__gt__("abc", "ab"))
assert_false(StringLiteral.__gt__("abc", "abcd"))

# Test less than or equal to and greater than or equal to
assert_true(StringLiteral.__le__("abc", "def"))
assert_true(StringLiteral.__le__("abc", "abc"))
assert_false(StringLiteral.__le__("def", "abc"))
assert_true(StringLiteral.__ge__("abc", "abc"))
assert_false(StringLiteral.__ge__("ab", "abc"))
assert_true(StringLiteral.__ge__("abcd", "abc"))

# Test case sensitivity in comparison (assuming ASCII order)
assert_true(StringLiteral.__gt__("abc", "ABC"))
assert_false(StringLiteral.__le__("abc", "ABC"))

# Test comparisons involving empty strings
assert_true(StringLiteral.__lt__("", "abc"))
assert_false(StringLiteral.__lt__("abc", ""))
assert_true(StringLiteral.__le__("", ""))
assert_true(StringLiteral.__ge__("", ""))


def test_hash():
# Test a couple basic hash behaviors.
# `test_hash.test_hash_bytes` has more comprehensive tests.
Expand Down Expand Up @@ -119,6 +147,7 @@ def main():
test_contains()
test_find()
test_rfind()
test_comparison_operators()
test_hash()
test_intable()
test_repr()

0 comments on commit 782145d

Please sign in to comment.