Skip to content

Commit

Permalink
Use __index__ for __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
When indexing stdlib containers we should accept a generic type that calls
on the  __index__ method to allow types other than Int to be used but
doesn't allow Intable types that should not be used for such purposes
(such as Float)

Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed Apr 23, 2024
1 parent b176fe7 commit c930369
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 17 deletions.
2 changes: 1 addition & 1 deletion stdlib/src/builtin/int.mojo
Expand Up @@ -169,7 +169,7 @@ fn int[T: IntableRaising](value: T) raises -> Int:
@lldb_formatter_wrapping_type
@value
@register_passable("trivial")
struct Int(Intable, Stringable, KeyElement, Boolable):
struct Int(Intable, Stringable, KeyElement, Boolable, Indexer):
"""This type represents an integer value."""

var value: __mlir_type.index
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/builtin/int_literal.mojo
Expand Up @@ -16,7 +16,7 @@
@value
@nonmaterializable(Int)
@register_passable("trivial")
struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable):
struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable, Indexer):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
must be lowered to other integer types (like Int), but allow for
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Expand Up @@ -92,8 +92,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -121,8 +121,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return _max(0, self.end - self.start)

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -192,8 +192,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(_abs(self.start - self.end), _abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
6 changes: 6 additions & 0 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -119,6 +119,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Stringable,
Hashable,
Boolable,
Indexer
):
"""Represents a small vector that is backed by a hardware vector element.
Expand Down Expand Up @@ -513,6 +514,11 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
constrained[type.is_integral(), "Expected an integral type"]()
return int(self)

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down
9 changes: 9 additions & 0 deletions stdlib/src/builtin/value.mojo
Expand Up @@ -121,3 +121,12 @@ trait StringableCollectionElement(CollectionElement, Stringable):
"""

pass

trait Indexer:
fn __index__(self) -> Int:
"""Return the index value
Returns:
The index value of the object
"""
...
17 changes: 9 additions & 8 deletions stdlib/src/collections/vector.mojo
Expand Up @@ -182,7 +182,7 @@ struct InlinedFixedVector[
return self.current_size

@always_inline
fn __getitem__(self, i: Int) -> type:
fn __getitem__[indexer: Indexer](self, i: indexer) -> type:
"""Gets a vector element at the given index.
Args:
Expand All @@ -191,12 +191,12 @@ struct InlinedFixedVector[
Returns:
The element at the given index.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)
var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand All @@ -205,20 +205,21 @@ struct InlinedFixedVector[
return self.dynamic_data[normalized_idx - Self.static_size]

@always_inline
fn __setitem__(inout self, i: Int, value: type):
fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type):
"""Sets a vector element at the given index.
Args:
i: The index of the element.
value: The value to assign.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)

var normalized_idx = i
if i < 0:

if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/python/object.mojo
Expand Up @@ -101,7 +101,7 @@ struct _PyIter(Sized):

@register_passable
struct PythonObject(
Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement
Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement, Indexer
):
"""A Python object."""

Expand Down

0 comments on commit c930369

Please sign in to comment.