diff --git a/base/rawbigints.jl b/base/rawbigints.jl index 6508bea05be0f..a9bb18e163e2d 100644 --- a/base/rawbigints.jl +++ b/base/rawbigints.jl @@ -21,14 +21,21 @@ reversed_index(n::Int, i::Int) = n - i - 1 reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero) +function get_elem_words_raw(x::RawBigInt{T}, i::Int) where {T} + @boundscheck if (i < 0) || (elem_count(x, Val(:words)) ≤ i) + throw(BoundsError(x, i)) + end + d = x.d + j = i + 1 + (GC.@preserve d unsafe_load(Ptr{T}(pointer(d)), j))::T +end + """ `i` is the zero-based index of the wanted word in `x`, starting from the less significant words. """ -function get_elem(x::RawBigInt{T}, i::Int, ::Val{:words}, ::Val{:ascending}) where {T} - # `i` must be non-negative and less than `x.word_count` - d = x.d - (GC.@preserve d unsafe_load(Ptr{T}(pointer(d)), i + 1))::T +function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending}) + @inbounds @inline get_elem_words_raw(x, i) end function get_elem(x, i::Int, v::Val, ::Val{:descending}) @@ -96,7 +103,8 @@ end """ Returns an integer of type `R`, consisting of the `len` most -significant bits of `x`. +significant bits of `x`. If there are less than `len` bits in `x`, +the least significant bits are zeroed. """ function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer} ret = zero(R) @@ -104,17 +112,22 @@ function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer} word_count, bit_count_in_word = split_bit_index(x, len) k = word_length(x) vals = (Val(:words), Val(:descending)) + lenx = elem_count(x, first(vals)) for w ∈ 0:(word_count - 1) ret <<= k - word = get_elem(x, w, vals...) - ret |= R(word) + if w < lenx + word = get_elem(x, w, vals...) + ret |= R(word) + end end if !iszero(bit_count_in_word) ret <<= bit_count_in_word - wrd = get_elem(x, word_count, vals...) - ret |= R(wrd >>> (k - bit_count_in_word)) + if word_count < lenx + wrd = get_elem(x, word_count, vals...) + ret |= R(wrd >>> (k - bit_count_in_word)) + end end end ret::R diff --git a/test/mpfr.jl b/test/mpfr.jl index 9a9698ba72c2c..63da732df1c09 100644 --- a/test/mpfr.jl +++ b/test/mpfr.jl @@ -1088,3 +1088,12 @@ end clear_flags() end end + +@testset "RawBigInt truncation OOB read" begin + @testset "T: $T" for T ∈ (UInt8, UInt16, UInt32, UInt64, UInt128) + v = Base.RawBigInt{T}("a"^sizeof(T), 1) + @testset "bit_count: $bit_count" for bit_count ∈ (0:10:80) + @test Base.truncated(UInt128, v, bit_count) isa Any + end + end +end