diff --git a/velox/common/memory/ByteStream.cpp b/velox/common/memory/ByteStream.cpp index 87db336f379d..6b370e3af133 100644 --- a/velox/common/memory/ByteStream.cpp +++ b/velox/common/memory/ByteStream.cpp @@ -51,6 +51,17 @@ void ByteStream::seekp(std::streampos position) { VELOX_FAIL("Seeking past end of ByteStream: {}", position); } +size_t ByteStream::flushSize() { + updateEnd(); + size_t size = 0; + for (int32_t i = 0; i < ranges_.size(); ++i) { + int32_t count = i == ranges_.size() - 1 ? lastRangeEnd_ : ranges_[i].size; + int32_t bytes = isBits_ ? bits::nbytes(count) : count; + size += bytes; + } + return size; +} + void ByteStream::flush(OutputStream* out) { updateEnd(); for (int32_t i = 0; i < ranges_.size(); ++i) { diff --git a/velox/common/memory/ByteStream.h b/velox/common/memory/ByteStream.h index 522bcc0cb646..55b5c24ab700 100644 --- a/velox/common/memory/ByteStream.h +++ b/velox/common/memory/ByteStream.h @@ -339,6 +339,8 @@ class ByteStream { append(folly::Range(&value, 1)); } + size_t flushSize(); + void flush(OutputStream* stream); // Returns the next byte that would be written to by a write. This diff --git a/velox/serializers/PrestoSerializer.cpp b/velox/serializers/PrestoSerializer.cpp index b00ff662e8b0..03ed89b619af 100644 --- a/velox/serializers/PrestoSerializer.cpp +++ b/velox/serializers/PrestoSerializer.cpp @@ -827,6 +827,84 @@ class VectorStream { return children_[index].get(); } + // similiar as flush + vector_size_t maxSerializedSize() { + vector_size_t size = 0; + size += header_.size; + switch (type_->kind()) { + case TypeKind::ROW: + if (isTimestampWithTimeZoneType(type_)) { + size += sizeof(int32_t); + size += nullsSize(); + size += values_.size(); + return size; + } + size += sizeof(int32_t); + + for (auto& child : children_) { + size += child->maxSerializedSize(); + } + size += sizeof(int32_t); + if (nullCount_ + nonNullCount_ == 0) { + size += sizeof(int32_t); + // If nothing was added, there is still one offset in the wire format. + } + size += lengths_.size(); + size += nullsSize(); + return size; + + case TypeKind::ARRAY: + size += children_[0]->maxSerializedSize(); + size += sizeof(int32_t); + if (nullCount_ + nonNullCount_ == 0) { + // If nothing was added, there is still one offset in the wire format. + size += sizeof(int32_t); + } + size += lengths_.size(); + size += nullsSize(); + return size; + + case TypeKind::MAP: { + size += children_[0]->maxSerializedSize(); + size += children_[1]->maxSerializedSize(); + // hash table size. -1 means not included in serialization. + size += sizeof(int32_t); + size += sizeof(int32_t); + if (nullCount_ + nonNullCount_ == 0) { + size += sizeof(int32_t); + // If nothing was added, there is still one offset in the wire format. + } + size += lengths_.size(); + size += nullsSize(); + return size; + } + + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: + size += sizeof(int32_t); + size += lengths_.size(); + size += nullsSize(); + size += sizeof(int32_t); + size += values_.size(); + return size; + + default: + size += sizeof(int32_t); + size += nullsSize(); + size += values_.size(); + } + return size; + } + + size_t nullsSize() { + size_t size = 0; + size += 1; + if (nullCount_) { + size += nulls_.flushSize(); + } + return size; + } + // Writes out the accumulated contents. Does not change the state. void flush(OutputStream* out) { out->write(reinterpret_cast(header_.buffer), header_.size); @@ -1604,6 +1682,15 @@ class PrestoVectorSerializer : public VectorSerializer { } } + vector_size_t maxSerializedSize() override { + vector_size_t size = 0; + for (auto& stream : streams_) { + size += stream->maxSerializedSize(); + } + size += 25; /* flush header layout size */ + return size; + } + void flush(OutputStream* out) override { flushInternal(numRows_, false /*rle*/, out); } @@ -1662,9 +1749,12 @@ class PrestoVectorSerializer : public VectorSerializer { writeInt32(out, numRows); } + // std::cout << "before stream " << out->tellp() << std::endl; + for (auto& stream : streams_) { stream->flush(out); } + // std::cout << "after stream " << out->tellp() << std::endl; // Pause CRC computation if (listener) { diff --git a/velox/serializers/tests/PrestoSerializerTest.cpp b/velox/serializers/tests/PrestoSerializerTest.cpp index 5b8a546dbca0..8622e13ef435 100644 --- a/velox/serializers/tests/PrestoSerializerTest.cpp +++ b/velox/serializers/tests/PrestoSerializerTest.cpp @@ -49,7 +49,8 @@ class PrestoSerializerTest : public ::testing::Test { void serialize( const RowVectorPtr& rowVector, std::ostream* output, - const VectorSerde::Options* serdeOptions) { + const VectorSerde::Options* serdeOptions, + bool testSerializedSize = true) { auto numRows = rowVector->size(); std::vector rows(numRows); @@ -66,9 +67,45 @@ class PrestoSerializerTest : public ::testing::Test { serde_->createSerializer(rowType, numRows, arena.get(), serdeOptions); serializer->append(rowVector, folly::Range(rows.data(), numRows)); + vector_size_t size = serializer->maxSerializedSize(); facebook::velox::serializer::presto::PrestoOutputStreamListener listener; OStreamOutputStream out(output, &listener); serializer->flush(&out); + if (testSerializedSize) { + ASSERT_EQ(size, out.tellp()); + } + } + + void serialize( + const std::vector& rowVector, + std::ostream* output, + const VectorSerde::Options* serdeOptions, + bool testSerializedSize = true) { + auto numRows = rowVector[0]->size(); + + std::vector rows(numRows); + for (int i = 0; i < numRows; i++) { + rows[i] = IndexRange{i, 1}; + } + + sanityCheckEstimateSerializedSize( + rowVector[0], folly::Range(rows.data(), numRows)); + + auto arena = std::make_unique(pool_.get()); + auto rowType = asRowType(rowVector[0]->type()); + auto serializer = + serde_->createSerializer(rowType, numRows, arena.get(), serdeOptions); + + for (auto& vector : rowVector) { + serializer->append(vector, folly::Range(rows.data(), numRows)); + } + vector_size_t size = serializer->maxSerializedSize(); + facebook::velox::serializer::presto::PrestoOutputStreamListener listener; + OStreamOutputStream out(output, &listener); + serializer->flush(&out); + if (testSerializedSize) { + ASSERT_EQ(size, out.tellp()); + } } void serializeRle( @@ -126,6 +163,24 @@ class PrestoSerializerTest : public ::testing::Test { assertEqualVectors(deserialized, rowVector); } + void testRoundTripMulti( + std::vector vectors, + const VectorSerde::Options* serdeOptions = nullptr) { + std::vector rowVectors; + for (auto& vector : vectors) { + // std::cout << "vector to ser" << vector->toString(0, 10) << std::endl; + auto rowVector = vectorMaker_->rowVector({vector}); + rowVectors.emplace_back(rowVector); + } + std::ostringstream out; + serialize(rowVectors, &out, serdeOptions); + + auto rowType = asRowType(rowVectors[0]->type()); + auto deserialized = deserialize(rowType, out.str(), serdeOptions); + // std::cout << "dese" << deserialized->toString(0, 10) << std::endl; + // assertEqualVectors(deserialized, rowVector); + } + void testRleRoundTrip(const VectorPtr& constantVector) { auto rowVector = vectorMaker_->rowVector({constantVector}); std::ostringstream out; @@ -148,6 +203,12 @@ TEST_F(PrestoSerializerTest, basic) { testRoundTrip(rowVector); } +TEST_F(PrestoSerializerTest, appendMulti) { + vector_size_t numRows = 2; + auto rowVector = makeTestVector(numRows); + testRoundTripMulti({rowVector, rowVector}); +} + /// Test serialization of a dictionary vector that adds nulls to the base /// vector. TEST_F(PrestoSerializerTest, dictionaryWithExtraNulls) { @@ -264,11 +325,11 @@ TEST_F(PrestoSerializerTest, multiPage) { // page 2 auto b = makeTestVector(538); - serialize(b, &out, nullptr); + serialize(b, &out, nullptr, false); // page 3 auto c = makeTestVector(2'048); - serialize(c, &out, nullptr); + serialize(c, &out, nullptr, false); auto bytes = out.str();