From cfadb8c7d8528fe432a13df2ae44e1ec802f5214 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 27 Jan 2025 16:58:35 +0000 Subject: [PATCH 1/2] "Remove multi-stream related code" --- .../decoders/_core/VideoDecoder.cpp | 136 ++++++------------ src/torchcodec/decoders/_core/VideoDecoder.h | 7 +- 2 files changed, 47 insertions(+), 96 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 8dcb1bb4..ef567957 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -449,11 +449,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { void VideoDecoder::addVideoStreamDecoder( int preferredStreamIndex, const VideoStreamOptions& videoStreamOptions) { - if (activeStreamIndices_.count(preferredStreamIndex) > 0) { - throw std::invalid_argument( - "Stream with index " + std::to_string(preferredStreamIndex) + - " is already active."); - } + TORCH_CHECK(activeStreamIndex_ == -1, "Can only add one single stream."); TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; @@ -520,7 +516,7 @@ void VideoDecoder::addVideoStreamDecoder( } codecContext->time_base = streamInfo.stream->time_base; - activeStreamIndices_.insert(streamIndex); + activeStreamIndex_ = streamIndex; updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.videoStreamOptions = videoStreamOptions; @@ -740,53 +736,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { - if (activeStreamIndices_.size() == 0) { + if (activeStreamIndex_ == -1) { return; } - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - // clang-format off: clang format clashes - streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); - // clang-format on - } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + streamInfo.discardFramesBeforePts = + secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); decodeStats_.numSeeksAttempted++; - // See comment for canWeAvoidSeeking() for details on why this optimization - // works. - bool mustSeek = false; - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; - if (!canWeAvoidSeekingForStream( - streamInfo, streamInfo.currentPts, desiredPtsForStream)) { - mustSeek = true; - break; - } - } - if (!mustSeek) { + + int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; + if (canWeAvoidSeekingForStream( + streamInfo, streamInfo.currentPts, desiredPtsForStream)) { decodeStats_.numSeeksSkipped++; return; } - int firstActiveStreamIndex = *activeStreamIndices_.begin(); - const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex]; int64_t desiredPts = - secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase); + secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of // the key frame that we want to seek to. // See https://github.com/pytorch/torchcodec/issues/179 for more details. // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug. - if (!firstStreamInfo.keyFrames.empty()) { + if (!streamInfo.keyFrames.empty()) { int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex( - firstStreamInfo.keyFrames, desiredPts); + streamInfo.keyFrames, desiredPts); desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0); - desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts; + desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts; } int ffmepgStatus = avformat_seek_file( formatContext_.get(), - firstStreamInfo.streamIndex, + streamInfo.streamIndex, INT64_MIN, desiredPts, desiredPts, @@ -797,15 +779,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { getFFMPEGErrorStringFromErrorCode(ffmepgStatus)); } decodeStats_.numFlushes++; - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - avcodec_flush_buffers(streamInfo.codecContext.get()); - } + avcodec_flush_buffers(streamInfo.codecContext.get()); } VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( - std::function filterFunction) { - if (activeStreamIndices_.size() == 0) { + std::function filterFunction) { + if (activeStreamIndex_ == -1) { throw std::runtime_error("No active streams configured."); } @@ -817,44 +796,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( desiredPtsSeconds_ = std::nullopt; } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + // Need to get the next frame or error from PopFrame. UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int ffmpegStatus = AVSUCCESS; bool reachedEOF = false; - int frameStreamIndex = -1; while (true) { - frameStreamIndex = -1; - bool gotPermanentErrorOnAnyActiveStream = false; - - // Get a frame on an active stream. Note that we don't know ahead of time - // which streams have frames to receive, so we linearly try the active - // streams. - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - ffmpegStatus = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - - if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) { - gotPermanentErrorOnAnyActiveStream = true; - break; - } + ffmpegStatus = + avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - if (ffmpegStatus == AVSUCCESS) { - frameStreamIndex = streamIndex; - break; - } - } - - if (gotPermanentErrorOnAnyActiveStream) { + if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) { + // Non-retriable error break; } decodeStats_.numFramesReceivedByDecoder++; - // Is this the kind of frame we're looking for? - if (ffmpegStatus == AVSUCCESS && - filterFunction(frameStreamIndex, avFrame.get())) { + if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) { // Yes, this is the frame we'll return; break out of the decoding loop. break; } else if (ffmpegStatus == AVSUCCESS) { @@ -879,18 +839,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( decodeStats_.numPacketsRead++; if (ffmpegStatus == AVERROR_EOF) { - // End of file reached. We must drain all codecs by sending a nullptr + // End of file reached. We must drain the codec by sending a nullptr // packet. - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - ffmpegStatus = avcodec_send_packet( - streamInfo.codecContext.get(), - /*avpkt=*/nullptr); - if (ffmpegStatus < AVSUCCESS) { - throw std::runtime_error( - "Could not flush decoder: " + - getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); - } + ffmpegStatus = avcodec_send_packet( + streamInfo.codecContext.get(), + /*avpkt=*/nullptr); + if (ffmpegStatus < AVSUCCESS) { + throw std::runtime_error( + "Could not flush decoder: " + + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } // We've reached the end of file so we can't read any more packets from @@ -906,15 +863,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - if (activeStreamIndices_.count(packet->stream_index) == 0) { - // This packet is not for any of the active streams. + if (packet->stream_index != activeStreamIndex_) { continue; } // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - ffmpegStatus = avcodec_send_packet( - streamInfos_[packet->stream_index].codecContext.get(), packet.get()); + ffmpegStatus = + avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( "Could not push packet to decoder: " + @@ -941,11 +897,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - activeStreamInfo.currentPts = avFrame->pts; - activeStreamInfo.currentDuration = getDuration(avFrame); + streamInfo.currentPts = avFrame->pts; + streamInfo.currentDuration = getDuration(avFrame); - return AVFrameStream(std::move(avFrame), frameStreamIndex); + return AVFrameStream(std::move(avFrame), activeStreamIndex_); } VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( @@ -1110,8 +1065,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( setCursorPtsInSeconds(seconds); AVFrameStream avFrameStream = - decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; + decodeAVFrame([seconds, this](AVFrame* avFrame) { + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( avFrame->pts + getDuration(avFrame), streamInfo.timeBase); @@ -1510,11 +1465,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - AVFrameStream avFrameStream = - decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; - }); + AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { + StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; + return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; + }); return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8261634d..5d146938 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -404,8 +404,7 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); - AVFrameStream decodeAVFrame( - std::function filterFunction); + AVFrameStream decodeAVFrame(std::function filterFunction); // Once we create a decoder can update the metadata with the codec context. // For example, for video streams, we can add the height and width of the // decoded stream. @@ -435,9 +434,7 @@ class VideoDecoder { ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; std::map streamInfos_; - // Stores the stream indices of the active streams, i.e. the streams we are - // decoding and returning to the user. - std::set activeStreamIndices_; + int activeStreamIndex_ = -1; // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. std::optional desiredPtsSeconds_; From f352cfc4987283fb8e37f1a71988456352c88040 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Jan 2025 16:01:05 +0000 Subject: [PATCH 2/2] Use NO_ACTIVE_STREAM --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 +++++--- src/torchcodec/decoders/_core/VideoDecoder.h | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f70a38d2..778a1b3e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -435,7 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { void VideoDecoder::addVideoStreamDecoder( int preferredStreamIndex, const VideoStreamOptions& videoStreamOptions) { - TORCH_CHECK(activeStreamIndex_ == -1, "Can only add one single stream."); + TORCH_CHECK( + activeStreamIndex_ == NO_ACTIVE_STREAM, + "Can only add one single stream."); TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; @@ -722,7 +724,7 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { - if (activeStreamIndex_ == -1) { + if (activeStreamIndex_ == NO_ACTIVE_STREAM) { return; } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; @@ -770,7 +772,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( std::function filterFunction) { - if (activeStreamIndex_ == -1) { + if (activeStreamIndex_ == NO_ACTIVE_STREAM) { throw std::runtime_error("No active streams configured."); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 4b80816e..696b2fa2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -468,7 +468,8 @@ class VideoDecoder { ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; std::map streamInfos_; - int activeStreamIndex_ = -1; + const int NO_ACTIVE_STREAM = -2; + int activeStreamIndex_ = NO_ACTIVE_STREAM; // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. std::optional desiredPtsSeconds_;