Skip to content

Commit

Permalink
Use __index__ for __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
When indexing stdlib containers we should accept a generic type that calls
on the  __index__ method to allow types other than Int to be used but
doesn't allow Intable types that should not be used for such purposes
(such as Float)

Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed May 9, 2024
1 parent 04e984a commit 06e9052
Show file tree
Hide file tree
Showing 25 changed files with 264 additions and 87 deletions.
21 changes: 17 additions & 4 deletions stdlib/src/builtin/bool.mojo
Expand Up @@ -57,7 +57,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -299,10 +304,18 @@ struct Bool(
"""
return lhs ^ self

# ===----------------------------------------------------------------------=== #
# bool
# ===----------------------------------------------------------------------=== #

# ===----------------------------------------------------------------------=== #
# bool
# ===----------------------------------------------------------------------=== #
@always_inline("nodebug")
fn __index__(self) -> Int:
"""Convert this Bool to an integer for indexing purposes
Returns:
Bool as Int
"""
return self.__int__()


@always_inline
Expand Down
18 changes: 12 additions & 6 deletions stdlib/src/builtin/builtin_list.mojo
Expand Up @@ -138,16 +138,19 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.
Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index(idx).value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -344,18 +347,21 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Self.reference_type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
idx: The index of the element to access on the list.
Returns:
A low-level pointer to the element on the list corresponding to the
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index(idx).value)
)

@always_inline
Expand Down
7 changes: 5 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Expand Up @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
"""Get the slice index.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index.
Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int.mojo
Expand Up @@ -204,6 +204,7 @@ struct Int(
KeyElement,
Roundable,
Stringable,
Indexer,
):
"""This type represents an integer value."""

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int_literal.mojo
Expand Up @@ -28,6 +28,7 @@ struct IntLiteral(
Intable,
Roundable,
Stringable,
Indexer,
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Expand Up @@ -82,8 +82,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -113,8 +113,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + index(idx)

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -185,8 +185,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(abs(self.start - self.end), abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + index(idx) * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
46 changes: 39 additions & 7 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -132,6 +132,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Roundable,
Sized,
Stringable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.
Expand Down Expand Up @@ -502,6 +503,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value
Contraints:
Must be an integral value
Returns:
The value as an integer
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down Expand Up @@ -1742,9 +1759,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]:
"""Gets an element from the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The element index.
Expand All @@ -1753,32 +1773,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, index(idx).value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
indexer: Indexer
](inout self, idx: indexer, val: Scalar[type]):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, index(idx).value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
indexer: Indexer
](
inout self,
idx: indexer,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, index(idx).value
)

fn __hash__(self) -> Int:
Expand Down
14 changes: 9 additions & 5 deletions stdlib/src/builtin/string.mojo
Expand Up @@ -687,21 +687,25 @@ struct String(
"""
return len(self) > 0

fn __getitem__(self, idx: Int) -> String:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> String:
"""Gets the character at the specified position.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index value.
Returns:
A new string containing the character at the specified position.
"""
if idx < 0:
return self.__getitem__(len(self) + idx)
var index_val = index(idx)
if index_val < 0:
return self.__getitem__(len(self) + index_val)

debug_assert(0 <= idx < len(self), "index must be in range")
debug_assert(0 <= index_val < len(self), "index must be in range")
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[idx])
buf.append(self._buffer[index_val])
buf.append(0)
return String(buf^)

Expand Down
23 changes: 23 additions & 0 deletions stdlib/src/builtin/value.mojo
Expand Up @@ -145,3 +145,26 @@ trait RepresentableCollectionElement(CollectionElement, Representable):
"""

pass


trait Indexer:
"""This trait denotes a type that can be used to index a container that
handles integral index values.
This solves the issue of being able to index data structures such as `List` with the various
integral types without being too broad and allowing types that should not be used such as float point
values.
"""

fn __index__(self) -> Int:
"""Return the index value
Returns:
The index value of the object
"""
...


@always_inline("nodebug")
fn index[indexer: Indexer](idx: indexer) -> Int:
return idx.__index__()

0 comments on commit 06e9052

Please sign in to comment.