Skip to content

Commit

Permalink
Use __index__ for __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
When indexing stdlib containers we should accept a generic type that calls
on the  __index__ method to allow types other than Int to be used but
doesn't allow Intable types that should not be used for such purposes
(such as Float)

Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed Apr 30, 2024
1 parent ee52186 commit 8902489
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 49 deletions.
16 changes: 15 additions & 1 deletion stdlib/src/builtin/bool.mojo
Expand Up @@ -47,7 +47,12 @@ trait Boolable:
@value
@register_passable("trivial")
struct Bool(
Stringable, CollectionElement, Boolable, EqualityComparable, Intable
Stringable,
CollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
"""The primitive Bool scalar value used in Mojo."""

Expand Down Expand Up @@ -276,6 +281,15 @@ struct Bool(
)
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Convert this Bool to an integer for indexing purposes
Returns:
Bool as Int
"""
return self.__int__()


@always_inline
fn bool(value: None) -> Bool:
Expand Down
10 changes: 6 additions & 4 deletions stdlib/src/builtin/builtin_list.mojo
Expand Up @@ -138,7 +138,7 @@ struct VariadicList[type: AnyRegType](Sized):
return __mlir_op.`pop.variadic.size`(self.value)

@always_inline
fn __getitem__(self, index: Int) -> type:
fn __getitem__[indexer: Indexer](self, index: indexer) -> type:
"""Gets a single element on the variadic list.
Args:
Expand All @@ -147,7 +147,7 @@ struct VariadicList[type: AnyRegType](Sized):
Returns:
The element on the list corresponding to the given index.
"""
return __mlir_op.`pop.variadic.get`(self.value, index.value)
return __mlir_op.`pop.variadic.get`(self.value, index.__index__().value)

@always_inline
fn __iter__(self) -> Self.IterType:
Expand Down Expand Up @@ -344,7 +344,9 @@ struct VariadicListMem[
# TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol
# allowing us to get rid of this and make foreach iteration clean.
@always_inline
fn __getitem__(self, index: Int) -> Self.reference_type:
fn __getitem__[
indexer: Indexer
](self, index: indexer) -> Self.reference_type:
"""Gets a single element on the variadic list.
Args:
Expand All @@ -355,7 +357,7 @@ struct VariadicListMem[
given index.
"""
return Self.reference_type(
__mlir_op.`pop.variadic.get`(self.value, index.value)
__mlir_op.`pop.variadic.get`(self.value, index.__index__().value)
)

@always_inline
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/builtin/builtin_slice.mojo
Expand Up @@ -155,7 +155,7 @@ struct Slice(Sized, Stringable, EqualityComparable):
return len(range(self.start, self.end, self.step))

@always_inline
fn __getitem__(self, idx: Int) -> Int:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
"""Get the slice index.
Args:
Expand All @@ -164,7 +164,7 @@ struct Slice(Sized, Stringable, EqualityComparable):
Returns:
The slice index.
"""
return self.start + idx * self.step
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn _has_end(self) -> Bool:
Expand Down
4 changes: 3 additions & 1 deletion stdlib/src/builtin/int.mojo
Expand Up @@ -192,7 +192,9 @@ fn int(value: String, base: Int = 10) raises -> Int:
@lldb_formatter_wrapping_type
@value
@register_passable("trivial")
struct Int(Absable, Intable, Stringable, KeyElement, Boolable, Formattable):
struct Int(
Absable, Intable, Stringable, KeyElement, Boolable, Formattable, Indexer
):
"""This type represents an integer value."""

var value: __mlir_type.index
Expand Down
4 changes: 3 additions & 1 deletion stdlib/src/builtin/int_literal.mojo
Expand Up @@ -16,7 +16,9 @@
@value
@nonmaterializable(Int)
@register_passable("trivial")
struct IntLiteral(Absable, Intable, Stringable, Boolable, EqualityComparable):
struct IntLiteral(
Absable, Intable, Stringable, Boolable, EqualityComparable, Indexer
):
"""This type represents a static integer literal value with
infinite precision. They can't be materialized at runtime and
must be lowered to other integer types (like Int), but allow for
Expand Down
12 changes: 6 additions & 6 deletions stdlib/src/builtin/range.mojo
Expand Up @@ -92,8 +92,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange):
return self.curr

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -123,8 +123,8 @@ struct _SequentialRange(Sized, ReversibleRange):
return self.end - self.start if self.start < self.end else 0

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__()

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down Expand Up @@ -195,8 +195,8 @@ struct _StridedRange(Sized, ReversibleRange):
return _div_ceil_positive(_abs(self.start - self.end), _abs(self.step))

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Int:
return self.start + idx * self.step
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int:
return self.start + idx.__index__() * self.step

@always_inline("nodebug")
fn __reversed__(self) -> _StridedRangeIterator:
Expand Down
37 changes: 30 additions & 7 deletions stdlib/src/builtin/simd.mojo
Expand Up @@ -120,6 +120,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Stringable,
Hashable,
Boolable,
Indexer,
):
"""Represents a small vector that is backed by a hardware vector element.
Expand Down Expand Up @@ -488,6 +489,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
rebind[Scalar[type]](self).value
)

@always_inline("nodebug")
fn __index__(self) -> Int:
"""Returns the value as an int if it is an integral value
Contraints:
Must be an integral value
Returns:
The value as an integer
"""
constrained[
type.is_integral() or type.is_bool(),
"expected integral or bool type",
]()
return self.__int__()

@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Expand Down Expand Up @@ -1544,7 +1561,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
# ===-------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__(self, idx: Int) -> Scalar[type]:
fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]:
"""Gets an element from the vector.
Args:
Expand All @@ -1555,23 +1572,29 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
"""
return __mlir_op.`pop.simd.extractelement`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](self.value, idx.value)
](self.value, idx.__index__().value)

@always_inline("nodebug")
fn __setitem__(inout self, idx: Int, val: Scalar[type]):
fn __setitem__[
indexer: Indexer
](inout self, idx: indexer, val: Scalar[type]):
"""Sets an element in the vector.
Args:
idx: The index to set.
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val.value, idx.value
self.value, val.value, idx.__index__().value
)

