Skip to content

Commit

Permalink
Use __index__ for __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
When indexing stdlib containers we should accept a generic type that calls
on the  __index__ method to allow types other than Int to be used but
doesn't allow Intable types that should not be used for such purposes
(such as Float)

Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed May 6, 2024
1 parent fd304d2 commit 5c0c215
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 50 deletions.
16 changes: 15 additions & 1 deletion stdlib/src/builtin/bool.mojo
Expand Up @@ -47,7 +47,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -276,6 +281,15 @@ struct Bool(
)
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Convert this Bool to an integer for indexing purposes
Returns:
Bool as Int
"""
return self.__int__()


@always_inline
fn bool(value: None) -> Bool:
Expand Down
16 changes: 12 additions & 4 deletions stdlib/src/builtin/builtin_list.mojo
Expand Up @@ -138,16 +138,19 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[indexer: Indexer](self, index: indexer) -> type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index.__index__().value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -344,9 +347,14 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[
indexer: Indexer
](self, index: indexer) -> Self.reference_type:
"""Gets a single element on the variadic list.
Parameters:
indexer: The type of the indexing value.
Args:
index: The index of the element to access on the list.
Expand All @@ -355,7 +363,7 @@ struct VariadicListMem[
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index.__index__().value)
)

@always_inline
Expand Down
7 changes: 5 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Expand Up @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
"""Get the slice index.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index.
Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int.mojo
Expand Up @@ -203,6 +203,7 @@ struct Int(
KeyElement,
Roundable,
Stringable,
Indexer,
):
"""This type represents an integer value."""

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/int_literal.mojo
Expand Up @@ -27,6 +27,7 @@ struct IntLiteral(
Intable,
Roundable,
Stringable,
Indexer,
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Expand Up @@ -82,8 +82,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -113,8 +113,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -185,8 +185,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(abs(self.start - self.end), abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
46 changes: 39 additions & 7 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -131,6 +131,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Roundable,
Sized,
Stringable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.
Expand Down Expand Up @@ -499,6 +500,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value
Contraints:
Must be an integral value
Returns:
The value as an integer
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down Expand Up @@ -1707,9 +1724,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]:
"""Gets an element from the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The element index.
Expand All @@ -1718,32 +1738,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, idx.__index__().value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
indexer: Indexer
](inout self, idx: indexer, val: Scalar[type]):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, idx.__index__().value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
indexer: Indexer
](
inout self,
idx: indexer,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.
Parameters:
indexer: The type of the indexing value.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, idx.__index__().value
)

fn __hash__(self) -> Int:
Expand Down
18 changes: 18 additions & 0 deletions stdlib/src/builtin/value.mojo
Expand Up @@ -145,3 +145,21 @@ trait RepresentableCollectionElement(CollectionElement, Representable):
"""

pass


trait Indexer:
"""This trait denotes a type that can be used to index a container that
handles integral index values.
This solves the issue of being able to index data structures such as `List` with the various
integral types without being too broad and allowing types that should not be used such as float point
values.
"""

fn __index__(self) -> Int:
"""Return the index value
Returns:
The index value of the object
"""
...
22 changes: 14 additions & 8 deletions stdlib/src/collections/vector.mojo
Expand Up @@ -182,21 +182,24 @@ struct InlinedFixedVector[
return self.current_size

@always_inline
fn __getitem__(self, i: Int) -> type:
fn __getitem__[indexer: Indexer](self, i: indexer) -> type:
"""Gets a vector element at the given index.
Parameters:
indexer: The type of the indexing value.
Args:
i: The index of the element.
Returns:
The element at the given index.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)
var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand All @@ -205,20 +208,23 @@ struct InlinedFixedVector[
return self.dynamic_data[normalized_idx - Self.static_size]

@always_inline
fn __setitem__(inout self, i: Int, value: type):
fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type):
"""Sets a vector element at the given index.
Parameters:
indexer: The type of the indexing value.
Args:
i: The index of the element.
value: The value to assign.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)

var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand Down
16 changes: 8 additions & 8 deletions stdlib/src/memory/unsafe.mojo
Expand Up @@ -289,19 +289,19 @@ struct LegacyPointer[
)

@always_inline("nodebug")
fn __refitem__[T: Intable](self, offset: T) -> Self._mlir_ref_type:
fn __refitem__[T: Indexer](self, offset: T) -> Self._mlir_ref_type:
"""Enable subscript syntax `ref[idx]` to access the element.
Parameters:
T: The Intable type of the offset.
T: The Indexer type of the offset.
Args:
offset: The offset to load from.
Returns:
The MLIR reference for the Mojo compiler to use.
"""
return (self + offset).__refitem__()
return (self + offset.__index__()).__refitem__()

# ===------------------------------------------------------------------=== #
# Load/Store
Expand Down Expand Up @@ -714,7 +714,7 @@ struct DTypePointer[
return arg.get_legacy_pointer()

@always_inline("nodebug")
fn __getitem__[T: Intable](self, offset: T) -> Scalar[type]:
fn __getitem__[T: Indexer](self, offset: T) -> Scalar[type]:
"""Loads a single element (SIMD of size 1) from the pointer at the
specified index.
Expand All @@ -727,20 +727,20 @@ struct DTypePointer[
Returns:
The loaded value.
"""
return self.load(offset)
return self.load(offset.__index__())

@always_inline("nodebug")
fn __setitem__[T: Intable](self, offset: T, val: Scalar[type]):
fn __setitem__[T: Indexer](self, offset: T, val: Scalar[type]):
"""Stores a single element value at the given offset.
Parameters:
T: The Intable type of the offset.
T: The type of the indexing value.
Args:
offset: The offset to store to.
val: The value to store.
"""
return self.store(offset, val)
return self.store(offset.__index__(), val)

# ===------------------------------------------------------------------=== #
# Comparisons
Expand Down
8 changes: 7 additions & 1 deletion stdlib/src/python/object.mojo
Expand Up @@ -101,7 +101,13 @@ struct _PyIter(Sized):

@register_passable
struct PythonObject(
Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement
Intable,
Stringable,
SizedRaising,
Boolable,
CollectionElement,
KeyElement,
Indexer,
):
"""A Python object."""

Expand Down

0 comments on commit 5c0c215

Please sign in to comment.