Skip to content

Commit

Permalink
Implement delivery callback for HQStreamTransport
Browse files Browse the repository at this point in the history
Summary:
I'm modifying `HQSession::HQStreamTransport::sendWebTransportStreamData` so that delivery callbacks are conveyed to the passed-in `DeliveryCallback*` if it is non-`nullptr`.

I'm also adding some state to keep track of the preface size written to the beginning of each stream, so that we can subtract it off when we fire the delivery callback.

Reviewed By: hanidamlaj

Differential Revision: D69619154

fbshipit-source-id: e9f4f4444ea3fe555fd3af6be9d47d3a779b0f81
  • Loading branch information
Aman Sharma authored and facebook-github-bot committed Feb 20, 2025
1 parent 1ff5e10 commit 81eb4cb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
46 changes: 32 additions & 14 deletions proxygen/lib/http/session/HQSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <proxygen/lib/http/session/HQSession.h>
#include <proxygen/lib/http/webtransport/QuicWtDeliveryCallbackWrapper.h>

#include <proxygen/lib/http/HTTPPriorityFunctions.h>
#include <proxygen/lib/http/codec/HQControlCodec.h>
Expand Down Expand Up @@ -91,7 +92,8 @@ quic::Priority toQuicPriority(const proxygen::HTTPPriority& pri) {
return quic::Priority(pri.urgency, pri.incremental, pri.orderId);
}

