Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid double allocation when passing strings via IntoParam #1713

Merged
merged 10 commits into from
Apr 28, 2022
41 changes: 32 additions & 9 deletions crates/libs/windows/src/core/heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,43 @@ pub unsafe fn heap_free(ptr: RawPtr) {
}
}

/// Copy a slice of `T` into a freshly allocated buffer with an additional default `T` at the end.
/// Copy an iterator of `T` into a freshly allocated buffer with an additional default `T` at the end.
///
/// Returns a pointer to the beginning of the buffer
/// Returns a pointer to the beginning of the buffer. This pointer must be freed when done using `heap_free`.
///
/// # Panics
///
/// This function panics if the heap allocation fails or if the pointer returned from
/// the heap allocation is not properly aligned to `T`.
pub fn heap_string<T: Copy + Default + Sized>(slice: &[T]) -> *const T {
unsafe {
let buffer = heap_alloc((slice.len() + 1) * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;
assert!(buffer.align_offset(std::mem::align_of::<T>()) == 0, "heap allocated buffer is not properly aligned");
buffer.copy_from_nonoverlapping(slice.as_ptr(), slice.len());
buffer.add(slice.len()).write(T::default());
buffer
///
/// # Safety
/// len must not be less than the number of items in the iterator.
pub unsafe fn string_from_iter<I, T>(iter: I, len: usize) -> *const T
where
I: Iterator<Item = T>,
T: Copy + Default,
{
let str_len = len + 1;
let ptr = heap_alloc(str_len * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;

// TODO this assert is mostly redundant, HeapAlloc has alignment of 8, we currently only require alignments of 1 or 2.
// There is no meaningful string type with characters that require an alignment above 8.
assert_eq!(ptr.align_offset(std::mem::align_of::<T>()), 0, "heap allocated buffer is not properly aligned");

let mut encoder = iter.chain(core::iter::once(T::default()));

for i in 0..str_len {
core::ptr::write(
ptr.add(i),
match encoder.next() {
Some(encoded) => encoded,
None => break,
},
);
}

// TODO ensure `encoder` is a fused iterator
assert!(encoder.next().is_none(), "encoder returned more characters than expected");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this be fused because the chaining of once() is guaranteed to be fused? iter might not be fused, but the implementation will never call next() on it again because the chain moved on to the next one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, I didn't consider that. Once is indeed guaranteed to be fused, so that issue is solved!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asserting here seems wrong... I'm not sure we want to panic if the encoder contains more than len. If anything, a debug_assert! might be appropriate (though I think silently ignoring is likely ok).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also like to avoid panics in general. I have customers for whom this is inappropriate. If it means we have to embed this inside PCWSTR and PCSTR directly to avoid the issues with generality, then so be it.

Copy link
Contributor

@ryancerium ryancerium Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this panic: string_from_iter("hello world".encode_utf16(), 5)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should never happen. My concern is that we're trying to harden a function that is only ever used internally so if there's a general concern about the safety of this function then we can either mark it unsafe or just get rid of the function entirely.


ptr
}
2 changes: 1 addition & 1 deletion crates/libs/windows/src/core/pcstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCSTR> for &str {
fn into_param(self) -> Param<'a, PCSTR> {
Param::Boxed(PCSTR(heap_string(self.as_bytes())))
Param::Boxed(PCSTR(unsafe { string_from_iter(self.as_bytes().iter().copied(), self.len()) }))
}
}
#[cfg(feature = "alloc")]
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/windows/src/core/pcwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCWSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCWSTR> for &str {
fn into_param(self) -> Param<'a, PCWSTR> {
Param::Boxed(PCWSTR(heap_string(&self.encode_utf16().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(unsafe { string_from_iter(self.encode_utf16(), self.len()) }))
Copy link
Contributor

@ryancerium ryancerium Apr 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will self.len() >= self.encode_utf16().count() always hold true? Intuitively it seems like it should, a 3-byte UTF-8 sequence representing two code points could be represented by either two u16 if they combine, or three u16 if they map 1:1. Worst case is that it over-allocates.

Edit: Based on everything I've been able to find, the worst case is over-allocation.

Copy link
Contributor

@tim-weis tim-weis Apr 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd been banging my head at this for a while now, too. It seems to hold true, so long as encode_utf16() doesn't do anything crazy, like decomposing code points.

The rationale I came up with is this: Each Unicode code point can be encoded as no more than two UTF-16 code units. For the inequality to hold, we can thus ignore all UTF-8 sequences that consist of two or more code units. That leaves the set of code points that require a single UTF-8 code unit. Those are the 128 ASCII characters, which are all encoded using a single code unit in UTF-16 as well.

In summary:

  • UTF-8 sequences of length 1 map to UTF-16 sequences of length 1
  • UTF-8 sequences of length 2 (or more) map to UTF-16 sequences of length 2 (or less)

While I couldn't find any formal documentation on what encode_utf16() does and doesn't do, the example listed under encode_utf16() at least seems to confirm the reasoning above: assert!(utf16_len <= utf8_len);.

Still a bit uneasy about all of this. Personally, I'd prefer a string_from_iter() that takes an iterator only, wastes some cycles on count() (ultimately doing the de-/encoding twice), and drops the unsafe along the way.

Thoughts?

Copy link
Contributor Author

@AronParker AronParker Apr 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will given three assumptions:

  1. the encoding of a Unicode code point in UTF-8 characters will always be higher or equal to the number of UTF-16 characters. UTF-8 is a variable-width character encoding that takes four possible forms:

1 UTF-8 Character => 1 UTF-16 Character
2 UTF-8 Characters => 1 UTF-16 Character
3 UTF-8 Characters => 1 UTF-16 Character
4 UTF-8 Characters => 2 UTF-16 Characters

  1. No (de)normalization takes place / no decomposing or composing character transformations (e.g. A + ¨: LATIN CAPITAL LETTER A (U+0041) + COMBINING DIAERESIS (U+0308)). There is no indication in the documentation of the Rust standard library that any of such normal or transformations takes place and looking at the source code only validates that assumption.

  2. Input is valid UTF-8. We might have to look into edge-cases around OsStr, since it consists of WTF-8 strings. Even WTF-8 strings uphold this invariant, because unpaired surrogate 16-bit code units would still shrink to 1 UTF-16 Character (2 UTF-16 characters for supplementary code points, which requires 4 UTF-8 characters).

I strongly believe that it should not be the case, but if there is concern we could also safeguard it either by truncating the maximum number of UTF-16 characters or panicing if that were to occur.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I'd prefer a string_from_iter() that takes an iterator only, wastes some cycles on count() (ultimately doing the de-/encoding twice), and drops the unsafe along the way.

Is there a guarantee that an iterator can be iterated twice?

I doubt there's any de/normalization taking place in encode_utf16() I'd be willing to bet dozens of dollars on it! Survey says... There's no denormalization, just bit twiddling as expected: https://doc.rust-lang.org/stable/src/core/char/methods.rs.html#1706

Also, the Rust documentation for encode_utf16() seems to indicate that the utf-16 u16 count is strictly less than the utf-8 byte count:
https://doc.rust-lang.org/stable/src/core/str/mod.rs.html#998

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the Rust documentation for encode_utf16() seems to indicate that the utf-16 u16 count is strictly less than the utf-8 byte count:
https://doc.rust-lang.org/stable/src/core/str/mod.rs.html#998

Nice find. Even though it's just an example I'd say it was intended as a guarantee.

The way I see it there are three ways in which we can resolve this:

  1. For each core::ptr::write we do, we assert that we are within bounds first (panic behaviour)
  2. We run .take(len + 1) on the iterator, which truncates any theoretical superfluous characters.
  3. We rethink our approach with iterators.

To be honest I'm not very happy about this function to begin with. I designed it the way I did in order to remain "backwards compatible" to heap_string. We are currently generic over T with an iterator we can only iterate once.

Why do we need to be generic over T? The only strings that the Windows API uses are UTF-8 and UTF-16 and some day there might be a use case for UTF-32. These have an alignment of 1, 2 and 4 respectively. As we are generic currently, we assert that the alignment of the output buffer matches the alignment of T. However, given that HeapAlloc has an alignment of 8+ and that we only ever use 4 as maximum alignment, that check is fairly redundant.

As it stands, this API has the following inputs: &OsStr and &str. And the following outputs: *const u8 and *const u16. We could create two functions instead: heap_string_narrow, heap_string_wide for *const u8 and *const u16, respectively. This would allow us to iterate over the string any number of times and ensure they are within bounds.

As I was writing this I actually have an even better idea, hold on let me implement it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you just embrace the fact that you know self is a &str and pass in the actual length instead of over-allocating? I don't know how expensive an over-allocation is relative to the recalculation.

string_from_iter(self.encode_utf16(), self.encode_utf16().count())

Copy link
Contributor Author

@AronParker AronParker Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to iterate over it twice, which provides some memory savings at the cost of run-time computation.

The amount of memory we waste is 0% if the string consists of ASCII only.
The amount of memory we waste is 50% if the string consists of BMP code points (that are not ASCII) only.
The amount of memory we waste is 50% if the string consists of supplementary code points only (requires surrogate pairs for UTF-16 (2 chars) and 4 byte-sequences for UTF-8 (4 chars))

Given that most strings are ASCII or contain mostly ASCII, I believe it is a reasonable choice. When considering memory waste we have to look at the intention the IntoParam trait provides: It is a convenience API.

In heavily resource constrained Windows environment such as Nano server, with a minimum requirement of 512 MB, I'd believe heap allocation is frowned upon as is case in most embedded environments. I don't think this convenience API is something that is intended for Windows Nano environments and even if, given small strings the waste is still comparatively low.

Given the 2 GB memory requirement of Windows, the only way in which string allocation would be worry-some would be strings with length of 512 MB. Aside from the fact that it would be ludicrous to keep such huge strings in memory and not maintain/convert for yourself for performance reasons, these strings are also only allocated for the duration of the function call, they are immediately freed after.

I'd say given the use case, only a very low amount of PCs having memory with less than 4 GB and that most strings aren't occupying some ludicrous size, I'd say that over-allocating here is a reasonable choice to make. The mostly practical application of supplementary code points or BMP code-points is mostly foreign languages that due to their amount of characters convey more entropy in their texts anyway and are hence smaller.

That being said, allocation the optimal length is still a desirable goal. Additionally, reading from the memory when decoding the unicode code points will likely enter the processor cache, which should make the the second iteration faster. I'll rewrite something later that takes advantage of this.

}
}
#[cfg(feature = "alloc")]
Expand All @@ -58,7 +58,7 @@ impl<'a> IntoParam<'a, PCWSTR> for alloc::string::String {
impl<'a> IntoParam<'a, PCWSTR> for &::std::ffi::OsStr {
fn into_param(self) -> Param<'a, PCWSTR> {
use ::std::os::windows::ffi::OsStrExt;
Param::Boxed(PCWSTR(heap_string(&self.encode_wide().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(unsafe { string_from_iter(self.encode_wide(), self.len()) }))
}
}
#[cfg(feature = "alloc")]
Expand Down