Skip to content

Commit

Permalink
Use __index__ for __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
When indexing stdlib containers we should accept a generic type that calls
on the  __index__ method to allow types other than Int to be used but
doesn't allow Intable types that should not be used for such purposes
(such as Float)

Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed May 8, 2024
1 parent 4ef113b commit aa76d86
Show file tree
Hide file tree
Showing 25 changed files with 255 additions and 80 deletions.
16 changes: 15 additions & 1 deletion stdlib/src/builtin/bool.mojo
Expand Up @@ -47,7 +47,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -276,6 +281,15 @@ struct Bool(
)
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Convert this Bool to an integer for indexing purposes
Returns:
Bool as Int
"""
return self.__int__()


@always_inline
fn bool(value: None) -> Bool:
Expand Down
16 changes: 12 additions & 4 deletions stdlib/src/builtin/builtin_list.mojo
Expand Up @@ -138,16 +138,19 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[indexer: Indexer](self, index: indexer) -> type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index.__index__().value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -344,9 +347,14 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[
indexer: Indexer
](self, index: indexer) -> Self.reference_type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
Expand All @@ -355,7 +363,7 @@ struct VariadicListMem[
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index.__index__().value)
)

@always_inline
Expand Down
7 changes: 5 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Expand Up @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
"""Get the slice index.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index.
Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int.mojo
Expand Up @@ -204,6 +204,7 @@ struct Int(
KeyElement,
Roundable,
Stringable,
Indexer,
):
"""This type represents an integer value."""

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int_literal.mojo
Expand Up @@ -28,6 +28,7 @@ struct IntLiteral(
Intable,
Roundable,
Stringable,
Indexer,
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Expand Up @@ -82,8 +82,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -113,8 +113,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -185,8 +185,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(abs(self.start - self.end), abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
46 changes: 39 additions & 7 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -133,6 +133,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Roundable,
Sized,
Stringable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.
Expand Down Expand Up @@ -501,6 +502,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value
Contraints:
Must be an integral value
Returns:
The value as an integer
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down Expand Up @@ -1741,9 +1758,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]:
"""Gets an element from the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The element index.
Expand All @@ -1752,32 +1772,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, idx.__index__().value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
indexer: Indexer
](inout self, idx: indexer, val: Scalar[type]):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, idx.__index__().value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
indexer: Indexer
](
inout self,
idx: indexer,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, idx.__index__().value
)

fn __hash__(self) -> Int:
Expand Down
14 changes: 9 additions & 5 deletions stdlib/src/builtin/string.mojo
Expand Up @@ -683,21 +683,25 @@ struct String(
"""
return len(self) > 0

fn __getitem__(self, idx: Int) -> String:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> String:
"""Gets the character at the specified position.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index value.
Returns:
A new string containing the character at the specified position.
"""
if idx < 0:
return self.__getitem__(len(self) + idx)
var index = idx.__index__()
if index < 0:
return self.__getitem__(len(self) + index)

debug_assert(0 <= idx < len(self), "index must be in range")
debug_assert(0 <= index < len(self), "index must be in range")
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[idx])
buf.append(self._buffer[index])
buf.append(0)
return String(buf^)

Expand Down
18 changes: 18 additions & 0 deletions stdlib/src/builtin/value.mojo
Expand Up @@ -145,3 +145,21 @@ trait RepresentableCollectionElement(CollectionElement, Representable):
"""

pass


trait Indexer:
"""This trait denotes a type that can be used to index a container that
handles integral index values.
This solves the issue of being able to index data structures such as `List` with the various
integral types without being too broad and allowing types that should not be used such as float point
values.
"""

fn __index__(self) -> Int:
"""Return the index value
Returns:
The index value of the object
"""
...

0 comments on commit aa76d86

Please sign in to comment.