bool writeWTStreamPrefaceToSock(
// Returns the number of bytes written. 0 if error.
uint32_t writeWTStreamPrefaceToSock(
quic::QuicSocket& sock,
quic::StreamId wtStreamId,
quic::StreamId wtSessionId,
Expand All @@ -101,14 +103,14 @@ bool writeWTStreamPrefaceToSock(
proxygen::hq::writeWTStreamPreface(writeBuf, streamType, wtSessionId);
if (!res) {
LOG(ERROR) << "Failed to write WT stream preface";
return false;
return 0;
}
auto writeRes = sock.writeChain(wtStreamId, writeBuf.move(), false);
if (writeRes.hasError()) {
LOG(ERROR) << "Failed to write stream preface to socket";
return false;
return 0;
}
return true;
return res.value();
}
} // namespace

Expand Down Expand Up @@ -3828,15 +3830,18 @@ HQSession::HQStreamTransport::newWebTransportBidiStream() {
return folly::makeUnexpected(
WebTransport::ErrorCode::STREAM_CREATION_ERROR);
}
if (!writeWTStreamPrefaceToSock(*session_.sock_,
*id,
getEgressStreamId(),
hq::WebTransportStreamType::BIDI)) {
auto numPrefaceBytesWritten =
writeWTStreamPrefaceToSock(*session_.sock_,
*id,
getEgressStreamId(),
hq::WebTransportStreamType::BIDI);
if (numPrefaceBytesWritten == 0) {
LOG(ERROR) << "Failed to write bidirectional stream preface";
// TODO: resetStream/stopSending?
return folly::makeUnexpected(
WebTransport::ErrorCode::STREAM_CREATION_ERROR);
}
streamIdToPrefaceSize_[*id] = numPrefaceBytesWritten;
return *id;
}

Expand All @@ -3849,14 +3854,17 @@ HQSession::HQStreamTransport::newWebTransportUniStream() {
return folly::makeUnexpected(
WebTransport::ErrorCode::STREAM_CREATION_ERROR);
}
if (!writeWTStreamPrefaceToSock(*session_.sock_,
*id,
getEgressStreamId(),
hq::WebTransportStreamType::UNI)) {
auto numPrefaceBytesWritten =
writeWTStreamPrefaceToSock(*session_.sock_,
*id,
getEgressStreamId(),
hq::WebTransportStreamType::UNI);
if (numPrefaceBytesWritten == 0) {
LOG(ERROR) << "Failed to write unidirectional stream preface";
return folly::makeUnexpected(
WebTransport::ErrorCode::STREAM_CREATION_ERROR);
}
streamIdToPrefaceSize_[*id] = numPrefaceBytesWritten;
return *id;
}

Expand All @@ -3865,8 +3873,17 @@ HQSession::HQStreamTransport::sendWebTransportStreamData(
HTTPCodec::StreamID id,
std::unique_ptr<folly::IOBuf> data,
bool eof,
WebTransport::DeliveryCallback* /* deliveryCallback */) {
auto res = session_.sock_->writeChain(id, std::move(data), eof);
WebTransportImpl::DeliveryCallback* deliveryCallback) {
std::unique_ptr<QuicWtDeliveryCallbackWrapper> deliveryCallbackWrapper =
nullptr;
if (deliveryCallback) {
uint32_t prefaceSize =
streamIdToPrefaceSize_.contains(id) ? streamIdToPrefaceSize_[id] : 0;
deliveryCallbackWrapper = std::make_unique<QuicWtDeliveryCallbackWrapper>(
deliveryCallback, prefaceSize);
}
auto res = session_.sock_->writeChain(
id, std::move(data), eof, deliveryCallbackWrapper.release());
if (res.hasError()) {
LOG(ERROR) << "Failed to write WT stream data";
return folly::makeUnexpected(WebTransport::ErrorCode::SEND_ERROR);
Expand Down Expand Up @@ -3895,6 +3912,7 @@ HQSession::HQStreamTransport::notifyPendingWriteOnStream(
folly::Expected<folly::Unit, WebTransport::ErrorCode>
HQSession::HQStreamTransport::resetWebTransportEgress(HTTPCodec::StreamID id,
uint32_t errorCode) {
streamIdToPrefaceSize_.erase(id);
if (session_.sock_) {
auto res = session_.sock_->resetStream(
id,
Expand Down
7 changes: 6 additions & 1 deletion proxygen/lib/http/session/HQSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,7 @@ class HQSession
HTTPCodec::StreamID /*id*/,
std::unique_ptr<folly::IOBuf> /*data*/,
bool /*eof*/,
WebTransport::DeliveryCallback* /* deliveryCallback */) override;
WebTransportImpl::DeliveryCallback* /* deliveryCallback */) override;

folly::Expected<folly::Unit, WebTransport::ErrorCode>
notifyPendingWriteOnStream(HTTPCodec::StreamID,
Expand Down Expand Up @@ -1884,6 +1884,11 @@ class HQSession
HTTPCodec::StreamID /*id*/,
folly::Optional<uint32_t> /*errorCode*/) override;

private:
// We keep track of this so that we know how many bytes we need to subtract
// off when we fire delivery callbacks.
folly::F14FastMap<HTTPCodec::StreamID, uint32_t> streamIdToPrefaceSize_;

}; // HQStreamTransport

#ifdef _MSC_VER
Expand Down
16 changes: 14 additions & 2 deletions proxygen/lib/http/session/test/HQUpstreamSessionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2245,19 +2245,31 @@ class HQUpstreamSessionTestWebTransport : public HQUpstreamSessionTest {
WebTransport* wt_{nullptr};
};

class MockDeliveryCallback : public WebTransport::DeliveryCallback {
public:
MOCK_METHOD(void, onDelivery, (uint64_t, uint32_t), (noexcept));

MOCK_METHOD(void, onDeliveryCancelled, (uint64_t, uint32_t), (noexcept));
};

TEST_P(HQUpstreamSessionTestWebTransport, BidirectionalStream) {
InSequence enforceOrder;
// Create a bidi WT stream
auto stream = wt_->createBidiStream().value();
auto id = stream.readHandle->getID();
// small write
stream.writeHandle->writeStreamData(makeBuf(10), false, nullptr);
auto mockCallback1 = std::make_unique<StrictMock<MockDeliveryCallback>>();
EXPECT_CALL(*mockCallback1, onDelivery(id, 10)).Times(1);
stream.writeHandle->writeStreamData(makeBuf(10), false, mockCallback1.get());
eventBase_.loopOnce();

// shrink the fcw to force it to block
socketDriver_->setStreamFlowControlWindow(id, 100);
bool writeComplete = false;
stream.writeHandle->writeStreamData(makeBuf(65536), false, nullptr);
auto mockCallback2 = std::make_unique<StrictMock<MockDeliveryCallback>>();
EXPECT_CALL(*mockCallback2, onDelivery(id, 65536 + 10)).Times(1);
stream.writeHandle->writeStreamData(
makeBuf(65536), false, mockCallback2.get());
stream.writeHandle->awaitWritable().value().via(&eventBase_).then([&](auto) {
VLOG(4) << "big write complete";
// after it completes, write FIN
Expand Down

0 comments on commit 81eb4cb

Please sign in to comment.