Skip to content

Commit

Permalink
[stdlib] Optimize normalize_index for unsigned types
Browse files Browse the repository at this point in the history
Use the Indexer trait in normalize_index to optimize for UInt, UInt8, UInt16, UInt32, and UInt64 types.

Signed-off-by: Yinon Burgansky <[email protected]>
  • Loading branch information
yinonburgansky committed Feb 1, 2025
1 parent b367ba8 commit 230ff56
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 65 deletions.
106 changes: 81 additions & 25 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,100 @@
"""The utilities provided in this module help normalize the access
to data elements in arrays."""

from sys.intrinsics import _type_is_eq
from sys import sizeof


@always_inline
fn normalize_index[
I: Indexer, //, container_name: StringLiteral
](idx: I, length: UInt) -> UInt:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
I: A type that can be used as an index.
container_name: The name of the container. Used for the error message.
Args:
idx: The index value to normalize.
length: The container length to normalize the index for.
Returns:
The normalized index value.
"""

@parameter
if (
_type_is_eq[I, UInt]()
or _type_is_eq[I, UInt8]()
or _type_is_eq[I, UInt16]()
or _type_is_eq[I, UInt32]()
or _type_is_eq[I, UInt64]()
):
var i = UInt(index(idx))
# TODO: Consider a way to construct the error message after the assert has failed
# something like "Indexing into an empty container" if length == 0 else "..."
debug_assert[assert_mode="safe", cpu_only=True](
i < length,
container_name,
" index out of bounds: index (",
i,
") valid range: -", # can't print -UInt.MAX
length,
" <= index < ",
length,
)
return i
else:
var mlir_index = index(idx)
var i = UInt(mlir_index)
if Int(mlir_index) < 0:
i += length
# Checking the bounds after the normalization saves a comparison
# while allowing negative indexing into containers with length > Int.MAX.
# For a positive index this is trivially correct.
# For a negative index we can infer the full bounds check from
# the assert UInt(idx + length) < length, by considering 2 cases:
# when length > Int.MAX then:
# idx + length > idx + Int.MAX >= Int.MIN + Int.MAX = -1
# therefore idx + length >= 0
# when length <= Int.MAX then:
# UInt(idx + length) < length <= Int.MAX
# Which means UInt(idx + length) signed bit is off
# therefore idx + length >= 0
# in either case we can infer 0 <= idx + length < length
debug_assert[assert_mode="safe", cpu_only=True](
i < length,
container_name,
" index out of bounds: index (",
Int(mlir_index),
") valid range: -", # can't print -UInt.MAX
length,
" <= index < ",
length,
)
return i


@always_inline
fn normalize_index[
ContainerType: Sized, //, container_name: StringLiteral
](idx: Int, container: ContainerType) -> Int:
I: Indexer, //, container_name: StringLiteral
](idx: I, length: Int) -> Int:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
ContainerType: The type of the container. Must have a `__len__` method.
I: A type that can be used as an index.
container_name: The name of the container. Used for the error message.
Args:
idx: The index value to normalize.
container: The container to normalize the index for.
length: The container length to normalize the index for.
Returns:
The normalized index value.
"""
debug_assert[assert_mode="safe", cpu_only=True](
len(container) > 0,
"indexing into a ",
container_name,
" that has 0 elements",
)
debug_assert[assert_mode="safe", cpu_only=True](
-len(container) <= idx < len(container),
container_name,
" has length: ",
len(container),
" index out of bounds: ",
idx,
" should be between ",
-len(container),
" and ",
len(container) - 1,
)
if idx >= 0:
return idx
return idx + len(container)
return Int(normalize_index[container_name](idx, UInt(length)))
25 changes: 4 additions & 21 deletions stdlib/src/collections/inline_array.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,8 @@ struct InlineArray[
Returns:
A reference to the item at the given index.
"""

@parameter
if _type_is_eq[I, UInt]():
return self.unsafe_get(idx)
else:
var normalized_index = normalize_index["InlineArray"](
Int(idx), self
)
return self.unsafe_get(normalized_index)
var normalized_index = normalize_index["InlineArray"](idx, len(self))
return self.unsafe_get(normalized_index)

@always_inline
fn __getitem__[
Expand All @@ -257,18 +250,8 @@ struct InlineArray[
A reference to the item at the given index.
"""
constrained[-size <= Int(idx) < size, "Index must be within bounds."]()

@parameter
if _type_is_eq[I, UInt]():
return self.unsafe_get(idx)
else:
var normalized_idx = Int(idx)

@parameter
if Int(idx) < 0:
normalized_idx += size

return self.unsafe_get(normalized_idx)
alias normalized_index = normalize_index["InlineArray"](idx, size)
return self.unsafe_get(normalized_index)

# ===------------------------------------------------------------------=== #
# Trait implementations
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/collections/linked_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ struct LinkedList[
A pointer to the node at the specified index.
"""
var l = len(self)
var i = normalize_index[container_name="LinkedList"](index, self)
var i = normalize_index["LinkedList"](index, l)
debug_assert(0 <= i < l, "index out of bounds")
var mid = l // 2
if i <= mid:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ struct String(
A new string containing the character at the specified position.
"""
# TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time
var normalized_idx = normalize_index["String"](index(idx), self)
var normalized_idx = normalize_index["String"](idx, len(self))
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[normalized_idx])
buf.append(0)
Expand Down
100 changes: 83 additions & 17 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,93 @@ from testing import assert_equal


def test_out_of_bounds_message():
l = List[Int](1, 2)
# CHECK: index out of bounds: 2
_ = normalize_index["List"](2, l)
# CHECK: index out of bounds: -3
_ = normalize_index["List"](-3, l)
# CHECK: index out of bounds
_ = normalize_index[""](2, 2)
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), 2)
# CHECK: index out of bounds
_ = normalize_index[""](2, UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](UInt8(2), 2)

