Skip to content

Commit

Permalink
[External] [stdlib] Add _count_utf8_continuation_bytes() (#49049)
Browse files Browse the repository at this point in the history
[External] [stdlib] Add `_count_utf8_continuation_bytes()`

Add `_count_utf8_continuation_bytes()`

ORIGINAL_AUTHOR=martinvuyk
<[email protected]>
PUBLIC_PR_LINK=#3529

Co-authored-by: martinvuyk <[email protected]>
Closes #3529
MODULAR_ORIG_COMMIT_REV_ID: 994f648ac650ccd29096946d29b290e855bce057
  • Loading branch information
modularbot and martinvuyk committed Dec 17, 2024
1 parent b95eaba commit 3f7bd01
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 37 deletions.
10 changes: 10 additions & 0 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,16 @@ struct StringLiteral(
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

fn __reversed__(self) -> _StringSliceIter[StaticConstantOrigin, False]:
"""Iterate backwards over the string, returning immutable references.
Returns:
A reversed iterator over the string.
"""
return _StringSliceIter[StaticConstantOrigin, False](
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> String:
"""Gets the character at the specified position.
Expand Down
8 changes: 3 additions & 5 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1127,8 +1127,8 @@ struct String(
count=other_len + 1,
)

fn __iter__(ref [_]self) -> _StringSliceIter[__origin_of(self)]:
"""Iterate over elements of the string, returning immutable references.
fn __iter__(self) -> _StringSliceIter[__origin_of(self)]:
"""Iterate over the string, returning immutable references.
Returns:
An iterator of references to the string elements.
Expand All @@ -1137,9 +1137,7 @@ struct String(
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

fn __reversed__(
ref [_]self,
) -> _StringSliceIter[__origin_of(self), False]:
fn __reversed__(self) -> _StringSliceIter[__origin_of(self), False]:
"""Iterate backwards over the string, returning immutable references.
Returns:
Expand Down
80 changes: 53 additions & 27 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ alias StaticString = StringSlice[StaticConstantOrigin]
"""An immutable static string slice."""


fn _count_utf8_continuation_bytes(span: Span[UInt8]) -> Int:
alias sizes = (256, 128, 64, 32, 16, 8)
var ptr = span.unsafe_ptr()
var num_bytes = len(span)
var amnt: Int = 0
var processed = 0

@parameter
for i in range(len(sizes)):
alias s = sizes.get[i, Int]()

@parameter
if simdwidthof[DType.uint8]() >= s:
var rest = num_bytes - processed
for _ in range(rest // s):
var vec = (ptr + processed).load[width=s]()
var comp = (vec & 0b1100_0000) == 0b1000_0000
amnt += int(comp.cast[DType.uint8]().reduce_add())
processed += s

for i in range(num_bytes - processed):
amnt += int((ptr[processed + i] & 0b1100_0000) == 0b1000_0000)

return amnt


fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int:
debug_assert(
0 <= c <= 0x10FFFF, "Value: ", c, " is not a valid Unicode code point"
Expand Down Expand Up @@ -147,10 +173,9 @@ struct _StringSliceIter[
self.index = 0 if forward else length
self.ptr = unsafe_pointer
self.length = length
self.continuation_bytes = 0
for i in range(length):
if _utf8_byte_type(unsafe_pointer[i]) == 1:
self.continuation_bytes += 1
alias S = Span[UInt8, StaticConstantOrigin]
var s = S(unsafe_ptr=self.ptr, len=self.length)
self.continuation_bytes = _count_utf8_continuation_bytes(s)

fn __iter__(self) -> Self:
return self
Expand Down Expand Up @@ -244,8 +269,7 @@ struct StringSlice[

@always_inline
fn __init__(inout self, *, owned unsafe_from_utf8: Span[UInt8, origin]):
"""
Construct a new StringSlice from a sequence of UTF-8 encoded bytes.
"""Construct a new StringSlice from a sequence of UTF-8 encoded bytes.
Safety:
`unsafe_from_utf8` MUST be valid UTF-8 encoded data.
Expand All @@ -257,9 +281,8 @@ struct StringSlice[
self._slice = unsafe_from_utf8^

fn __init__(inout self, *, unsafe_from_utf8_strref: StringRef):
"""
Construct a new StringSlice from a StringRef pointing to UTF-8 encoded
bytes.
"""Construct a new StringSlice from a StringRef pointing to UTF-8
encoded bytes.
Safety:
- `unsafe_from_utf8_strref` MUST point to data that is valid for
Expand All @@ -285,8 +308,7 @@ struct StringSlice[
unsafe_from_utf8_ptr: UnsafePointer[UInt8],
len: Int,
):
"""
Construct a StringSlice from a pointer to a sequence of UTF-8 encoded
"""Construct a StringSlice from a pointer to a sequence of UTF-8 encoded
bytes and a length.
Safety:
Expand Down Expand Up @@ -335,13 +357,10 @@ struct StringSlice[
Returns:
The length in Unicode codepoints.
"""
var unicode_length = self.byte_length()

for i in range(unicode_length):
if _utf8_byte_type(self._slice[i]) == 1:
unicode_length -= 1

return unicode_length
var b_len = self.byte_length()
alias S = Span[UInt8, StaticConstantOrigin]
var s = S(unsafe_ptr=self.unsafe_ptr(), len=b_len)
return b_len - _count_utf8_continuation_bytes(s)

fn format_to(self, inout writer: Formatter):
"""
Expand Down Expand Up @@ -372,7 +391,8 @@ struct StringSlice[
rhs: The string slice to compare against.
Returns:
True if the string slices are equal in length and contain the same elements, False otherwise.
True if the string slices are equal in length and contain the same
elements, False otherwise.
"""
if not self and not rhs:
return True
Expand All @@ -394,7 +414,8 @@ struct StringSlice[
rhs: The string to compare against.
Returns:
True if the string slice is equal to the input string in length and contain the same bytes, False otherwise.
True if the string slice is equal to the input string in length and
contain the same bytes, False otherwise.
"""
return self == rhs.as_string_slice()

Expand All @@ -406,7 +427,8 @@ struct StringSlice[
rhs: The literal to compare against.
Returns:
True if the string slice is equal to the input literal in length and contain the same bytes, False otherwise.
True if the string slice is equal to the input literal in length and
contain the same bytes, False otherwise.
"""
return self == rhs.as_string_slice()

Expand All @@ -419,7 +441,8 @@ struct StringSlice[
rhs: The string slice to compare against.
Returns:
True if the string slices are not equal in length or contents, False otherwise.
True if the string slices are not equal in length or contents, False
otherwise.
"""
return not self == rhs

Expand All @@ -431,7 +454,8 @@ struct StringSlice[
rhs: The string slice to compare against.
Returns:
True if the string and slice are not equal in length or contents, False otherwise.
True if the string and slice are not equal in length or contents,
False otherwise.
"""
return not self == rhs

Expand All @@ -443,12 +467,13 @@ struct StringSlice[
rhs: The string literal to compare against.
Returns:
True if the slice is not equal to the literal in length or contents, False otherwise.
True if the slice is not equal to the literal in length or contents,
False otherwise.
"""
return not self == rhs

fn __iter__(self) -> _StringSliceIter[origin]:
"""Iterate over elements of the string slice, returning immutable references.
"""Iterate over the string, returning immutable references.
Returns:
An iterator of references to the string elements.
Expand All @@ -473,7 +498,7 @@ struct StringSlice[

@always_inline
fn as_bytes(self) -> Span[UInt8, origin]:
"""Get the sequence of encoded bytes as a slice of the underlying string.
"""Get the sequence of encoded bytes of the underlying string.
Returns:
A slice containing the underlying sequence of encoded bytes.
Expand Down Expand Up @@ -519,7 +544,8 @@ struct StringSlice[
pass

fn _from_start(self, start: Int) -> Self:
"""Gets the `StringSlice` pointing to the substring after the specified slice start position.
"""Gets the `StringSlice` pointing to the substring after the specified
slice start position.
If start is negative, it is interpreted as the number of characters
from the end of the string to start at.
Expand Down
13 changes: 8 additions & 5 deletions stdlib/test/collections/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1338,17 +1338,20 @@ def test_string_iter():
"álO",
"етйувтсвардЗ",
)
var utf8_sequence_lengths = List(5, 12, 9, 5, 7, 6, 5, 5, 2, 3, 12)
var items_amount_characters = List(5, 12, 9, 5, 7, 6, 5, 5, 2, 3, 12)
for item_idx in range(len(items)):
var item = items[item_idx]
var utf8_sequence_len = 0
var ptr = item.unsafe_ptr()
var amnt_characters = 0
var byte_idx = 0
for v in item:
var byte_len = v.byte_length()
assert_equal(item[byte_idx : byte_idx + byte_len], v)
for i in range(byte_len):
assert_equal(ptr[byte_idx + i], v.unsafe_ptr()[i])
byte_idx += byte_len
utf8_sequence_len += 1
assert_equal(utf8_sequence_len, utf8_sequence_lengths[item_idx])
amnt_characters += 1

assert_equal(amnt_characters, items_amount_characters[item_idx])
var concat = String("")
for v in item.__reversed__():
concat += v
Expand Down
29 changes: 29 additions & 0 deletions stdlib/test/utils/test_string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ from testing import assert_equal, assert_true, assert_false

from utils import Span, StringSlice
from utils._utf8_validation import _is_valid_utf8
from utils.string_slice import _count_utf8_continuation_bytes


fn test_string_literal_byte_span() raises:
Expand Down Expand Up @@ -387,6 +388,33 @@ def test_combination_10_good_10_bad_utf8_sequences():
assert_false(validate_utf8(sequence))


def test_count_utf8_continuation_bytes():
alias c = UInt8(0b1000_0000)
alias b1 = UInt8(0b0100_0000)
alias b2 = UInt8(0b1100_0000)
alias b3 = UInt8(0b1110_0000)
alias b4 = UInt8(0b1111_0000)

def _test(amnt: Int, items: List[UInt8]):
p = items.unsafe_ptr()
span = Span[UInt8, StaticConstantOrigin](unsafe_ptr=p, len=len(items))
assert_equal(amnt, _count_utf8_continuation_bytes(span))

_test(5, List[UInt8](c, c, c, c, c))
_test(2, List[UInt8](b2, c, b2, c, b1))
_test(2, List[UInt8](b2, c, b1, b2, c))
_test(2, List[UInt8](b2, c, b2, c, b1))
_test(2, List[UInt8](b2, c, b1, b2, c))
_test(2, List[UInt8](b1, b2, c, b2, c))
_test(2, List[UInt8](b3, c, c, b1, b1))
_test(2, List[UInt8](b1, b1, b3, c, c))
_test(2, List[UInt8](b1, b3, c, c, b1))
_test(3, List[UInt8](b1, b4, c, c, c))
_test(3, List[UInt8](b4, c, c, c, b1))
_test(3, List[UInt8](b3, c, c, b2, c))
_test(3, List[UInt8](b2, c, b3, c, c))


fn main() raises:
test_string_literal_byte_span()
test_string_byte_span()
Expand All @@ -403,3 +431,4 @@ fn main() raises:
test_combination_good_bad_utf8_sequences()
test_combination_10_good_utf8_sequences()
test_combination_10_good_10_bad_utf8_sequences()
test_count_utf8_continuation_bytes()

0 comments on commit 3f7bd01

Please sign in to comment.