diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java index 275540dea9fe..a46263d2f575 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ArrayColumnWriter.java @@ -83,10 +83,4 @@ public long getRetainedBytes() { return INSTANCE_SIZE + elementWriter.getRetainedBytes(); } - - @Override - public void reset() - { - elementWriter.reset(); - } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnWriter.java index b2ed20f0922e..b82d4a4942a1 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ColumnWriter.java @@ -34,8 +34,6 @@ List getBuffer() long getRetainedBytes(); - void reset(); - class BufferData { private final ColumnMetaData metaData; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java index 016bb00c04f8..8102259d9218 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/MapColumnWriter.java @@ -87,11 +87,4 @@ public long getRetainedBytes() { return INSTANCE_SIZE + keyWriter.getRetainedBytes() + valueWriter.getRetainedBytes(); } - - @Override - public void reset() - { - keyWriter.reset(); - valueWriter.reset(); - } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index a26bdf99dcdd..f785fb1d7e29 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -59,15 +59,17 @@ public class ParquetWriter private static final int CHUNK_MAX_BYTES = toIntExact(DataSize.of(128, MEGABYTE).toBytes()); - private final List columnWriters; private final OutputStreamSliceOutput outputStream; private final ParquetWriterOptions writerOption; private final MessageType messageType; private final String createdBy; private final int chunkMaxLogicalBytes; + private final Map, Type> primitiveTypes; + private final CompressionCodecName compressionCodecName; private final ImmutableList.Builder rowGroupBuilder = ImmutableList.builder(); + private List columnWriters; private int rows; private long bufferedBytes; private boolean closed; @@ -85,17 +87,10 @@ public ParquetWriter( { this.outputStream = new OutputStreamSliceOutput(requireNonNull(outputStream, "outputstream is null")); this.messageType = requireNonNull(messageType, "messageType is null"); - requireNonNull(primitiveTypes, "primitiveTypes is null"); + this.primitiveTypes = requireNonNull(primitiveTypes, "primitiveTypes is null"); this.writerOption = requireNonNull(writerOption, "writerOption is null"); - requireNonNull(compressionCodecName, "compressionCodecName is null"); - - ParquetProperties parquetProperties = ParquetProperties.builder() - .withWriterVersion(PARQUET_1_0) - .withPageSize(writerOption.getMaxPageSize()) - .build(); - - this.columnWriters = ParquetWriters.getColumnWriters(messageType, primitiveTypes, parquetProperties, compressionCodecName); - + this.compressionCodecName = requireNonNull(compressionCodecName, "compressionCodecName is null"); + initColumnWriters(); this.chunkMaxLogicalBytes = max(1, CHUNK_MAX_BYTES / 2); this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null")); } @@ -164,7 +159,7 @@ private void writeChunk(Page page) if (bufferedBytes >= writerOption.getMaxRowGroupSize()) { columnWriters.forEach(ColumnWriter::close); flush(); - columnWriters.forEach(ColumnWriter::reset); + initColumnWriters(); rows = 0; bufferedBytes = columnWriters.stream().mapToLong(ColumnWriter::getBufferedBytes).sum(); } @@ -289,4 +284,14 @@ static String formatCreatedBy(String trinoVersion) // Add "(build n/a)" suffix to satisfy Parquet's VersionParser expectations return "Trino version " + trinoVersion + " (build n/a)"; } + + private void initColumnWriters() + { + ParquetProperties parquetProperties = ParquetProperties.builder() + .withWriterVersion(PARQUET_1_0) + .withPageSize(writerOption.getMaxPageSize()) + .build(); + + this.columnWriters = ParquetWriters.getColumnWriters(messageType, primitiveTypes, parquetProperties, compressionCodecName); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java index cca70680901d..f1996a7029b4 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java @@ -306,24 +306,4 @@ public long getRetainedBytes() definitionLevelWriter.getAllocatedSize() + repetitionLevelWriter.getAllocatedSize(); } - - @Override - public void reset() - { - definitionLevelWriter.reset(); - repetitionLevelWriter.reset(); - primitiveValueWriter.reset(); - pageBuffer.clear(); - closed = false; - - totalCompressedSize = 0; - totalUnCompressedSize = 0; - totalValues = 0; - encodings.clear(); - dataPagesWithEncoding.clear(); - dictionaryPagesWithEncoding.clear(); - this.columnStatistics = Statistics.createStats(columnDescriptor.getPrimitiveType()); - - getDataStreamsCalled = false; - } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java index 53860c85679e..aa84ceda0300 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/StructColumnWriter.java @@ -95,10 +95,4 @@ public long getRetainedBytes() return INSTANCE_SIZE + columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum(); } - - @Override - public void reset() - { - columnWriters.forEach(ColumnWriter::reset); - } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java index 100b134d95e4..2fc53800aa49 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java @@ -14,8 +14,15 @@ package io.trino.plugin.iceberg; import io.trino.Session; +import io.trino.testing.MaterializedResult; +import io.trino.testing.sql.TestTable; +import org.testng.annotations.Test; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.iceberg.FileFormat.PARQUET; +import static org.testng.Assert.assertEquals; public class TestIcebergParquetConnectorTest extends BaseIcebergConnectorTest @@ -39,6 +46,24 @@ protected boolean supportsRowGroupStatistics(String typeName) typeName.equalsIgnoreCase("timestamp(6) with time zone")); } + @Test + public void testRowGroupResetDictionary() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_row_group_reset_dictionary", + "(plain_col varchar, dict_col int)")) { + String tableName = table.getName(); + String values = IntStream.range(0, 100) + .mapToObj(i -> "('ABCDEFGHIJ" + i + "' , " + (i < 20 ? "1" : "null") + ")") + .collect(Collectors.joining(", ")); + assertUpdate(withSmallRowGroups(getSession()), "INSERT INTO " + tableName + " VALUES " + values, 100); + + MaterializedResult result = getDistributedQueryRunner().execute(String.format("SELECT * FROM %s", tableName)); + assertEquals(result.getRowCount(), 100); + } + } + @Override protected Session withSmallRowGroups(Session session) {