l2 = List[Int]()
# CHECK: indexing into a List that has 0 elements
_ = normalize_index["List"](2, l2)
# CHECK: index out of bounds
_ = normalize_index[""](-3, 2)
# CHECK: index out of bounds
_ = normalize_index[""](-3, UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](Int8(-3), 2)

# CHECK: index out of bounds
_ = normalize_index[""](2, 0)
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), 0)
# CHECK: index out of bounds
_ = normalize_index[""](2, UInt(0))
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), UInt(0))

# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, 10)
# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](Int.MAX, 10)
# CHECK: index out of bounds
_ = normalize_index[""](Int.MAX, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](Int.MIN, Int.MAX)

# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, 10)
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt.MAX)
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt.MAX - 10)


def test_normalize_index():
container = List[Int](1, 1, 1, 1)
assert_equal(normalize_index[""](-4, container), 0)
assert_equal(normalize_index[""](-3, container), 1)
assert_equal(normalize_index[""](-2, container), 2)
assert_equal(normalize_index[""](-1, container), 3)
assert_equal(normalize_index[""](0, container), 0)
assert_equal(normalize_index[""](1, container), 1)
assert_equal(normalize_index[""](2, container), 2)
assert_equal(normalize_index[""](3, container), 3)
assert_equal(normalize_index[""](-3, 3), 0)
assert_equal(normalize_index[""](-2, 3), 1)
assert_equal(normalize_index[""](-1, 3), 2)
assert_equal(normalize_index[""](0, 3), 0)
assert_equal(normalize_index[""](1, 3), 1)
assert_equal(normalize_index[""](2, 3), 2)

assert_equal(normalize_index[""](-3, UInt(3)), 0)
assert_equal(normalize_index[""](-2, UInt(3)), 1)
assert_equal(normalize_index[""](-1, UInt(3)), 2)
assert_equal(normalize_index[""](0, UInt(3)), 0)
assert_equal(normalize_index[""](1, UInt(3)), 1)
assert_equal(normalize_index[""](2, UInt(3)), 2)

assert_equal(normalize_index[""](UInt(0), UInt(3)), 0)
assert_equal(normalize_index[""](UInt(1), UInt(3)), 1)
assert_equal(normalize_index[""](UInt(2), UInt(3)), 2)

assert_equal(normalize_index[""](Int8(-3), 3), 0)
assert_equal(normalize_index[""](Int8(-2), 3), 1)
assert_equal(normalize_index[""](Int8(-1), 3), 2)
assert_equal(normalize_index[""](Int8(0), 3), 0)
assert_equal(normalize_index[""](Int8(1), 3), 1)
assert_equal(normalize_index[""](Int8(2), 3), 2)

assert_equal(normalize_index[""](UInt8(0), 3), 0)
assert_equal(normalize_index[""](UInt8(1), 3), 1)
assert_equal(normalize_index[""](UInt8(2), 3), 2)

assert_equal(normalize_index[""](UInt(1), UInt.MAX), 1)
assert_equal(normalize_index[""](UInt.MAX - 5, UInt.MAX), UInt.MAX - 5)

assert_equal(normalize_index[""](-1, Int.MAX), Int.MAX - 1)
assert_equal(normalize_index[""](-10, Int.MAX), Int.MAX - 10)
assert_equal(normalize_index[""](-1, UInt.MAX), UInt.MAX - 1)
assert_equal(normalize_index[""](-10, UInt.MAX), UInt.MAX - 10)
assert_equal(normalize_index[""](-1, UInt(Int.MAX) + 1), UInt(Int.MAX))
assert_equal(normalize_index[""](Int.MIN, UInt(Int.MAX) + 1), 0)


def main():
Expand Down

0 comments on commit 230ff56

Please sign in to comment.