Skip to content

Commit

Permalink
Reinitialize the column writers when flushing a row group
Browse files Browse the repository at this point in the history
  • Loading branch information
findinpath committed Jan 7, 2022
1 parent 44ab5fc commit 682eb69
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,4 @@ public long getRetainedBytes()
{
return INSTANCE_SIZE + elementWriter.getRetainedBytes();
}

@Override
public void reset()
{
elementWriter.reset();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ List<BufferData> getBuffer()

long getRetainedBytes();

void reset();

class BufferData
{
private final ColumnMetaData metaData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,4 @@ public long getRetainedBytes()
{
return INSTANCE_SIZE + keyWriter.getRetainedBytes() + valueWriter.getRetainedBytes();
}

@Override
public void reset()
{
keyWriter.reset();
valueWriter.reset();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@ public class ParquetWriter

private static final int CHUNK_MAX_BYTES = toIntExact(DataSize.of(128, MEGABYTE).toBytes());

private final List<ColumnWriter> columnWriters;
private final OutputStreamSliceOutput outputStream;
private final ParquetWriterOptions writerOption;
private final MessageType messageType;
private final String createdBy;
private final int chunkMaxLogicalBytes;
private final Map<List<String>, Type> primitiveTypes;
private final CompressionCodecName compressionCodecName;

private final ImmutableList.Builder<RowGroup> rowGroupBuilder = ImmutableList.builder();

private List<ColumnWriter> columnWriters;
private int rows;
private long bufferedBytes;
private boolean closed;
Expand All @@ -85,17 +87,10 @@ public ParquetWriter(
{
this.outputStream = new OutputStreamSliceOutput(requireNonNull(outputStream, "outputstream is null"));
this.messageType = requireNonNull(messageType, "messageType is null");
requireNonNull(primitiveTypes, "primitiveTypes is null");
this.primitiveTypes = requireNonNull(primitiveTypes, "primitiveTypes is null");
this.writerOption = requireNonNull(writerOption, "writerOption is null");
requireNonNull(compressionCodecName, "compressionCodecName is null");

ParquetProperties parquetProperties = ParquetProperties.builder()
.withWriterVersion(PARQUET_1_0)
.withPageSize(writerOption.getMaxPageSize())
.build();

this.columnWriters = ParquetWriters.getColumnWriters(messageType, primitiveTypes, parquetProperties, compressionCodecName);

this.compressionCodecName = requireNonNull(compressionCodecName, "compressionCodecName is null");
initColumnWriters();
this.chunkMaxLogicalBytes = max(1, CHUNK_MAX_BYTES / 2);
this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null"));
}
Expand Down Expand Up @@ -164,7 +159,7 @@ private void writeChunk(Page page)
if (bufferedBytes >= writerOption.getMaxRowGroupSize()) {
columnWriters.forEach(ColumnWriter::close);
flush();
columnWriters.forEach(ColumnWriter::reset);
initColumnWriters();
rows = 0;
bufferedBytes = columnWriters.stream().mapToLong(ColumnWriter::getBufferedBytes).sum();
}
Expand Down Expand Up @@ -289,4 +284,14 @@ static String formatCreatedBy(String trinoVersion)
// Add "(build n/a)" suffix to satisfy Parquet's VersionParser expectations
return "Trino version " + trinoVersion + " (build n/a)";
}

private void initColumnWriters()
{
ParquetProperties parquetProperties = ParquetProperties.builder()
.withWriterVersion(PARQUET_1_0)
.withPageSize(writerOption.getMaxPageSize())
.build();

this.columnWriters = ParquetWriters.getColumnWriters(messageType, primitiveTypes, parquetProperties, compressionCodecName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,24 +306,4 @@ public long getRetainedBytes()
definitionLevelWriter.getAllocatedSize() +
repetitionLevelWriter.getAllocatedSize();
}

@Override
public void reset()
{
definitionLevelWriter.reset();
repetitionLevelWriter.reset();
primitiveValueWriter.reset();
pageBuffer.clear();
closed = false;

totalCompressedSize = 0;
totalUnCompressedSize = 0;
totalValues = 0;
encodings.clear();
dataPagesWithEncoding.clear();
dictionaryPagesWithEncoding.clear();
this.columnStatistics = Statistics.createStats(columnDescriptor.getPrimitiveType());

getDataStreamsCalled = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,4 @@ public long getRetainedBytes()
return INSTANCE_SIZE +
columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum();
}

@Override
public void reset()
{
columnWriters.forEach(ColumnWriter::reset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
package io.trino.plugin.iceberg;

import io.trino.Session;
import io.trino.testing.MaterializedResult;
import io.trino.testing.sql.TestTable;
import org.testng.annotations.Test;

import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.apache.iceberg.FileFormat.PARQUET;
import static org.testng.Assert.assertEquals;

public class TestIcebergParquetConnectorTest
extends BaseIcebergConnectorTest
Expand All @@ -39,6 +46,24 @@ protected boolean supportsRowGroupStatistics(String typeName)
typeName.equalsIgnoreCase("timestamp(6) with time zone"));
}

@Test
public void testRowGroupResetDictionary()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_row_group_reset_dictionary",
"(plain_col varchar, dict_col int)")) {
String tableName = table.getName();
String values = IntStream.range(0, 100)
.mapToObj(i -> "('ABCDEFGHIJ" + i + "' , " + (i < 20 ? "1" : "null") + ")")
.collect(Collectors.joining(", "));
assertUpdate(withSmallRowGroups(getSession()), "INSERT INTO " + tableName + " VALUES " + values, 100);

MaterializedResult result = getDistributedQueryRunner().execute(String.format("SELECT * FROM %s", tableName));
assertEquals(result.getRowCount(), 100);
}
}

@Override
protected Session withSmallRowGroups(Session session)
{
Expand Down

0 comments on commit 682eb69

Please sign in to comment.