From 182b14fd0c207ef0ec3fd17abd1bd201d2e82faf Mon Sep 17 00:00:00 2001 From: Nathan Zimmerberg <39104088+nhz2@users.noreply.github.com> Date: Fri, 5 Apr 2024 22:52:27 -0400 Subject: [PATCH] Fixes for `position` with nested `NoopStreams` (#203) --- .github/workflows/CI.yml | 7 +++++++ ext/TestExt.jl | 1 + fuzz/fuzz.jl | 3 +-- src/noop.jl | 36 +++++++++++++++++++++++++++-------- src/stream.jl | 8 ++++++-- test/codecdoubleframe.jl | 2 ++ test/codecnoop.jl | 41 +++++++++++++++++++++++++++++++++++++--- test/codecquadruple.jl | 12 ++++++------ 8 files changed, 89 insertions(+), 21 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 18a8d850..f5cdfcdf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,6 +26,13 @@ jobs: - windows-latest arch: - x64 + include: + - os: ubuntu-latest + version: '1' + arch: x86 + - os: macOS-14 + version: '1' + arch: aarch64 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/ext/TestExt.jl b/ext/TestExt.jl index b7c7de36..a1a15668 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -89,6 +89,7 @@ function TranscodingStreams.test_chunked_read(Encoder, Decoder) for chunk in chunks stream = TranscodingStream(Decoder(), buffer, stop_on_end=true) ok &= read(stream) == chunk + ok &= position(stream) == length(chunk) ok &= eof(stream) ok &= isreadable(stream) close(stream) diff --git a/fuzz/fuzz.jl b/fuzz/fuzz.jl index 10b48060..ef36dcc5 100644 --- a/fuzz/fuzz.jl +++ b/fuzz/fuzz.jl @@ -120,8 +120,7 @@ end for r in rs d = r(stream) append!(x, d) - # TODO fix position - # length(x) == position(stream) || return false + length(x) == position(stream) || return false end x == data[eachindex(x)] end diff --git a/src/noop.jl b/src/noop.jl index 9ece4e1f..e555ba07 100644 --- a/src/noop.jl +++ b/src/noop.jl @@ -53,16 +53,18 @@ Note that this method may return a wrong position when - some data have been inserted by `TranscodingStreams.unread`, or - the position of the wrapped stream has been changed outside of this package. """ -function Base.position(stream::NoopStream) +function Base.position(stream::NoopStream)::Int64 mode = stream.state.mode - if mode === :idle + if !isopen(stream) + throw_invalid_mode(mode) + elseif mode === :idle return Int64(0) + elseif has_sharedbuf(stream) + return position(stream.stream) elseif mode === :write return position(stream.stream) + buffersize(stream.buffer1) - elseif mode === :read + else # read return position(stream.stream) - buffersize(stream.buffer1) - else - throw_invalid_mode(mode) end @assert false "unreachable" end @@ -97,8 +99,25 @@ function Base.seekend(stream::NoopStream) return stream end -function Base.unsafe_write(stream::NoopStream, input::Ptr{UInt8}, nbytes::UInt) +function Base.write(stream::NoopStream, b::UInt8)::Int + changemode!(stream, :write) + if has_sharedbuf(stream) + # directly write data to the underlying stream + n = Int(write(stream.stream, b)) + return n + end + buffer1 = stream.buffer1 + marginsize(buffer1) > 0 || flushbuffer(stream) + return writebyte!(buffer1, b) +end + +function Base.unsafe_write(stream::NoopStream, input::Ptr{UInt8}, nbytes::UInt)::Int changemode!(stream, :write) + if has_sharedbuf(stream) + # directly write data to the underlying stream + n = Int(unsafe_write(stream.stream, input, nbytes)) + return n + end buffer = stream.buffer1 if marginsize(buffer) ≥ nbytes copydata!(buffer, input, nbytes) @@ -106,7 +125,8 @@ function Base.unsafe_write(stream::NoopStream, input::Ptr{UInt8}, nbytes::UInt) else flushbuffer(stream) # directly write data to the underlying stream - return unsafe_write(stream.stream, input, nbytes) + n = Int(unsafe_write(stream.stream, input, nbytes)) + return n end end @@ -152,7 +172,7 @@ function fillbuffer(stream::NoopStream; eager::Bool = false)::Int changemode!(stream, :read) buffer = stream.buffer1 @assert buffer === stream.buffer2 - if stream.stream isa TranscodingStream && buffer === stream.stream.buffer1 + if has_sharedbuf(stream) # Delegate the operation when buffers are shared. underlying_mode::Symbol = stream.stream.state.mode if underlying_mode === :idle || underlying_mode === :read diff --git a/src/stream.jl b/src/stream.jl index 0eae4982..347f7407 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -157,6 +157,10 @@ end # throw ArgumentError that mode is invalid. throw_invalid_mode(mode) = throw(ArgumentError(string("invalid mode :", mode))) +# Return true if the stream shares buffers with underlying stream +function has_sharedbuf(stream::TranscodingStream)::Bool + stream.stream isa TranscodingStream && stream.buffer2 === stream.stream.buffer1 +end # Base IO Functions # ----------------- @@ -264,7 +268,7 @@ function Base.position(stream::TranscodingStream) mode = stream.state.mode if mode === :idle return Int64(0) - elseif mode === :read + elseif mode === :read || mode === :stop return stats(stream).out elseif mode === :write return stats(stream).in @@ -584,7 +588,7 @@ function stats(stream::TranscodingStream) buffer2 = stream.buffer2 if mode === :idle transcoded_in = transcoded_out = in = out = 0 - elseif mode === :read + elseif mode === :read || mode === :stop transcoded_in = buffer2.transcoded transcoded_out = buffer1.transcoded in = transcoded_in + buffersize(buffer2) diff --git a/test/codecdoubleframe.jl b/test/codecdoubleframe.jl index f3b5076d..1ebeee02 100644 --- a/test/codecdoubleframe.jl +++ b/test/codecdoubleframe.jl @@ -272,6 +272,7 @@ DoubleFrameDecoderStream(stream::IO; kwargs...) = TranscodingStream(DoubleFrameD ) )) @test read(s1) == b"" + @test position(s1) == 0 @test eof(s1) s2 = NoopStream( @@ -281,6 +282,7 @@ DoubleFrameDecoderStream(stream::IO; kwargs...) = TranscodingStream(DoubleFrameD ) ) @test read(s2) == b"" + @test position(s1) == 0 @test eof(s2) end diff --git a/test/codecnoop.jl b/test/codecnoop.jl index 3583cbdc..8b988749 100644 --- a/test/codecnoop.jl +++ b/test/codecnoop.jl @@ -50,11 +50,11 @@ close(stream) stream = TranscodingStream(Noop(), IOBuffer(b"foobarbaz")) - @test position(stream) === 0 + @test position(stream) === Int64(0) read(stream, UInt8) - @test position(stream) === 1 + @test position(stream) === Int64(1) read(stream) - @test position(stream) === 9 + @test position(stream) === Int64(9) data = collect(0x00:0x0f) stream = TranscodingStream(Noop(), IOBuffer(data)) @@ -368,6 +368,41 @@ @test position(stream) == pos end end + + @testset "writing nested NoopStream sharedbuf=$(sharedbuf)" for sharedbuf in (true, false) + stream = NoopStream(NoopStream(IOBuffer()); sharedbuf, bufsize=4) + @test position(stream) == 0 + write(stream, 0x01) + @test position(stream) == 1 + flush(stream) + @test position(stream) == 1 + write(stream, "abc") + @test position(stream) == 4 + flush(stream) + @test position(stream) == 4 + for i in 1:10 + write(stream, 0x01) + @test position(stream) == 4 + i + end + end + + @testset "reading nested NoopStream sharedbuf=$(sharedbuf)" for sharedbuf in (true, false) + stream = NoopStream(NoopStream(IOBuffer("abcdefghijk")); sharedbuf, bufsize=4) + @test position(stream) == 0 + @test !eof(stream) + @test position(stream) == 0 + @test read(stream, UInt8) == b"a"[1] + @test position(stream) == 1 + @test read(stream, 3) == b"bcd" + @test position(stream) == 4 + @test !eof(stream) + @test position(stream) == 4 + @test read(stream) == b"efghijk" + @test position(stream) == 11 + @test eof(stream) + @test position(stream) == 11 + end + end @testset "seek doesn't delete data" begin diff --git a/test/codecquadruple.jl b/test/codecquadruple.jl index 1e17db94..31a0ef53 100644 --- a/test/codecquadruple.jl +++ b/test/codecquadruple.jl @@ -51,19 +51,19 @@ end close(stream) stream = TranscodingStream(QuadrupleCodec(), IOBuffer("foo")) - @test position(stream) === 0 + @test position(stream) === Int64(0) read(stream, 3) - @test position(stream) === 3 + @test position(stream) === Int64(3) read(stream, UInt8) - @test position(stream) === 4 + @test position(stream) === Int64(4) close(stream) stream = TranscodingStream(QuadrupleCodec(), IOBuffer()) - @test position(stream) === 0 + @test position(stream) === Int64(0) write(stream, 0x00) - @test position(stream) === 1 + @test position(stream) === Int64(1) write(stream, "foo") - @test position(stream) === 4 + @test position(stream) === Int64(4) close(stream) # Buffers are shared.