@always_inline("nodebug")
fn __setitem__(
inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`]
fn __setitem__[
indexer: Indexer
](
inout self,
idx: indexer,
val: __mlir_type[`!pop.scalar<`, type.value, `>`],
):
"""Sets an element in the vector.
Expand All @@ -1580,7 +1603,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
val: The value to set.
"""
self.value = __mlir_op.`pop.simd.insertelement`(
self.value, val, idx.value
self.value, val, idx.__index__().value
)

fn __hash__(self) -> Int:
Expand Down
10 changes: 10 additions & 0 deletions stdlib/src/builtin/value.mojo
Expand Up @@ -133,3 +133,13 @@ trait ComparableCollectionElement(CollectionElement, EqualityComparable):
"""

pass


trait Indexer:
fn __index__(self) -> Int:
"""Return the index value
Returns:
The index value of the object
"""
...
16 changes: 8 additions & 8 deletions stdlib/src/collections/vector.mojo
Expand Up @@ -182,7 +182,7 @@ struct InlinedFixedVector[
return self.current_size

@always_inline
fn __getitem__(self, i: Int) -> type:
fn __getitem__[indexer: Indexer](self, i: indexer) -> type:
"""Gets a vector element at the given index.
Args:
Expand All @@ -191,12 +191,12 @@ struct InlinedFixedVector[
Returns:
The element at the given index.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)
var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand All @@ -205,20 +205,20 @@ struct InlinedFixedVector[
return self.dynamic_data[normalized_idx - Self.static_size]

@always_inline
fn __setitem__(inout self, i: Int, value: type):
fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type):
"""Sets a vector element at the given index.
Args:
i: The index of the element.
value: The value to assign.
"""
var normalized_idx = i.__index__()
debug_assert(
-self.current_size <= i < self.current_size,
-self.current_size <= normalized_idx < self.current_size,
"index must be within bounds",
)

var normalized_idx = i
if i < 0:
if normalized_idx < 0:
normalized_idx += len(self)

if normalized_idx < Self.static_size:
Expand Down
10 changes: 5 additions & 5 deletions stdlib/src/memory/unsafe.mojo
Expand Up @@ -289,19 +289,19 @@ struct LegacyPointer[
)

@always_inline("nodebug")
fn __refitem__[T: Intable](self, offset: T) -> Self._mlir_ref_type:
fn __refitem__[T: Indexer](self, offset: T) -> Self._mlir_ref_type:
"""Enable subscript syntax `ref[idx]` to access the element.
Parameters:
T: The Intable type of the offset.
T: The Indexer type of the offset.
Args:
offset: The offset to load from.
Returns:
The MLIR reference for the Mojo compiler to use.
"""
return (self + offset).__refitem__()
return (self + offset.__index__()).__refitem__()

# ===------------------------------------------------------------------=== #
# Load/Store
Expand Down Expand Up @@ -714,7 +714,7 @@ struct DTypePointer[
return arg.get_legacy_pointer()

@always_inline("nodebug")
fn __getitem__[T: Intable](self, offset: T) -> Scalar[type]:
fn __getitem__[T: Indexer](self, offset: T) -> Scalar[type]:
"""Loads a single element (SIMD of size 1) from the pointer at the
specified index.
Expand All @@ -727,7 +727,7 @@ struct DTypePointer[
Returns:
The loaded value.
"""
return self.load(offset)
return self.load(offset.__index__())

@always_inline("nodebug")
fn __setitem__[T: Intable](self, offset: T, val: Scalar[type]):
Expand Down
8 changes: 7 additions & 1 deletion stdlib/src/python/object.mojo
Expand Up @@ -101,7 +101,13 @@ struct _PyIter(Sized):

@register_passable
struct PythonObject(
Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement
Intable,
Stringable,
SizedRaising,
Boolable,
CollectionElement,
KeyElement,
Indexer,
):
"""A Python object."""

Expand Down
8 changes: 4 additions & 4 deletions stdlib/src/utils/index.mojo
Expand Up @@ -335,11 +335,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable):
return size

@always_inline("nodebug")
fn __getitem__[intable: Intable](self, index: intable) -> Int:
fn __getitem__[indexer: Indexer](self, index: indexer) -> Int:
"""Gets an element from the tuple by index.
Parameters:
intable: The intable type.
indexer: The index type.
Args:
index: The element index.
Expand All @@ -362,11 +362,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable):
self.data.__setitem__[index](val)

@always_inline("nodebug")
fn __setitem__[intable: Intable](inout self, index: intable, val: Int):
fn __setitem__[indexer: Indexer](inout self, index: indexer, val: Int):
"""Sets an element in the tuple at the given index.
Parameters:
intable: The intable type.
indexer: The index type.
Args:
index: The element index.
Expand Down

0 comments on commit 8902489

Please sign in to comment.