-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid double allocation when passing strings via IntoParam
#1713
Conversation
IntoParam
Might be worth adding a NUL-in-the-middle test ( |
@@ -45,7 +45,7 @@ unsafe impl Abi for PCWSTR { | |||
#[cfg(feature = "alloc")] | |||
impl<'a> IntoParam<'a, PCWSTR> for &str { | |||
fn into_param(self) -> Param<'a, PCWSTR> { | |||
Param::Boxed(PCWSTR(heap_string(&self.encode_utf16().collect::<alloc::vec::Vec<u16>>()))) | |||
Param::Boxed(PCWSTR(unsafe { string_from_iter(self.encode_utf16(), self.len()) })) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will self.len() >= self.encode_utf16().count()
always hold true? Intuitively it seems like it should, a 3-byte UTF-8 sequence representing two code points could be represented by either two u16
if they combine, or three u16
if they map 1:1. Worst case is that it over-allocates.
Edit: Based on everything I've been able to find, the worst case is over-allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd been banging my head at this for a while now, too. It seems to hold true, so long as encode_utf16()
doesn't do anything crazy, like decomposing code points.
The rationale I came up with is this: Each Unicode code point can be encoded as no more than two UTF-16 code units. For the inequality to hold, we can thus ignore all UTF-8 sequences that consist of two or more code units. That leaves the set of code points that require a single UTF-8 code unit. Those are the 128 ASCII characters, which are all encoded using a single code unit in UTF-16 as well.
In summary:
- UTF-8 sequences of length 1 map to UTF-16 sequences of length 1
- UTF-8 sequences of length 2 (or more) map to UTF-16 sequences of length 2 (or less)
While I couldn't find any formal documentation on what encode_utf16()
does and doesn't do, the example listed under encode_utf16()
at least seems to confirm the reasoning above: assert!(utf16_len <= utf8_len);
.
Still a bit uneasy about all of this. Personally, I'd prefer a string_from_iter()
that takes an iterator only, wastes some cycles on count()
(ultimately doing the de-/encoding twice), and drops the unsafe
along the way.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will given three assumptions:
- the encoding of a Unicode code point in UTF-8 characters will always be higher or equal to the number of UTF-16 characters. UTF-8 is a variable-width character encoding that takes four possible forms:
1 UTF-8 Character => 1 UTF-16 Character
2 UTF-8 Characters => 1 UTF-16 Character
3 UTF-8 Characters => 1 UTF-16 Character
4 UTF-8 Characters => 2 UTF-16 Characters
-
No (de)normalization takes place / no decomposing or composing character transformations (e.g. A + ¨: LATIN CAPITAL LETTER A (U+0041) + COMBINING DIAERESIS (U+0308)). There is no indication in the documentation of the Rust standard library that any of such normal or transformations takes place and looking at the source code only validates that assumption.
-
Input is valid UTF-8. We might have to look into edge-cases around OsStr, since it consists of WTF-8 strings.Even WTF-8 strings uphold this invariant, because unpaired surrogate 16-bit code units would still shrink to 1 UTF-16 Character (2 UTF-16 characters for supplementary code points, which requires 4 UTF-8 characters).
I strongly believe that it should not be the case, but if there is concern we could also safeguard it either by truncating the maximum number of UTF-16 characters or panicing if that were to occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I'd prefer a
string_from_iter()
that takes an iterator only, wastes some cycles oncount()
(ultimately doing the de-/encoding twice), and drops theunsafe
along the way.
Is there a guarantee that an iterator can be iterated twice?
I doubt there's any de/normalization taking place in encode_utf16()
I'd be willing to bet dozens of dollars on it! Survey says... There's no denormalization, just bit twiddling as expected: https://doc.rust-lang.org/stable/src/core/char/methods.rs.html#1706
Also, the Rust documentation for encode_utf16()
seems to indicate that the utf-16 u16
count is strictly less than the utf-8 byte count:
https://doc.rust-lang.org/stable/src/core/str/mod.rs.html#998
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the Rust documentation for encode_utf16() seems to indicate that the utf-16 u16 count is strictly less than the utf-8 byte count:
https://doc.rust-lang.org/stable/src/core/str/mod.rs.html#998
Nice find. Even though it's just an example I'd say it was intended as a guarantee.
The way I see it there are three ways in which we can resolve this:
- For each core::ptr::write we do, we assert that we are within bounds first (panic behaviour)
- We run .take(len + 1) on the iterator, which truncates any theoretical superfluous characters.
- We rethink our approach with iterators.
To be honest I'm not very happy about this function to begin with. I designed it the way I did in order to remain "backwards compatible" to heap_string
. We are currently generic over T with an iterator we can only iterate once.
Why do we need to be generic over T? The only strings that the Windows API uses are UTF-8 and UTF-16 and some day there might be a use case for UTF-32. These have an alignment of 1, 2 and 4 respectively. As we are generic currently, we assert that the alignment of the output buffer matches the alignment of T. However, given that HeapAlloc
has an alignment of 8+ and that we only ever use 4 as maximum alignment, that check is fairly redundant.
As it stands, this API has the following inputs: &OsStr
and &str
. And the following outputs: *const u8
and *const u16
. We could create two functions instead: heap_string_narrow
, heap_string_wide
for *const u8
and *const u16
, respectively. This would allow us to iterate over the string any number of times and ensure they are within bounds.
As I was writing this I actually have an even better idea, hold on let me implement it.
Good idea! Probably good for another pull request though, we should try to solve / validate this implementation first. |
I just pushed a new version that has the following advantages:
There are two more issues to be solved:
For |
crates/libs/windows/src/core/heap.rs
Outdated
} | ||
|
||
// TODO ensure `encoder` is a fused iterator | ||
assert!(encoder.next().is_none(), "encoder returned more characters than expected"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this be fused because the chaining of once()
is guaranteed to be fused? iter
might not be fused, but the implementation will never call next()
on it again because the chain moved on to the next one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, I didn't consider that. Once
is indeed guaranteed to be fused, so that issue is solved!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
asserting
here seems wrong... I'm not sure we want to panic if the encoder contains more than len
. If anything, a debug_assert!
might be appropriate (though I think silently ignoring is likely ok).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd also like to avoid panics in general. I have customers for whom this is inappropriate. If it means we have to embed this inside PCWSTR
and PCSTR
directly to avoid the issues with generality, then so be it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this panic: string_from_iter("hello world".encode_utf16(), 5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should never happen. My concern is that we're trying to harden a function that is only ever used internally so if there's a general concern about the safety of this function then we can either mark it unsafe or just get rid of the function entirely.
@@ -45,7 +45,7 @@ unsafe impl Abi for PCWSTR { | |||
#[cfg(feature = "alloc")] | |||
impl<'a> IntoParam<'a, PCWSTR> for &str { | |||
fn into_param(self) -> Param<'a, PCWSTR> { | |||
Param::Boxed(PCWSTR(heap_string(&self.encode_utf16().collect::<alloc::vec::Vec<u16>>()))) | |||
Param::Boxed(PCWSTR(unsafe { string_from_iter(self.encode_utf16(), self.len()) })) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you just embrace the fact that you know self
is a &str
and pass in the actual length instead of over-allocating? I don't know how expensive an over-allocation is relative to the recalculation.
string_from_iter(self.encode_utf16(), self.encode_utf16().count())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need to iterate over it twice, which provides some memory savings at the cost of run-time computation.
The amount of memory we waste is 0% if the string consists of ASCII only.
The amount of memory we waste is 50% if the string consists of BMP code points (that are not ASCII) only.
The amount of memory we waste is 50% if the string consists of supplementary code points only (requires surrogate pairs for UTF-16 (2 chars) and 4 byte-sequences for UTF-8 (4 chars))
Given that most strings are ASCII or contain mostly ASCII, I believe it is a reasonable choice. When considering memory waste we have to look at the intention the IntoParam
trait provides: It is a convenience API.
In heavily resource constrained Windows environment such as Nano server, with a minimum requirement of 512 MB, I'd believe heap allocation is frowned upon as is case in most embedded environments. I don't think this convenience API is something that is intended for Windows Nano environments and even if, given small strings the waste is still comparatively low.
Given the 2 GB memory requirement of Windows, the only way in which string allocation would be worry-some would be strings with length of 512 MB. Aside from the fact that it would be ludicrous to keep such huge strings in memory and not maintain/convert for yourself for performance reasons, these strings are also only allocated for the duration of the function call, they are immediately freed after.
I'd say given the use case, only a very low amount of PCs having memory with less than 4 GB and that most strings aren't occupying some ludicrous size, I'd say that over-allocating here is a reasonable choice to make. The mostly practical application of supplementary code points or BMP code-points is mostly foreign languages that due to their amount of characters convey more entropy in their texts anyway and are hence smaller.
That being said, allocation the optimal length is still a desirable goal. Additionally, reading from the memory when decoding the unicode code points will likely enter the processor cache, which should make the the second iteration faster. I'll rewrite something later that takes advantage of this.
Thanks @AronParker - agree with your analysis. It doesn't seem worthwhile in this case. |
Great work @AronParker |
Thanks guys! Thanks to @kennykerr for the iterator idea and once again thanks @ryancerium for your extensive and helpful reviews, it made the code a lot better and robust in the end! |
crates/libs/windows/src/core/heap.rs
Outdated
for i in 0..len { | ||
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len` | ||
unsafe { | ||
core::ptr::write( | ||
ptr.add(i), | ||
match encoder.next() { | ||
Some(encoded) => encoded, | ||
None => break, | ||
}, | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this a bit clearer perhaps:
for (offset, c) in (0..len).zip(encoder) {
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len`
unsafe { core::ptr::write(ptr.add(offset), c); }
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much better! Unfortunately putting the encoder into .zip(self)
would consume it, which wouldn't allow our assertion afterwards:
assert!(encoder.next().is_none(), "encoder returned more characters than expected");
Your version does look much better and I'd take it in a heartbeat, but we'd need a non-consuming version of zip unfortunately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be unusual, but is it invalid to request encoding fewer characters than there are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically in my opinion it is impossible, but since there was uncertainty in the reviews about silent truncation I kept it in. It could be possible that somehow an invalid length gets passed in the future, so it'd be helpful to be aware via a panic rather than silent truncation. I'm not hugely opinionated on this issue, we can go for the more elegant .zip
at the expense of potential bugs in the encoder or what gets passed to the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't worried about truncation, I was worried about writing past the allocated memory with the unsafe ptr.write()
. Your outer loop and/or the .zip()
makes that a non-issue. No strong opinion on this either, I can see it both ways. We're not going to overwrite anything no matter what.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can store the zipped iterator into a local variable and still do the check afterwards if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Zip
iterator instance will return None
because the range iterator is finished. You can never get to the original encoder
iterator again because it's inside the Zip
iterator instance.
Possibly rust-lang/rust#95604 There have been regressions regarding stdcall symbol name mangling. |
Thanks Chris! Hey @rylev I noticed you commented on that - could this be related? |
Here's the fix: rust-lang/rust#96444 |
@AronParker there's a temporary fix for the build until nightly is fixed if you'd like to merge master. |
Alright seems good! Yeah I'm quite happy with the way this is right now. We could apply ryancerium's |
We could guarantee null-termination, but it would still potentially leave characters in the iterator. Honestly, if this is an internal-only function, I'd opt for readable because the crate owns all 3 places it gets called. // Allocate space for a terminating null
let alloc_len = len + 1;
let ptr = heap_alloc(alloc_len * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;
let mut encoder = iter.take(len) // read up to len chars from the iterator
.chain(core::iter::once(T::default()) // append the terminating null
.enumerate(); // memory offset of the char
for (offset, c) in encoder {
unsafe { core::ptr::write(ptr.add(offset, c); }
} |
In a way we guarantee null-termination already, if the encoder does not fully write the transcoded result including the null-terminator, it panics. Which is pretty much the nuclear option that is impossible to sweep under the rug. I'm open to changing it, but I'd still say that silent truncation is worse than a hard-error. Given under normal circumstances and people using the function correctly, it should never have any issues. But in the case issues arise (e.g. someone passed the wrong len etc.), I believe it is more advantageous to panic rather than silently truncate, because it makes the error most noticeable. But if others share this preference I'm open to changing it of course. |
|
I agree, I think it's good to have that assertion for that assumption there. |
crates/libs/windows/src/core/heap.rs
Outdated
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len` | ||
unsafe { | ||
core::ptr::write( | ||
ptr.add(i), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you use ptr.add(i).write(...)
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course! My bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can store the zipped iterator into a local variable and still do the check afterwards if you want.
Pardon, but how do I do that? The zip iterator consumes both inputs and only yields elements if both iterators have an element, what can I get from the zipped iterator besides None afterwards?
crates/libs/windows/src/core/heap.rs
Outdated
for i in 0..len { | ||
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len` | ||
unsafe { | ||
core::ptr::write( | ||
ptr.add(i), | ||
match encoder.next() { | ||
Some(encoded) => encoded, | ||
None => break, | ||
}, | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can store the zipped iterator into a local variable and still do the check afterwards if you want.
crates/libs/windows/src/core/heap.rs
Outdated
} | ||
|
||
// TODO ensure `encoder` is a fused iterator | ||
assert!(encoder.next().is_none(), "encoder returned more characters than expected"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
asserting
here seems wrong... I'm not sure we want to panic if the encoder contains more than len
. If anything, a debug_assert!
might be appropriate (though I think silently ignoring is likely ok).
crates/libs/windows/src/core/heap.rs
Outdated
ptr.add(i), | ||
match encoder.next() { | ||
Some(encoded) => encoded, | ||
None => break, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be useful to do the following here instead of the assert in code a few lines later:
debug_assert!(i == len -1);
Essentially, while this code is always safe (i.e., we'll never try to write to unallocated memory), if the iterator's length and len
don't match we either end up allocating too little memory or too much. It seems reasonable to help user's of this function not make that mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean debug_assert!(i < len)
and not ==
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant ==
, because I believe we expect that length will always be equal to the number of elements in the iterator.
crates/libs/windows/src/core/heap.rs
Outdated
buffer | ||
/// This function panics if the heap allocation fails, the alignment requirements of 'T' surpass | ||
/// 8 (HeapAlloc's alignment) or if len is less than the number of items in the iterator. | ||
pub fn string_from_iter<I, T>(iter: I, len: usize) -> *const T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it always bothered me a bit that we're using the term string
here when this function is more general than that. It might be nice to name it in a way that describes more closely what's actually happening.
In fact, it might make sense for this to only copy an iterator and the caller is responsible for adding the trailing null byte. This function would then lose the Default
bound and the caller would call it like so:
copy_from_iterator(self.as_bytes().iter().copied().chain(core::iter::once(0)), self.len() + 1);
The caller is a bit more verbose, but it's way clearer what's actually happening.
I just pushed a commit that contains the suggestions of ryancerium and rylev. I've renamed the function to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a work of art. I applaud your willingness to change tactics after excellent discussion. I'm approving just for the heck of it!
crates/libs/windows/src/core/heap.rs
Outdated
@@ -32,20 +32,32 @@ pub unsafe fn heap_free(ptr: RawPtr) { | |||
} | |||
} | |||
|
|||
/// Copy a slice of `T` into a freshly allocated buffer with an additional default `T` at the end. | |||
/// Copy len elements of an iterator of type `T` into a freshly allocated buffer with an additional default `T` at the end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No more default T
at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nice catch! Gonna fix it immediately.
Thank you very much for your positivity and continuous valuable input. It's been a pleasure to work with you! |
0ad384c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks folks! If there's no more feedback, I'm happy to merge this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like how this turned out! Thanks so much for your patience and contribution here 🧡
Same here, I love how clean this is! Thanks everyone again for your valuable feedback! |
Introduces an iterator based approach as discussed in #1712. The current behavior when passing a
&str
to a function using theIntoParam
trait works as follows:vec: Vec<u16>
and transcode the UTF-8 froms: &str
into thevec
with UTF-16.vec.len + 1
vec
into the newly allocated buffer(\0)
This pull request changes the behaviour as follows:
s.len() + 1
froms: &str
s
into the newly allocated buffer using UTF-16.(\0)
This saves one copy and at least one allocation (likely 3-4 on average). Collecting into a
Vec<u16>
as done currently usingOsStr::encode_wide
andStr::encode_utf16
mostly allocate 2-3 times as described here: rust-lang/rust#96297 (comment)Fixes #1712