Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove multi-stream related code #483

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 45 additions & 91 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
void VideoDecoder::addVideoStreamDecoder(
int preferredStreamIndex,
const VideoStreamOptions& videoStreamOptions) {
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
throw std::invalid_argument(
"Stream with index " + std::to_string(preferredStreamIndex) +
" is already active.");
}
TORCH_CHECK(activeStreamIndex_ == -1, "Can only add one single stream.");
TORCH_CHECK(formatContext_.get() != nullptr);

AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
Expand Down Expand Up @@ -520,7 +516,7 @@ void VideoDecoder::addVideoStreamDecoder(
}

codecContext->time_base = streamInfo.stream->time_base;
activeStreamIndices_.insert(streamIndex);
activeStreamIndex_ = streamIndex;
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
streamInfo.videoStreamOptions = videoStreamOptions;

Expand Down Expand Up @@ -740,53 +736,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
// the comment of canWeAvoidSeeking() for details.
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
if (activeStreamIndices_.size() == 0) {
if (activeStreamIndex_ == -1) {
return;
}
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streamInfos_[streamIndex];
// clang-format off: clang format clashes
streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
// clang-format on
}
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
streamInfo.discardFramesBeforePts =
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);

decodeStats_.numSeeksAttempted++;
// See comment for canWeAvoidSeeking() for details on why this optimization
// works.
bool mustSeek = false;
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streamInfos_[streamIndex];
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
if (!canWeAvoidSeekingForStream(
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
mustSeek = true;
break;
}
}
if (!mustSeek) {

int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
if (canWeAvoidSeekingForStream(
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
decodeStats_.numSeeksSkipped++;
return;
}
int firstActiveStreamIndex = *activeStreamIndices_.begin();
const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex];
int64_t desiredPts =
secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase);
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);

// For some encodings like H265, FFMPEG sometimes seeks past the point we
// set as the max_ts. So we use our own index to give it the exact pts of
// the key frame that we want to seek to.
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
if (!firstStreamInfo.keyFrames.empty()) {
if (!streamInfo.keyFrames.empty()) {
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
firstStreamInfo.keyFrames, desiredPts);
streamInfo.keyFrames, desiredPts);
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
}

int ffmepgStatus = avformat_seek_file(
formatContext_.get(),
firstStreamInfo.streamIndex,
streamInfo.streamIndex,
INT64_MIN,
desiredPts,
desiredPts,
Expand All @@ -797,15 +779,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
}
decodeStats_.numFlushes++;
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streamInfos_[streamIndex];
avcodec_flush_buffers(streamInfo.codecContext.get());
}
avcodec_flush_buffers(streamInfo.codecContext.get());
}

VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
std::function<bool(int, AVFrame*)> filterFunction) {
if (activeStreamIndices_.size() == 0) {
std::function<bool(AVFrame*)> filterFunction) {
if (activeStreamIndex_ == -1) {
throw std::runtime_error("No active streams configured.");
}

Expand All @@ -817,44 +796,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
desiredPtsSeconds_ = std::nullopt;
}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

// Need to get the next frame or error from PopFrame.
UniqueAVFrame avFrame(av_frame_alloc());
AutoAVPacket autoAVPacket;
int ffmpegStatus = AVSUCCESS;
bool reachedEOF = false;
int frameStreamIndex = -1;
while (true) {
frameStreamIndex = -1;
bool gotPermanentErrorOnAnyActiveStream = false;

// Get a frame on an active stream. Note that we don't know ahead of time
// which streams have frames to receive, so we linearly try the active
// streams.
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streamInfos_[streamIndex];
ffmpegStatus =
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());

if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
gotPermanentErrorOnAnyActiveStream = true;
break;
}
ffmpegStatus =
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());

if (ffmpegStatus == AVSUCCESS) {
frameStreamIndex = streamIndex;
break;
}
}

if (gotPermanentErrorOnAnyActiveStream) {
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
// Non-retriable error
break;
}

decodeStats_.numFramesReceivedByDecoder++;

