Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Support reversed for Dict #2340

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 75 additions & 2 deletions stdlib/src/builtin/reversed.mojo
Expand Up @@ -19,7 +19,7 @@ from .range import _StridedRangeIterator

from collections.list import _ListIter

from collections.dict import _DictKeyIter
from collections.dict import _DictKeyIter, _DictValueIter, _DictEntryIter

# ===----------------------------------------------------------------------=== #
# Reversible
Expand Down Expand Up @@ -125,6 +125,79 @@ fn reversed[
value: The dict to get the reversed iterator of.

Returns:
The reversed iterator of the dict.
The reversed iterator of the dict keys.
"""
return Reference(value)[].__reversed__[mutability, self_life]()


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
](
value: Reference[
_DictValueIter[K, V, dict_mutability, dict_lifetime],
mutability,
self_life,
]._mlir_type,
) -> _DictValueIter[K, V, dict_mutability, dict_lifetime, False]:
"""Get a reversed iterator of the input dict values.

**Note**: iterators are currently non-raising.

Parameters:
mutability: Whether the reference to the dict is mutable.
self_life: The lifetime of the dict.
K: The type of the keys in the dict.
V: The type of the values in the dict.
dict_mutability: Whether the reference to the dict values is mutable.
dict_lifetime: The lifetime of the dict values.

Args:
value: The dict values to get the reversed iterator of.

Returns:
The reversed iterator of the dict values.
"""
return Reference(value)[].__reversed__[mutability, self_life]()


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
](
value: Reference[
_DictEntryIter[K, V, dict_mutability, dict_lifetime],
mutability,
self_life,
]._mlir_type,
) -> _DictEntryIter[K, V, dict_mutability, dict_lifetime, False]:
"""Get a reversed iterator of the input dict items.

**Note**: iterators are currently non-raising.

Parameters:
mutability: Whether the reference to the dict is mutable.
self_life: The lifetime of the dict.
K: The type of the keys in the dict.
V: The type of the values in the dict.
dict_mutability: Whether the reference to the dict items is mutable.
dict_lifetime: The lifetime of the dict items.

Args:
value: The dict items to get the reversed iterator of.

Returns:
The reversed iterator of the dict items.
"""
var src = Reference(value)[].src
return _DictEntryIter[K, V, dict_mutability, dict_lifetime, False](
src[]._reserved, 0, src
)
17 changes: 15 additions & 2 deletions stdlib/src/collections/dict.mojo
Expand Up @@ -168,6 +168,7 @@ struct _DictValueIter[
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
forward: Bool = True,
]:
"""Iterator over Dict value references. These are mutable if the dict
is mutable.
Expand All @@ -177,15 +178,26 @@ struct _DictValueIter[
V: The value type of the elements in the dictionary.
dict_mutability: Whether the reference to the vector is mutable.
dict_lifetime: The lifetime of the List
forward: The iteration direction. `False` is backwards.
"""

alias ref_type = Reference[V, dict_mutability, dict_lifetime]

var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime]
var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime, forward]

fn __iter__(self) -> Self:
return self

fn __reversed__[
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](self) -> _DictValueIter[K, V, dict_mutability, dict_lifetime, False]:
var src = self.iter.src
return _DictValueIter(
_DictEntryIter[K, V, dict_mutability, dict_lifetime, False](
src[]._reserved, 0, src
)
)

fn __next__(inout self) -> Self.ref_type:
var entry_ref = self.iter.__next__()
# Cast through a pointer to grant additional mutability because
Expand Down Expand Up @@ -683,7 +695,8 @@ struct Dict[K: KeyElement, V: CollectionElement](
return Self.__iter__(self)

fn values[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictValueIter[K, V, mutability, self_life]:
Expand Down
74 changes: 62 additions & 12 deletions stdlib/test/builtin/test_reversed.mojo
Expand Up @@ -30,19 +30,35 @@ def test_reversed_dict():
dict["b"] = 2
dict["c"] = 3
dict["d"] = 4
dict["a"] = 5
dict["a"] = 1

var keys = String("")
for key in reversed(dict):
keys += key[]

assert_equal(keys, "dcba")

var check: Int = 4
for val in reversed(dict.values()):
assert_equal(val[], check)
check -= 1

keys = String("")
check = 4
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check -= 1

assert_equal(keys, "dcba")

# Order preserved

_ = dict.pop("a")
_ = dict.pop("c")

# dict: {'b': 2, 'd': 4}

keys = String("")
for key in dict:
keys += key[]
Expand All @@ -55,29 +71,63 @@ def test_reversed_dict():

assert_equal(keys, "db")

# got 4 and 2
check = 4
for val in reversed(dict.values()):
assert_equal(val[], check)
check -= 2

keys = String("")
check = 4
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check -= 2

assert_equal(keys, "db")

# Refill dict
dict["c"] = 2
dict["a"] = 1
dict["b"] = 4
dict["d"] = 3

# dict: {'b': 4, 'd': 3, 'c': 2, 'a':1}
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved

keys = String("")
check = 1
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved
for item in reversed(dict.items()):
keys += item[].key
assert_equal(item[].value, check)
check += 1

assert_equal(keys, "acdb")

# Empty dict is iterable

_ = dict.pop("b")
_ = dict.pop("d")
var empty_dict = Dict[String, Int]()

keys = String("")
for key in reversed(dict):
for key in reversed(empty_dict):
keys += key[]

assert_equal(keys, "")

# Refill dict
check = 0
for val in reversed(empty_dict.values()):
# values is empty, should not reach here
check += 1

dict["d"] = 4
dict["a"] = 1
dict["b"] = 2
dict["e"] = 3
assert_equal(check, 0)
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved

keys = String("")
for key in reversed(dict):
keys += key[]
check = 0
for item in reversed(empty_dict.items()):
keys += item[].key
check += item[].value

assert_equal(keys, "ebad")
assert_equal(keys, "")
assert_equal(check, 0)


def main():
Expand Down