diff --git a/library/alloc/tests/str.rs b/library/alloc/tests/str.rs index 9689196ef21ac..4d182be02c9e9 100644 --- a/library/alloc/tests/str.rs +++ b/library/alloc/tests/str.rs @@ -1631,6 +1631,18 @@ fn strslice_issue_16878() { assert!(!"00abc01234567890123456789abc".contains("bcabc")); } +#[test] +fn strslice_issue_104726() { + // Edge-case in the simd_contains impl. + // The first and last byte are the same so it backtracks by one byte + // which aligns with the end of the string. Previously incorrect offset calculations + // lead to out-of-bounds slicing. + #[rustfmt::skip] + let needle = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaba"; + let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"; + assert!(!haystack.contains(needle)); +} + #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_strslice_contains() { diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs index c5be32861f9a5..d76d6f8b2a2d9 100644 --- a/library/core/src/str/pattern.rs +++ b/library/core/src/str/pattern.rs @@ -1741,6 +1741,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { use crate::simd::{SimdPartialEq, ToBitMask}; let first_probe = needle[0]; + let last_byte_offset = needle.len() - 1; // the offset used for the 2nd vector let second_probe_offset = if needle.len() == 2 { @@ -1758,7 +1759,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { }; // do a naive search if the haystack is too small to fit - if haystack.len() < Block::LANES + second_probe_offset { + if haystack.len() < Block::LANES + last_byte_offset { return Some(haystack.windows(needle.len()).any(|c| c == needle)); } @@ -1815,7 +1816,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { // The loop condition must ensure that there's enough headroom to read LANE bytes, // and not only at the current index but also at the index shifted by block_offset const UNROLL: usize = 4; - while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result { + while i + last_byte_offset + UNROLL * Block::LANES < haystack.len() && !result { let mut masks = [0u16; UNROLL]; for j in 0..UNROLL { masks[j] = test_chunk(i + j * Block::LANES); @@ -1828,7 +1829,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { } i += UNROLL * Block::LANES; } - while i + second_probe_offset + Block::LANES < haystack.len() && !result { + while i + last_byte_offset + Block::LANES < haystack.len() && !result { let mask = test_chunk(i); if mask != 0 { result |= check_mask(i, mask, result); @@ -1840,7 +1841,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { // This simply repeats the same procedure but as right-aligned chunk instead // of a left-aligned one. The last byte must be exactly flush with the string end so // we don't miss a single byte or read out of bounds. - let i = haystack.len() - second_probe_offset - Block::LANES; + let i = haystack.len() - last_byte_offset - Block::LANES; let mask = test_chunk(i); if mask != 0 { result |= check_mask(i, mask, result); @@ -1860,6 +1861,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option { #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86 #[inline] unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool { + debug_assert_eq!(x.len(), y.len()); // This function is adapted from // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32