// Is this the kind of frame we're looking for?
if (ffmpegStatus == AVSUCCESS &&
filterFunction(frameStreamIndex, avFrame.get())) {
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
// Yes, this is the frame we'll return; break out of the decoding loop.
break;
} else if (ffmpegStatus == AVSUCCESS) {
Expand All @@ -879,18 +839,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
decodeStats_.numPacketsRead++;

if (ffmpegStatus == AVERROR_EOF) {
// End of file reached. We must drain all codecs by sending a nullptr
// End of file reached. We must drain the codec by sending a nullptr
// packet.
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streamInfos_[streamIndex];
ffmpegStatus = avcodec_send_packet(
streamInfo.codecContext.get(),
/*avpkt=*/nullptr);
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error(
"Could not flush decoder: " +
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}
ffmpegStatus = avcodec_send_packet(
streamInfo.codecContext.get(),
/*avpkt=*/nullptr);
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error(
"Could not flush decoder: " +
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}

// We've reached the end of file so we can't read any more packets from
Expand All @@ -906,15 +863,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}

if (activeStreamIndices_.count(packet->stream_index) == 0) {
// This packet is not for any of the active streams.
if (packet->stream_index != activeStreamIndex_) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this is the "demux" part now. And I think we should be calling av_read_frame within its own while loop, for as long as the received packet isn't of the target stream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, what we're currently doing is inefficient, because if the packet is not the right stream, we first have to make a call to avcodec_receive_frame() that we know will fail. We're still doing the correct thing (I think), it's just inefficient.

continue;
}

// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
ffmpegStatus = avcodec_send_packet(
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
ffmpegStatus =
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error(
"Could not push packet to decoder: " +
Expand All @@ -941,11 +897,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
// av_receive_frame() or the user will have seeked to a different location in
// the file and that will flush the decoder.
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
activeStreamInfo.currentPts = avFrame->pts;
activeStreamInfo.currentDuration = getDuration(avFrame);
streamInfo.currentPts = avFrame->pts;
streamInfo.currentDuration = getDuration(avFrame);

return AVFrameStream(std::move(avFrame), frameStreamIndex);
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need this AVFrameStream struct anymore, since we know the streamIndex is always activeStreamIndex_. We could remove it, but I guess this is better done in another PR.

}

VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
Expand Down Expand Up @@ -1110,8 +1065,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(

setCursorPtsInSeconds(seconds);
AVFrameStream avFrameStream =
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
decodeAVFrame([seconds, this](AVFrame* avFrame) {
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
double frameEndTime = ptsToSeconds(
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
Expand Down Expand Up @@ -1510,11 +1465,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {

VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
AVFrameStream avFrameStream =
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
});
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
});
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
}

Expand Down
7 changes: 2 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,7 @@ class VideoDecoder {
const enum AVColorSpace colorspace);

void maybeSeekToBeforeDesiredPts();
AVFrameStream decodeAVFrame(
std::function<bool(int, AVFrame*)> filterFunction);
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
// Once we create a decoder can update the metadata with the codec context.
// For example, for video streams, we can add the height and width of the
// decoded stream.
Expand Down Expand Up @@ -435,9 +434,7 @@ class VideoDecoder {
ContainerMetadata containerMetadata_;
UniqueAVFormatContext formatContext_;
std::map<int, StreamInfo> streamInfos_;
// Stores the stream indices of the active streams, i.e. the streams we are
// decoding and returning to the user.
std::set<int> activeStreamIndices_;
int activeStreamIndex_ = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to both make this a named constant, and to use something other than -1. Perhaps:

const int NO_ACTIVE_STREAM = -2;

Obviously 0 and above are potential active stream numbers, so they can't be used as such an indicator. But I'm worried that -1 can get confusing, because that is used to ask FFmpeg to find the best stream in some API calls. So I'd like to also treat -1 as a special stream value.

// Set when the user wants to seek and stores the desired pts that the user
// wants to seek to.
std::optional<double> desiredPtsSeconds_;
Expand Down
Loading