Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Support reversed for Dict #2340

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.md
Expand Up @@ -18,6 +18,8 @@ what we publish.

### ⭐️ New

- `Dict()` now support `reversed` for `dict.items()` and `dict.values()`. ([PR #2340](https://github.com/modularml/mojo/pull/2340) by [@jayzhan211](https://github.com/jayzhan211))

- `int()` can now take a string and a specified base to parse an integer from a
string: `int("ff", 16)` returns `255`. Additionally, if a base of zero is
specified, the string will be parsed as if it was an integer literal, with the
Expand Down
77 changes: 75 additions & 2 deletions stdlib/src/builtin/reversed.mojo
Expand Up @@ -19,7 +19,7 @@ from .range import _StridedRangeIterator

from collections.list import _ListIter

from collections.dict import _DictKeyIter
from collections.dict import _DictKeyIter, _DictValueIter, _DictEntryIter

# ===----------------------------------------------------------------------=== #
# Reversible
Expand Down Expand Up @@ -119,6 +119,79 @@ fn reversed[
value: The dict to get the reversed iterator of.

Returns:
The reversed iterator of the dict.
The reversed iterator of the dict keys.
"""
return value[].__reversed__()


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
](
value: Reference[
_DictValueIter[K, V, dict_mutability, dict_lifetime],
mutability,
self_life,
]._mlir_type,
) -> _DictValueIter[K, V, dict_mutability, dict_lifetime, False]:
"""Get a reversed iterator of the input dict values.

**Note**: iterators are currently non-raising.

Parameters:
mutability: Whether the reference to the dict is mutable.
self_life: The lifetime of the dict.
K: The type of the keys in the dict.
V: The type of the values in the dict.
dict_mutability: Whether the reference to the dict values is mutable.
dict_lifetime: The lifetime of the dict values.

Args:
value: The dict values to get the reversed iterator of.

Returns:
The reversed iterator of the dict values.
"""
return Reference(value)[].__reversed__[mutability, self_life]()


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
](
value: Reference[
_DictEntryIter[K, V, dict_mutability, dict_lifetime],
mutability,
self_life,
]._mlir_type,
) -> _DictEntryIter[K, V, dict_mutability, dict_lifetime, False]:
"""Get a reversed iterator of the input dict items.

**Note**: iterators are currently non-raising.

Parameters:
mutability: Whether the reference to the dict is mutable.
self_life: The lifetime of the dict.
K: The type of the keys in the dict.
V: The type of the values in the dict.
dict_mutability: Whether the reference to the dict items is mutable.
dict_lifetime: The lifetime of the dict items.

Args:
value: The dict items to get the reversed iterator of.

Returns:
The reversed iterator of the dict items.
"""
var src = Reference(value)[].src
return _DictEntryIter[K, V, dict_mutability, dict_lifetime, False](
src[]._reserved, 0, src
)
14 changes: 13 additions & 1 deletion stdlib/src/collections/dict.mojo
Expand Up @@ -163,6 +163,7 @@ struct _DictValueIter[
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
forward: Bool = True,
]:
"""Iterator over Dict value references. These are mutable if the dict
is mutable.
Expand All @@ -172,15 +173,26 @@ struct _DictValueIter[
V: The value type of the elements in the dictionary.
dict_mutability: Whether the reference to the vector is mutable.
dict_lifetime: The lifetime of the List
forward: The iteration direction. `False` is backwards.
"""

alias ref_type = Reference[V, dict_mutability, dict_lifetime]

var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime]
var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime, forward]

fn __iter__(self) -> Self:
return self

fn __reversed__[
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](self) -> _DictValueIter[K, V, dict_mutability, dict_lifetime, False]:
var src = self.iter.src
return _DictValueIter(
_DictEntryIter[K, V, dict_mutability, dict_lifetime, False](
src[]._reserved, 0, src
)
)

fn __next__(inout self) -> Self.ref_type:
var entry_ref = self.iter.__next__()
# Cast through a pointer to grant additional mutability because
Expand Down
74 changes: 62 additions & 12 deletions stdlib/test/builtin/test_reversed.mojo
Expand Up @@ -30,19 +30,35 @@ def test_reversed_dict():
dict["b"] = 2
dict["c"] = 3
dict["d"] = 4
dict["a"] = 5
dict["a"] = 1

var keys = String("")
for key in reversed(dict):
keys += key[]

assert_equal(keys, "dcba")

var check: Int = 4
for val in reversed(dict.values()):
assert_equal(val[], check)
check -= 1

keys = String("")
check = 4
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check -= 1

assert_equal(keys, "dcba")

# Order preserved

_ = dict.pop("a")
_ = dict.pop("c")

# dict: {'b': 2, 'd': 4}

keys = String("")
for key in dict:
keys += key[]
Expand All @@ -55,29 +71,63 @@ def test_reversed_dict():

assert_equal(keys, "db")

# got 4 and 2
check = 4
for val in reversed(dict.values()):
assert_equal(val[], check)
check -= 2

keys = String("")
check = 4
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check -= 2

assert_equal(keys, "db")

# Refill dict
dict["c"] = 2
dict["a"] = 1
dict["b"] = 4
dict["d"] = 3

# dict: {'b': 4, 'd': 3, 'c': 2, 'a': 1}

keys = String("")
check = 1
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check += 1

assert_equal(keys, "acdb")

# Empty dict is iterable

_ = dict.pop("b")
_ = dict.pop("d")
var empty_dict = Dict[String, Int]()

keys = String("")
for key in reversed(dict):
for key in reversed(empty_dict):
keys += key[]

assert_equal(keys, "")

# Refill dict
check = 0
for val in reversed(empty_dict.values()):
# values is empty, should not reach here
check += 1

dict["d"] = 4
dict["a"] = 1
dict["b"] = 2
dict["e"] = 3
assert_equal(check, 0)
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved

keys = String("")
for key in reversed(dict):
keys += key[]
check = 0
for item in reversed(empty_dict.items()):
keys += item[].key
check += item[].value

assert_equal(keys, "ebad")
assert_equal(keys, "")
assert_equal(check, 0)


def main():
Expand Down