diff --git a/aircompressor-LICENSE b/aircompressor-LICENSE new file mode 100644 index 0000000..51fca54 --- /dev/null +++ b/aircompressor-LICENSE @@ -0,0 +1,11 @@ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/build.gradle b/build.gradle index cd1578d..b9be241 100644 --- a/build.gradle +++ b/build.gradle @@ -81,6 +81,9 @@ license { matching('**/io/github/steveice10/opennbt/**') { header = file('opennbt-LICENSE') } + matching('**/io/airlift/compress/**') { + header = file('aircompressor-LICENSE') + } } tasks.build.dependsOn proguardJar \ No newline at end of file diff --git a/src/main/java/com/unascribed/nbted/CommandProcessor.java b/src/main/java/com/unascribed/nbted/CommandProcessor.java index fdcf2ff..03c5ad6 100644 --- a/src/main/java/com/unascribed/nbted/CommandProcessor.java +++ b/src/main/java/com/unascribed/nbted/CommandProcessor.java @@ -435,7 +435,7 @@ public CommandProcessor(NBTTag _root, TagPrinter _printer, FileInfo _fileInfo) { } else if (set.has("big-endian")) { endianness = Endianness.BIG; } else { - endianness = fileInfo.endianness; + endianness = fileInfo.endianness == null ? Endianness.BIG : fileInfo.endianness; } Compression compression; if (set.has("compression")) { @@ -495,6 +495,16 @@ public CommandProcessor(NBTTag _root, TagPrinter _printer, FileInfo _fileInfo) { if (!prompt("You are saving an NBT file with a JSON extension. Are you sure you want to do this?", false)) { return; } + } else if (compression == Compression.ZSTD) { + if (outFile.getName().endsWith(".dat") || outFile.getName().endsWith(".nbt")) { + if (!prompt("You are saving a non-standard Zstd NBT file with a standard extension. Are you sure you want to do this?", true)) { + return; + } + } else if (!outFile.getName().endsWith(".zat") && !outFile.getName().endsWith(".znbt")) { + if (!prompt("You are saving a non-standard Zstd NBT file with an unknown extension. Are you sure you want to do this?", true)) { + return; + } + } } else if (!outFile.getName().endsWith(".dat") && !outFile.getName().endsWith(".nbt")) { if (!prompt("You are saving an NBT file with a nonstandard extension. Are you sure you want to do this?", true)) { return; diff --git a/src/main/java/com/unascribed/nbted/Compression.java b/src/main/java/com/unascribed/nbted/Compression.java index 171cb06..23f9ad7 100644 --- a/src/main/java/com/unascribed/nbted/Compression.java +++ b/src/main/java/com/unascribed/nbted/Compression.java @@ -26,12 +26,17 @@ import java.util.zip.GZIPOutputStream; import java.util.zip.InflaterInputStream; +import io.airlift.compress.zstd.ZstdInputStream; +import io.airlift.compress.zstd.ZstdOutputStream; + public enum Compression { NONE("None"), DEFLATE("Deflate"), - GZIP("GZip"); + GZIP("GZip"), + ZSTD("ZStandard"), + ; private final String name; - private Compression(String name) { + Compression(String name) { this.name = name; } @@ -41,6 +46,7 @@ public InputStream wrap(InputStream is) throws IOException { case NONE: return is; case DEFLATE: return new InflaterInputStream(is); case GZIP: return new GZIPInputStream(is); + case ZSTD: return new ZstdInputStream(is); default: throw new AssertionError("missing case for "+this); } } @@ -51,6 +57,7 @@ public OutputStream wrap(OutputStream os) throws IOException { case NONE: return os; case DEFLATE: return new DeflaterOutputStream(os); case GZIP: return new GZIPOutputStream(os); + case ZSTD: return new ZstdOutputStream(os); default: throw new AssertionError("missing case for "+this); } } diff --git a/src/main/java/com/unascribed/nbted/NBTEd.java b/src/main/java/com/unascribed/nbted/NBTEd.java index 6532407..b4bd409 100644 --- a/src/main/java/com/unascribed/nbted/NBTEd.java +++ b/src/main/java/com/unascribed/nbted/NBTEd.java @@ -325,6 +325,8 @@ public static void main(String[] args) throws Exception { detectedCompressionMethod = Compression.GZIP; } else if (magic8 == 0x78) { detectedCompressionMethod = Compression.DEFLATE; + } else if (magic16 == 0xb528) { + detectedCompressionMethod = Compression.ZSTD; } else { detectedCompressionMethod = Compression.NONE; } diff --git a/src/main/java/io/airlift/compress/Compressor.java b/src/main/java/io/airlift/compress/Compressor.java new file mode 100644 index 0000000..a62aebc --- /dev/null +++ b/src/main/java/io/airlift/compress/Compressor.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress; + +import java.nio.ByteBuffer; + +public interface Compressor +{ + int maxCompressedLength(int uncompressedSize); + + /** + * @return number of bytes written to the output + */ + int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength); + + void compress(ByteBuffer input, ByteBuffer output); +} diff --git a/src/main/java/io/airlift/compress/Decompressor.java b/src/main/java/io/airlift/compress/Decompressor.java new file mode 100644 index 0000000..6665a90 --- /dev/null +++ b/src/main/java/io/airlift/compress/Decompressor.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress; + +import java.nio.ByteBuffer; + +public interface Decompressor +{ + /** + * @return number of bytes written to the output + */ + int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) + throws MalformedInputException; + + void decompress(ByteBuffer input, ByteBuffer output) + throws MalformedInputException; +} diff --git a/src/main/java/io/airlift/compress/IncompatibleJvmException.java b/src/main/java/io/airlift/compress/IncompatibleJvmException.java new file mode 100644 index 0000000..c721346 --- /dev/null +++ b/src/main/java/io/airlift/compress/IncompatibleJvmException.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress; + +public class IncompatibleJvmException + extends RuntimeException +{ + public IncompatibleJvmException(String message) + { + super(message); + } +} diff --git a/src/main/java/io/airlift/compress/MalformedInputException.java b/src/main/java/io/airlift/compress/MalformedInputException.java new file mode 100644 index 0000000..7125b04 --- /dev/null +++ b/src/main/java/io/airlift/compress/MalformedInputException.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress; + +public class MalformedInputException + extends RuntimeException +{ + private final long offset; + + public MalformedInputException(long offset) + { + this(offset, "Malformed input"); + } + + public MalformedInputException(long offset, String reason) + { + super(reason + ": offset=" + offset); + this.offset = offset; + } + + public long getOffset() + { + return offset; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/BitInputStream.java b/src/main/java/io/airlift/compress/zstd/BitInputStream.java new file mode 100644 index 0000000..aa59fd6 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/BitInputStream.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.highestBit; +import static io.airlift.compress.zstd.Util.verify; + +/** + * Bit streams are encoded as a byte-aligned little-endian stream. Thus, bits are laid out + * in the following manner, and the stream is read from right to left. + *

+ *

+ * ... [16 17 18 19 20 21 22 23] [8 9 10 11 12 13 14 15] [0 1 2 3 4 5 6 7] + */ +class BitInputStream +{ + private BitInputStream() + { + } + + public static boolean isEndOfStream(long startAddress, long currentAddress, int bitsConsumed) + { + return startAddress == currentAddress && bitsConsumed == Long.SIZE; + } + + static long readTail(Object inputBase, long inputAddress, int inputSize) + { + long bits = UNSAFE.getByte(inputBase, inputAddress) & 0xFF; + + switch (inputSize) { + case 7: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 6) & 0xFFL) << 48; + case 6: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 5) & 0xFFL) << 40; + case 5: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 4) & 0xFFL) << 32; + case 4: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 3) & 0xFFL) << 24; + case 3: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 2) & 0xFFL) << 16; + case 2: + bits |= (UNSAFE.getByte(inputBase, inputAddress + 1) & 0xFFL) << 8; + } + + return bits; + } + + /** + * @return numberOfBits in the low order bits of a long + */ + public static long peekBits(int bitsConsumed, long bitContainer, int numberOfBits) + { + return (((bitContainer << bitsConsumed) >>> 1) >>> (63 - numberOfBits)); + } + + /** + * numberOfBits must be > 0 + * + * @return numberOfBits in the low order bits of a long + */ + public static long peekBitsFast(int bitsConsumed, long bitContainer, int numberOfBits) + { + return ((bitContainer << bitsConsumed) >>> (64 - numberOfBits)); + } + + static class Initializer + { + private final Object inputBase; + private final long startAddress; + private final long endAddress; + private long bits; + private long currentAddress; + private int bitsConsumed; + + public Initializer(Object inputBase, long startAddress, long endAddress) + { + this.inputBase = inputBase; + this.startAddress = startAddress; + this.endAddress = endAddress; + } + + public long getBits() + { + return bits; + } + + public long getCurrentAddress() + { + return currentAddress; + } + + public int getBitsConsumed() + { + return bitsConsumed; + } + + public void initialize() + { + verify(endAddress - startAddress >= 1, startAddress, "Bitstream is empty"); + + int lastByte = UNSAFE.getByte(inputBase, endAddress - 1) & 0xFF; + verify(lastByte != 0, endAddress, "Bitstream end mark not present"); + + bitsConsumed = SIZE_OF_LONG - highestBit(lastByte); + + int inputSize = (int) (endAddress - startAddress); + if (inputSize >= SIZE_OF_LONG) { /* normal case */ + currentAddress = endAddress - SIZE_OF_LONG; + bits = UNSAFE.getLong(inputBase, currentAddress); + } + else { + currentAddress = startAddress; + bits = readTail(inputBase, startAddress, inputSize); + + bitsConsumed += (SIZE_OF_LONG - inputSize) * 8; + } + } + } + + static final class Loader + { + private final Object inputBase; + private final long startAddress; + private long bits; + private long currentAddress; + private int bitsConsumed; + private boolean overflow; + + public Loader(Object inputBase, long startAddress, long currentAddress, long bits, int bitsConsumed) + { + this.inputBase = inputBase; + this.startAddress = startAddress; + this.bits = bits; + this.currentAddress = currentAddress; + this.bitsConsumed = bitsConsumed; + } + + public long getBits() + { + return bits; + } + + public long getCurrentAddress() + { + return currentAddress; + } + + public int getBitsConsumed() + { + return bitsConsumed; + } + + public boolean isOverflow() + { + return overflow; + } + + public boolean load() + { + if (bitsConsumed > 64) { + overflow = true; + return true; + } + + else if (currentAddress == startAddress) { + return true; + } + + int bytes = bitsConsumed >>> 3; // divide by 8 + if (currentAddress >= startAddress + SIZE_OF_LONG) { + if (bytes > 0) { + currentAddress -= bytes; + bits = UNSAFE.getLong(inputBase, currentAddress); + } + bitsConsumed &= 0b111; + } + else if (currentAddress - bytes < startAddress) { + bytes = (int) (currentAddress - startAddress); + currentAddress = startAddress; + bitsConsumed -= bytes * SIZE_OF_LONG; + bits = UNSAFE.getLong(inputBase, startAddress); + return true; + } + else { + currentAddress -= bytes; + bitsConsumed -= bytes * SIZE_OF_LONG; + bits = UNSAFE.getLong(inputBase, currentAddress); + } + + return false; + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/BitOutputStream.java b/src/main/java/io/airlift/compress/zstd/BitOutputStream.java new file mode 100644 index 0000000..65dad1a --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/BitOutputStream.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; + +class BitOutputStream +{ + private static final long[] BIT_MASK = { + 0x0, 0x1, 0x3, 0x7, 0xF, 0x1F, + 0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF, + 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF, + 0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF, + 0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF, + 0x3FFFFFFF, 0x7FFFFFFF}; // up to 31 bits + + private final Object outputBase; + private final long outputAddress; + private final long outputLimit; + + private long container; + private int bitCount; + private long currentAddress; + + public BitOutputStream(Object outputBase, long outputAddress, int outputSize) + { + checkArgument(outputSize >= SIZE_OF_LONG, "Output buffer too small"); + + this.outputBase = outputBase; + this.outputAddress = outputAddress; + outputLimit = this.outputAddress + outputSize - SIZE_OF_LONG; + + currentAddress = this.outputAddress; + } + + public void addBits(int value, int bits) + { + container |= (value & BIT_MASK[bits]) << bitCount; + bitCount += bits; + } + + /** + * Note: leading bits of value must be 0 + */ + public void addBitsFast(int value, int bits) + { + container |= ((long) value) << bitCount; + bitCount += bits; + } + + public void flush() + { + int bytes = bitCount >>> 3; + + UNSAFE.putLong(outputBase, currentAddress, container); + currentAddress += bytes; + + if (currentAddress > outputLimit) { + currentAddress = outputLimit; + } + + bitCount &= 7; + container >>>= bytes * 8; + } + + public int close() + { + addBitsFast(1, 1); // end mark + flush(); + + if (currentAddress >= outputLimit) { + return 0; + } + + return (int) ((currentAddress - outputAddress) + (bitCount > 0 ? 1 : 0)); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/BlockCompressionState.java b/src/main/java/io/airlift/compress/zstd/BlockCompressionState.java new file mode 100644 index 0000000..b1caea0 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/BlockCompressionState.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +class BlockCompressionState +{ + public final int[] hashTable; + public final int[] chainTable; + + private final long baseAddress; + + // starting point of the window with respect to baseAddress + private int windowBaseOffset; + + public BlockCompressionState(CompressionParameters parameters, long baseAddress) + { + this.baseAddress = baseAddress; + hashTable = new int[1 << parameters.getHashLog()]; + chainTable = new int[1 << parameters.getChainLog()]; // TODO: chain table not used by Strategy.FAST + } + + public void slideWindow(int slideWindowSize) + { + for (int i = 0; i < hashTable.length; i++) { + int newValue = hashTable[i] - slideWindowSize; + // if new value is negative, set it to zero branchless + newValue = newValue & (~(newValue >> 31)); + hashTable[i] = newValue; + } + for (int i = 0; i < chainTable.length; i++) { + int newValue = chainTable[i] - slideWindowSize; + // if new value is negative, set it to zero branchless + newValue = newValue & (~(newValue >> 31)); + chainTable[i] = newValue; + } + } + + public void reset() + { + Arrays.fill(hashTable, 0); + Arrays.fill(chainTable, 0); + } + + public void enforceMaxDistance(long inputLimit, int maxDistance) + { + int distance = (int) (inputLimit - baseAddress); + + int newOffset = distance - maxDistance; + if (windowBaseOffset < newOffset) { + windowBaseOffset = newOffset; + } + } + + public long getBaseAddress() + { + return baseAddress; + } + + public int getWindowBaseOffset() + { + return windowBaseOffset; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/BlockCompressor.java b/src/main/java/io/airlift/compress/zstd/BlockCompressor.java new file mode 100644 index 0000000..3a327cc --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/BlockCompressor.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +interface BlockCompressor +{ + BlockCompressor UNSUPPORTED = (inputBase, inputAddress, inputSize, sequenceStore, blockCompressionState, offsets, parameters) -> { throw new UnsupportedOperationException(); }; + + int compressBlock(Object inputBase, long inputAddress, int inputSize, SequenceStore output, BlockCompressionState state, RepeatedOffsets offsets, CompressionParameters parameters); +} diff --git a/src/main/java/io/airlift/compress/zstd/CompressionContext.java b/src/main/java/io/airlift/compress/zstd/CompressionContext.java new file mode 100644 index 0000000..bb4437c --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/CompressionContext.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; +import static io.airlift.compress.zstd.Util.checkArgument; + +class CompressionContext +{ + public final CompressionParameters parameters; + public final RepeatedOffsets offsets = new RepeatedOffsets(); + public final BlockCompressionState blockCompressionState; + public final SequenceStore sequenceStore; + + public final SequenceEncodingContext sequenceEncodingContext = new SequenceEncodingContext(); + + public final HuffmanCompressionContext huffmanContext = new HuffmanCompressionContext(); + + public CompressionContext(CompressionParameters parameters, long baseAddress, int inputSize) + { + this.parameters = parameters; + + int windowSize = Math.max(1, Math.min(parameters.getWindowSize(), inputSize)); + int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize); + int divider = (parameters.getSearchLength() == 3) ? 3 : 4; + + int maxSequences = blockSize / divider; + + sequenceStore = new SequenceStore(blockSize, maxSequences); + + blockCompressionState = new BlockCompressionState(parameters, baseAddress); + } + + public void slideWindow(int slideWindowSize) + { + checkArgument(slideWindowSize > 0, "slideWindowSize must be positive"); + blockCompressionState.slideWindow(slideWindowSize); + } + + public void commit() + { + offsets.commit(); + huffmanContext.saveChanges(); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/CompressionParameters.java b/src/main/java/io/airlift/compress/zstd/CompressionParameters.java new file mode 100644 index 0000000..bfdc62f --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/CompressionParameters.java @@ -0,0 +1,325 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; +import static io.airlift.compress.zstd.Constants.MAX_WINDOW_LOG; +import static io.airlift.compress.zstd.Constants.MIN_WINDOW_LOG; +import static io.airlift.compress.zstd.Util.cycleLog; +import static io.airlift.compress.zstd.Util.highestBit; + +class CompressionParameters +{ + private static final int MIN_HASH_LOG = 6; + + public static final int DEFAULT_COMPRESSION_LEVEL = 3; + private static final int MAX_COMPRESSION_LEVEL = 22; + + private final int windowLog; // largest match distance : larger == more compression, more memory needed during decompression + private final int windowSize; // computed: 1 << windowLog + private final int blockSize; // computed: min(MAX_BLOCK_SIZE, windowSize) + private final int chainLog; // fully searched segment : larger == more compression, slower, more memory (useless for fast) + private final int hashLog; // dispatch table : larger == faster, more memory + private final int searchLog; // nb of searches : larger == more compression, slower + private final int searchLength; // match length searched : larger == faster decompression, sometimes less compression + private final int targetLength; // acceptable match size for optimal parser (only) : larger == more compression, slower + private final Strategy strategy; + + private static final CompressionParameters[][] DEFAULT_COMPRESSION_PARAMETERS = new CompressionParameters[][] { + { + // default + new CompressionParameters(19, 12, 13, 1, 6, 1, Strategy.FAST), /* base for negative levels */ + new CompressionParameters(19, 13, 14, 1, 7, 0, Strategy.FAST), /* level 1 */ + new CompressionParameters(19, 15, 16, 1, 6, 0, Strategy.FAST), /* level 2 */ + new CompressionParameters(20, 16, 17, 1, 5, 1, Strategy.DFAST), /* level 3 */ + new CompressionParameters(20, 18, 18, 1, 5, 1, Strategy.DFAST), /* level 4 */ + new CompressionParameters(20, 18, 18, 2, 5, 2, Strategy.GREEDY), /* level 5 */ + new CompressionParameters(21, 18, 19, 2, 5, 4, Strategy.LAZY), /* level 6 */ + new CompressionParameters(21, 18, 19, 3, 5, 8, Strategy.LAZY2), /* level 7 */ + new CompressionParameters(21, 19, 19, 3, 5, 16, Strategy.LAZY2), /* level 8 */ + new CompressionParameters(21, 19, 20, 4, 5, 16, Strategy.LAZY2), /* level 9 */ + new CompressionParameters(21, 20, 21, 4, 5, 16, Strategy.LAZY2), /* level 10 */ + new CompressionParameters(21, 21, 22, 4, 5, 16, Strategy.LAZY2), /* level 11 */ + new CompressionParameters(22, 20, 22, 5, 5, 16, Strategy.LAZY2), /* level 12 */ + new CompressionParameters(22, 21, 22, 4, 5, 32, Strategy.BTLAZY2), /* level 13 */ + new CompressionParameters(22, 21, 22, 5, 5, 32, Strategy.BTLAZY2), /* level 14 */ + new CompressionParameters(22, 22, 22, 6, 5, 32, Strategy.BTLAZY2), /* level 15 */ + new CompressionParameters(22, 21, 22, 4, 5, 48, Strategy.BTOPT), /* level 16 */ + new CompressionParameters(23, 22, 22, 4, 4, 64, Strategy.BTOPT), /* level 17 */ + new CompressionParameters(23, 23, 22, 6, 3, 256, Strategy.BTOPT), /* level 18 */ + new CompressionParameters(23, 24, 22, 7, 3, 256, Strategy.BTULTRA), /* level 19 */ + new CompressionParameters(25, 25, 23, 7, 3, 256, Strategy.BTULTRA), /* level 20 */ + new CompressionParameters(26, 26, 24, 7, 3, 512, Strategy.BTULTRA), /* level 21 */ + new CompressionParameters(27, 27, 25, 9, 3, 999, Strategy.BTULTRA) /* level 22 */ + }, + { + // for size <= 256 KB + new CompressionParameters(18, 12, 13, 1, 5, 1, Strategy.FAST), /* base for negative levels */ + new CompressionParameters(18, 13, 14, 1, 6, 0, Strategy.FAST), /* level 1 */ + new CompressionParameters(18, 14, 14, 1, 5, 1, Strategy.DFAST), /* level 2 */ + new CompressionParameters(18, 16, 16, 1, 4, 1, Strategy.DFAST), /* level 3 */ + new CompressionParameters(18, 16, 17, 2, 5, 2, Strategy.GREEDY), /* level 4.*/ + new CompressionParameters(18, 18, 18, 3, 5, 2, Strategy.GREEDY), /* level 5.*/ + new CompressionParameters(18, 18, 19, 3, 5, 4, Strategy.LAZY), /* level 6.*/ + new CompressionParameters(18, 18, 19, 4, 4, 4, Strategy.LAZY), /* level 7 */ + new CompressionParameters(18, 18, 19, 4, 4, 8, Strategy.LAZY2), /* level 8 */ + new CompressionParameters(18, 18, 19, 5, 4, 8, Strategy.LAZY2), /* level 9 */ + new CompressionParameters(18, 18, 19, 6, 4, 8, Strategy.LAZY2), /* level 10 */ + new CompressionParameters(18, 18, 19, 5, 4, 16, Strategy.BTLAZY2), /* level 11.*/ + new CompressionParameters(18, 19, 19, 6, 4, 16, Strategy.BTLAZY2), /* level 12.*/ + new CompressionParameters(18, 19, 19, 8, 4, 16, Strategy.BTLAZY2), /* level 13 */ + new CompressionParameters(18, 18, 19, 4, 4, 24, Strategy.BTOPT), /* level 14.*/ + new CompressionParameters(18, 18, 19, 4, 3, 24, Strategy.BTOPT), /* level 15.*/ + new CompressionParameters(18, 19, 19, 6, 3, 64, Strategy.BTOPT), /* level 16.*/ + new CompressionParameters(18, 19, 19, 8, 3, 128, Strategy.BTOPT), /* level 17.*/ + new CompressionParameters(18, 19, 19, 10, 3, 256, Strategy.BTOPT), /* level 18.*/ + new CompressionParameters(18, 19, 19, 10, 3, 256, Strategy.BTULTRA), /* level 19.*/ + new CompressionParameters(18, 19, 19, 11, 3, 512, Strategy.BTULTRA), /* level 20.*/ + new CompressionParameters(18, 19, 19, 12, 3, 512, Strategy.BTULTRA), /* level 21.*/ + new CompressionParameters(18, 19, 19, 13, 3, 999, Strategy.BTULTRA) /* level 22.*/ + }, + { + // for size <= 128 KB + new CompressionParameters(17, 12, 12, 1, 5, 1, Strategy.FAST), /* base for negative levels */ + new CompressionParameters(17, 12, 13, 1, 6, 0, Strategy.FAST), /* level 1 */ + new CompressionParameters(17, 13, 15, 1, 5, 0, Strategy.FAST), /* level 2 */ + new CompressionParameters(17, 15, 16, 2, 5, 1, Strategy.DFAST), /* level 3 */ + new CompressionParameters(17, 17, 17, 2, 4, 1, Strategy.DFAST), /* level 4 */ + new CompressionParameters(17, 16, 17, 3, 4, 2, Strategy.GREEDY), /* level 5 */ + new CompressionParameters(17, 17, 17, 3, 4, 4, Strategy.LAZY), /* level 6 */ + new CompressionParameters(17, 17, 17, 3, 4, 8, Strategy.LAZY2), /* level 7 */ + new CompressionParameters(17, 17, 17, 4, 4, 8, Strategy.LAZY2), /* level 8 */ + new CompressionParameters(17, 17, 17, 5, 4, 8, Strategy.LAZY2), /* level 9 */ + new CompressionParameters(17, 17, 17, 6, 4, 8, Strategy.LAZY2), /* level 10 */ + new CompressionParameters(17, 17, 17, 7, 4, 8, Strategy.LAZY2), /* level 11 */ + new CompressionParameters(17, 18, 17, 6, 4, 16, Strategy.BTLAZY2), /* level 12 */ + new CompressionParameters(17, 18, 17, 8, 4, 16, Strategy.BTLAZY2), /* level 13.*/ + new CompressionParameters(17, 18, 17, 4, 4, 32, Strategy.BTOPT), /* level 14.*/ + new CompressionParameters(17, 18, 17, 6, 3, 64, Strategy.BTOPT), /* level 15.*/ + new CompressionParameters(17, 18, 17, 7, 3, 128, Strategy.BTOPT), /* level 16.*/ + new CompressionParameters(17, 18, 17, 7, 3, 256, Strategy.BTOPT), /* level 17.*/ + new CompressionParameters(17, 18, 17, 8, 3, 256, Strategy.BTOPT), /* level 18.*/ + new CompressionParameters(17, 18, 17, 8, 3, 256, Strategy.BTULTRA), /* level 19.*/ + new CompressionParameters(17, 18, 17, 9, 3, 256, Strategy.BTULTRA), /* level 20.*/ + new CompressionParameters(17, 18, 17, 10, 3, 256, Strategy.BTULTRA), /* level 21.*/ + new CompressionParameters(17, 18, 17, 11, 3, 512, Strategy.BTULTRA) /* level 22.*/ + }, + { + // for size <= 16 KB + new CompressionParameters(14, 12, 13, 1, 5, 1, Strategy.FAST), /* base for negative levels */ + new CompressionParameters(14, 14, 15, 1, 5, 0, Strategy.FAST), /* level 1 */ + new CompressionParameters(14, 14, 15, 1, 4, 0, Strategy.FAST), /* level 2 */ + new CompressionParameters(14, 14, 14, 2, 4, 1, Strategy.DFAST), /* level 3.*/ + new CompressionParameters(14, 14, 14, 4, 4, 2, Strategy.GREEDY), /* level 4.*/ + new CompressionParameters(14, 14, 14, 3, 4, 4, Strategy.LAZY), /* level 5.*/ + new CompressionParameters(14, 14, 14, 4, 4, 8, Strategy.LAZY2), /* level 6 */ + new CompressionParameters(14, 14, 14, 6, 4, 8, Strategy.LAZY2), /* level 7 */ + new CompressionParameters(14, 14, 14, 8, 4, 8, Strategy.LAZY2), /* level 8.*/ + new CompressionParameters(14, 15, 14, 5, 4, 8, Strategy.BTLAZY2), /* level 9.*/ + new CompressionParameters(14, 15, 14, 9, 4, 8, Strategy.BTLAZY2), /* level 10.*/ + new CompressionParameters(14, 15, 14, 3, 4, 12, Strategy.BTOPT), /* level 11.*/ + new CompressionParameters(14, 15, 14, 6, 3, 16, Strategy.BTOPT), /* level 12.*/ + new CompressionParameters(14, 15, 14, 6, 3, 24, Strategy.BTOPT), /* level 13.*/ + new CompressionParameters(14, 15, 15, 6, 3, 48, Strategy.BTOPT), /* level 14.*/ + new CompressionParameters(14, 15, 15, 6, 3, 64, Strategy.BTOPT), /* level 15.*/ + new CompressionParameters(14, 15, 15, 6, 3, 96, Strategy.BTOPT), /* level 16.*/ + new CompressionParameters(14, 15, 15, 6, 3, 128, Strategy.BTOPT), /* level 17.*/ + new CompressionParameters(14, 15, 15, 8, 3, 256, Strategy.BTOPT), /* level 18.*/ + new CompressionParameters(14, 15, 15, 6, 3, 256, Strategy.BTULTRA), /* level 19.*/ + new CompressionParameters(14, 15, 15, 8, 3, 256, Strategy.BTULTRA), /* level 20.*/ + new CompressionParameters(14, 15, 15, 9, 3, 256, Strategy.BTULTRA), /* level 21.*/ + new CompressionParameters(14, 15, 15, 10, 3, 512, Strategy.BTULTRA) /* level 22.*/ + } + }; + + public enum Strategy + { + // from faster to stronger + + // YC: fast is a "single probe" strategy : at every position, we attempt to find a match, and give up if we don't find any. similar to lz4. + FAST(BlockCompressor.UNSUPPORTED), + + // YC: double_fast is a 2 attempts strategies. They are not symmetrical by the way. One attempt is "normal" while the second one looks for "long matches". It was + // empirically found that this was the best trade off. As can be guessed, it's slower than single-attempt, but find more and better matches, so compresses better. + DFAST(new DoubleFastBlockCompressor()), + + // YC: greedy uses a hash chain strategy. Every position is hashed, and all positions with same hash are chained. The algorithm goes through all candidates. There are + // diminishing returns in going deeper and deeper, so after a nb of attempts (which can be selected), it abandons the search. The best (longest) match wins. If there is + // one winner, it's immediately encoded. + GREEDY(BlockCompressor.UNSUPPORTED), + + // YC: lazy will do something similar to greedy, but will not encode immediately. It will search again at next position, in case it would find something better. + // It's actually fairly common to have a small match at position p hiding a more worthy one at position p+1. This obviously increases the search workload. But the + // resulting compressed stream generally contains larger matches, hence compresses better. + LAZY(BlockCompressor.UNSUPPORTED), + + // YC: lazy2 is same as lazy, but deeper. It will search at P, P+1 and then P+2 in case it would find something even better. More workload. Better matches. + LAZY2(BlockCompressor.UNSUPPORTED), + + // YC: btlazy2 is like lazy2, but trades the hash chain for a binary tree. This becomes necessary, as the nb of attempts becomes prohibitively expensive. The binary tree + // complexity increases with log of search depth, instead of proportionally with search depth. So searching deeper in history quickly becomes the dominant operation. + // btlazy2 cuts into that. But it costs 2x more memory. It's also relatively "slow", even when trying to cut its parameters to make it perform faster. So it's really + // a high compression strategy. + BTLAZY2(BlockCompressor.UNSUPPORTED), + + // YC: btopt is, well, a hell of lot more complex. + // It will compute and find multiple matches per position, will dynamically compare every path from point P to P+N, reverse the graph to find cheapest path, iterate on + // batches of overlapping matches, etc. It's much more expensive. But the compression ratio is also much better. + BTOPT(BlockCompressor.UNSUPPORTED), + + // YC: btultra is about the same, but doesn't cut as many corners (btopt "abandons" more quickly unpromising little gains). Slower, stronger. + BTULTRA(BlockCompressor.UNSUPPORTED); + + private final BlockCompressor compressor; + + Strategy(BlockCompressor compressor) + { + this.compressor = compressor; + } + + public BlockCompressor getCompressor() + { + return compressor; + } + } + + public CompressionParameters(int windowLog, int chainLog, int hashLog, int searchLog, int searchLength, int targetLength, Strategy strategy) + { + this.windowLog = windowLog; + this.windowSize = 1 << windowLog; + this.blockSize = Math.min(MAX_BLOCK_SIZE, windowSize); + this.chainLog = chainLog; + this.hashLog = hashLog; + this.searchLog = searchLog; + this.searchLength = searchLength; + this.targetLength = targetLength; + this.strategy = strategy; + } + + public int getWindowLog() + { + return windowLog; + } + + public int getWindowSize() + { + return windowSize; + } + + public int getBlockSize() + { + return blockSize; + } + + public int getSearchLength() + { + return searchLength; + } + + public int getChainLog() + { + return chainLog; + } + + public int getHashLog() + { + return hashLog; + } + + public int getSearchLog() + { + return searchLog; + } + + public int getTargetLength() + { + return targetLength; + } + + public Strategy getStrategy() + { + return strategy; + } + + public static CompressionParameters compute(int compressionLevel, int estimatedInputSize) + { + CompressionParameters defaultParameters = getDefaultParameters(compressionLevel, estimatedInputSize); + if (estimatedInputSize < 0) { + return defaultParameters; + } + + int targetLength = defaultParameters.targetLength; + int windowLog = defaultParameters.windowLog; + int chainLog = defaultParameters.chainLog; + int hashLog = defaultParameters.hashLog; + int searchLog = defaultParameters.searchLog; + int searchLength = defaultParameters.searchLength; + Strategy strategy = defaultParameters.strategy; + + if (compressionLevel < 0) { + targetLength = -compressionLevel; // acceleration factor + } + + // resize windowLog if input is small enough, to use less memory + long maxWindowResize = 1L << (MAX_WINDOW_LOG - 1); + if (estimatedInputSize < maxWindowResize) { + int hashSizeMin = 1 << MIN_HASH_LOG; + int inputSizeLog = (estimatedInputSize < hashSizeMin) ? MIN_HASH_LOG : highestBit(estimatedInputSize - 1) + 1; + if (windowLog > inputSizeLog) { + windowLog = inputSizeLog; + } + } + + if (hashLog > windowLog + 1) { + hashLog = windowLog + 1; + } + + int cycleLog = cycleLog(chainLog, strategy); + if (cycleLog > windowLog) { + chainLog -= (cycleLog - windowLog); + } + + if (windowLog < MIN_WINDOW_LOG) { + windowLog = MIN_WINDOW_LOG; + } + + return new CompressionParameters(windowLog, chainLog, hashLog, searchLog, searchLength, targetLength, strategy); + } + + private static CompressionParameters getDefaultParameters(int compressionLevel, long estimatedInputSize) + { + int table = 0; + + if (estimatedInputSize >= 0) { + if (estimatedInputSize <= 16 * 1024) { + table = 3; + } + else if (estimatedInputSize <= 128 * 1024) { + table = 2; + } + else if (estimatedInputSize <= 256 * 1024) { + table = 1; + } + } + + int row = DEFAULT_COMPRESSION_LEVEL; + + if (compressionLevel != 0) { // TODO: figure out better way to indicate default compression level + row = Math.min(Math.max(0, compressionLevel), MAX_COMPRESSION_LEVEL); + } + + return DEFAULT_COMPRESSION_PARAMETERS[table][row]; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/Constants.java b/src/main/java/io/airlift/compress/zstd/Constants.java new file mode 100644 index 0000000..6da44ef --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/Constants.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +class Constants +{ + public static final int SIZE_OF_BYTE = 1; + public static final int SIZE_OF_SHORT = 2; + public static final int SIZE_OF_INT = 4; + public static final int SIZE_OF_LONG = 8; + + public static final int MAGIC_NUMBER = 0xFD2FB528; + + public static final int MIN_WINDOW_LOG = 10; + public static final int MAX_WINDOW_LOG = 31; + + public static final int SIZE_OF_BLOCK_HEADER = 3; + + public static final int MIN_SEQUENCES_SIZE = 1; + public static final int MIN_BLOCK_SIZE = 1 // block type tag + + 1 // min size of raw or rle length header + + MIN_SEQUENCES_SIZE; + public static final int MAX_BLOCK_SIZE = 128 * 1024; + + public static final int REPEATED_OFFSET_COUNT = 3; + + // block types + public static final int RAW_BLOCK = 0; + public static final int RLE_BLOCK = 1; + public static final int COMPRESSED_BLOCK = 2; + + // sequence encoding types + public static final int SEQUENCE_ENCODING_BASIC = 0; + public static final int SEQUENCE_ENCODING_RLE = 1; + public static final int SEQUENCE_ENCODING_COMPRESSED = 2; + public static final int SEQUENCE_ENCODING_REPEAT = 3; + + public static final int MAX_LITERALS_LENGTH_SYMBOL = 35; + public static final int MAX_MATCH_LENGTH_SYMBOL = 52; + public static final int MAX_OFFSET_CODE_SYMBOL = 31; + public static final int DEFAULT_MAX_OFFSET_CODE_SYMBOL = 28; + + public static final int LITERAL_LENGTH_TABLE_LOG = 9; + public static final int MATCH_LENGTH_TABLE_LOG = 9; + public static final int OFFSET_TABLE_LOG = 8; + + // literal block types + public static final int RAW_LITERALS_BLOCK = 0; + public static final int RLE_LITERALS_BLOCK = 1; + public static final int COMPRESSED_LITERALS_BLOCK = 2; + public static final int TREELESS_LITERALS_BLOCK = 3; + + public static final int LONG_NUMBER_OF_SEQUENCES = 0x7F00; + + public static final int[] LITERALS_LENGTH_BITS = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 2, 2, 3, 3, + 4, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16}; + + public static final int[] MATCH_LENGTH_BITS = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16}; + + private Constants() + { + } +} diff --git a/src/main/java/io/airlift/compress/zstd/DoubleFastBlockCompressor.java b/src/main/java/io/airlift/compress/zstd/DoubleFastBlockCompressor.java new file mode 100644 index 0000000..e02a0e2 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/DoubleFastBlockCompressor.java @@ -0,0 +1,262 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; + +class DoubleFastBlockCompressor + implements BlockCompressor +{ + private static final int MIN_MATCH = 3; + private static final int SEARCH_STRENGTH = 8; + private static final int REP_MOVE = Constants.REPEATED_OFFSET_COUNT - 1; + + public int compressBlock(Object inputBase, final long inputAddress, int inputSize, SequenceStore output, BlockCompressionState state, RepeatedOffsets offsets, CompressionParameters parameters) + { + int matchSearchLength = Math.max(parameters.getSearchLength(), 4); + + // Offsets in hash tables are relative to baseAddress. Hash tables can be reused across calls to compressBlock as long as + // baseAddress is kept constant. + // We don't want to generate sequences that point before the current window limit, so we "filter" out all results from looking up in the hash tables + // beyond that point. + final long baseAddress = state.getBaseAddress(); + final long windowBaseAddress = baseAddress + state.getWindowBaseOffset(); + + int[] longHashTable = state.hashTable; + int longHashBits = parameters.getHashLog(); + + int[] shortHashTable = state.chainTable; + int shortHashBits = parameters.getChainLog(); + + final long inputEnd = inputAddress + inputSize; + final long inputLimit = inputEnd - SIZE_OF_LONG; // We read a long at a time for computing the hashes + + long input = inputAddress; + long anchor = inputAddress; + + int offset1 = offsets.getOffset0(); + int offset2 = offsets.getOffset1(); + + int savedOffset = 0; + + if (input - windowBaseAddress == 0) { + input++; + } + int maxRep = (int) (input - windowBaseAddress); + + if (offset2 > maxRep) { + savedOffset = offset2; + offset2 = 0; + } + + if (offset1 > maxRep) { + savedOffset = offset1; + offset1 = 0; + } + + while (input < inputLimit) { // < instead of <=, because repcode check at (input+1) + int shortHash = hash(inputBase, input, shortHashBits, matchSearchLength); + long shortMatchAddress = baseAddress + shortHashTable[shortHash]; + + int longHash = hash8(UNSAFE.getLong(inputBase, input), longHashBits); + long longMatchAddress = baseAddress + longHashTable[longHash]; + + // update hash tables + int current = (int) (input - baseAddress); + longHashTable[longHash] = current; + shortHashTable[shortHash] = current; + + int matchLength; + int offset; + + if (offset1 > 0 && UNSAFE.getInt(inputBase, input + 1 - offset1) == UNSAFE.getInt(inputBase, input + 1)) { + // found a repeated sequence of at least 4 bytes, separated by offset1 + matchLength = count(inputBase, input + 1 + SIZE_OF_INT, inputEnd, input + 1 + SIZE_OF_INT - offset1) + SIZE_OF_INT; + input++; + output.storeSequence(inputBase, anchor, (int) (input - anchor), 0, matchLength - MIN_MATCH); + } + else { + // check prefix long match + if (longMatchAddress > windowBaseAddress && UNSAFE.getLong(inputBase, longMatchAddress) == UNSAFE.getLong(inputBase, input)) { + matchLength = count(inputBase, input + SIZE_OF_LONG, inputEnd, longMatchAddress + SIZE_OF_LONG) + SIZE_OF_LONG; + offset = (int) (input - longMatchAddress); + while (input > anchor && longMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, longMatchAddress - 1)) { + input--; + longMatchAddress--; + matchLength++; + } + } + else { + // check prefix short match + if (shortMatchAddress > windowBaseAddress && UNSAFE.getInt(inputBase, shortMatchAddress) == UNSAFE.getInt(inputBase, input)) { + int nextOffsetHash = hash8(UNSAFE.getLong(inputBase, input + 1), longHashBits); + long nextOffsetMatchAddress = baseAddress + longHashTable[nextOffsetHash]; + longHashTable[nextOffsetHash] = current + 1; + + // check prefix long +1 match + if (nextOffsetMatchAddress > windowBaseAddress && UNSAFE.getLong(inputBase, nextOffsetMatchAddress) == UNSAFE.getLong(inputBase, input + 1)) { + matchLength = count(inputBase, input + 1 + SIZE_OF_LONG, inputEnd, nextOffsetMatchAddress + SIZE_OF_LONG) + SIZE_OF_LONG; + input++; + offset = (int) (input - nextOffsetMatchAddress); + while (input > anchor && nextOffsetMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, nextOffsetMatchAddress - 1)) { + input--; + nextOffsetMatchAddress--; + matchLength++; + } + } + else { + // if no long +1 match, explore the short match we found + matchLength = count(inputBase, input + SIZE_OF_INT, inputEnd, shortMatchAddress + SIZE_OF_INT) + SIZE_OF_INT; + offset = (int) (input - shortMatchAddress); + while (input > anchor && shortMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, shortMatchAddress - 1)) { + input--; + shortMatchAddress--; + matchLength++; + } + } + } + else { + input += ((input - anchor) >> SEARCH_STRENGTH) + 1; + continue; + } + } + + offset2 = offset1; + offset1 = offset; + + output.storeSequence(inputBase, anchor, (int) (input - anchor), offset + REP_MOVE, matchLength - MIN_MATCH); + } + + input += matchLength; + anchor = input; + + if (input <= inputLimit) { + // Fill Table + longHashTable[hash8(UNSAFE.getLong(inputBase, baseAddress + current + 2), longHashBits)] = current + 2; + shortHashTable[hash(inputBase, baseAddress + current + 2, shortHashBits, matchSearchLength)] = current + 2; + + longHashTable[hash8(UNSAFE.getLong(inputBase, input - 2), longHashBits)] = (int) (input - 2 - baseAddress); + shortHashTable[hash(inputBase, input - 2, shortHashBits, matchSearchLength)] = (int) (input - 2 - baseAddress); + + while (input <= inputLimit && offset2 > 0 && UNSAFE.getInt(inputBase, input) == UNSAFE.getInt(inputBase, input - offset2)) { + int repetitionLength = count(inputBase, input + SIZE_OF_INT, inputEnd, input + SIZE_OF_INT - offset2) + SIZE_OF_INT; + + // swap offset2 <=> offset1 + int temp = offset2; + offset2 = offset1; + offset1 = temp; + + shortHashTable[hash(inputBase, input, shortHashBits, matchSearchLength)] = (int) (input - baseAddress); + longHashTable[hash8(UNSAFE.getLong(inputBase, input), longHashBits)] = (int) (input - baseAddress); + + output.storeSequence(inputBase, anchor, 0, 0, repetitionLength - MIN_MATCH); + + input += repetitionLength; + anchor = input; + } + } + } + + // save reps for next block + offsets.saveOffset0(offset1 != 0 ? offset1 : savedOffset); + offsets.saveOffset1(offset2 != 0 ? offset2 : savedOffset); + + // return the last literals size + return (int) (inputEnd - anchor); + } + + // TODO: same as LZ4RawCompressor.count + + /** + * matchAddress must be < inputAddress + */ + public static int count(Object inputBase, final long inputAddress, final long inputLimit, final long matchAddress) + { + long input = inputAddress; + long match = matchAddress; + + int remaining = (int) (inputLimit - inputAddress); + + // first, compare long at a time + int count = 0; + while (count < remaining - (SIZE_OF_LONG - 1)) { + long diff = UNSAFE.getLong(inputBase, match) ^ UNSAFE.getLong(inputBase, input); + if (diff != 0) { + return count + (Long.numberOfTrailingZeros(diff) >> 3); + } + + count += SIZE_OF_LONG; + input += SIZE_OF_LONG; + match += SIZE_OF_LONG; + } + + while (count < remaining && UNSAFE.getByte(inputBase, match) == UNSAFE.getByte(inputBase, input)) { + count++; + input++; + match++; + } + + return count; + } + + private static int hash(Object inputBase, long inputAddress, int bits, int matchSearchLength) + { + switch (matchSearchLength) { + case 8: + return hash8(UNSAFE.getLong(inputBase, inputAddress), bits); + case 7: + return hash7(UNSAFE.getLong(inputBase, inputAddress), bits); + case 6: + return hash6(UNSAFE.getLong(inputBase, inputAddress), bits); + case 5: + return hash5(UNSAFE.getLong(inputBase, inputAddress), bits); + default: + return hash4(UNSAFE.getInt(inputBase, inputAddress), bits); + } + } + + private static final int PRIME_4_BYTES = 0x9E3779B1; + private static final long PRIME_5_BYTES = 0xCF1BBCDCBBL; + private static final long PRIME_6_BYTES = 0xCF1BBCDCBF9BL; + private static final long PRIME_7_BYTES = 0xCF1BBCDCBFA563L; + private static final long PRIME_8_BYTES = 0xCF1BBCDCB7A56463L; + + private static int hash4(int value, int bits) + { + return (value * PRIME_4_BYTES) >>> (Integer.SIZE - bits); + } + + private static int hash5(long value, int bits) + { + return (int) (((value << (Long.SIZE - 40)) * PRIME_5_BYTES) >>> (Long.SIZE - bits)); + } + + private static int hash6(long value, int bits) + { + return (int) (((value << (Long.SIZE - 48)) * PRIME_6_BYTES) >>> (Long.SIZE - bits)); + } + + private static int hash7(long value, int bits) + { + return (int) (((value << (Long.SIZE - 56)) * PRIME_7_BYTES) >>> (Long.SIZE - bits)); + } + + private static int hash8(long value, int bits) + { + return (int) ((value * PRIME_8_BYTES) >>> (Long.SIZE - bits)); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/FiniteStateEntropy.java b/src/main/java/io/airlift/compress/zstd/FiniteStateEntropy.java new file mode 100644 index 0000000..59cc2f9 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/FiniteStateEntropy.java @@ -0,0 +1,552 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.BitInputStream.peekBits; +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; +import static io.airlift.compress.zstd.Util.verify; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +class FiniteStateEntropy +{ + public static final int MAX_SYMBOL = 255; + public static final int MAX_TABLE_LOG = 12; + public static final int MIN_TABLE_LOG = 5; + + private static final int[] REST_TO_BEAT = new int[] {0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}; + private static final short UNASSIGNED = -2; + + private FiniteStateEntropy() + { + } + + public static int decompress(FiniteStateEntropy.Table table, final Object inputBase, final long inputAddress, final long inputLimit, byte[] outputBuffer) + { + final Object outputBase = outputBuffer; + final long outputAddress = ARRAY_BYTE_BASE_OFFSET; + final long outputLimit = outputAddress + outputBuffer.length; + + long input = inputAddress; + long output = outputAddress; + + // initialize bit stream + BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, input, inputLimit); + initializer.initialize(); + int bitsConsumed = initializer.getBitsConsumed(); + long currentAddress = initializer.getCurrentAddress(); + long bits = initializer.getBits(); + + // initialize first FSE stream + int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size); + bitsConsumed += table.log2Size; + + BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader.load(); + bits = loader.getBits(); + bitsConsumed = loader.getBitsConsumed(); + currentAddress = loader.getCurrentAddress(); + + // initialize second FSE stream + int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size); + bitsConsumed += table.log2Size; + + loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader.load(); + bits = loader.getBits(); + bitsConsumed = loader.getBitsConsumed(); + currentAddress = loader.getCurrentAddress(); + + byte[] symbols = table.symbol; + byte[] numbersOfBits = table.numberOfBits; + int[] newStates = table.newState; + + // decode 4 symbols per loop + while (output <= outputLimit - 4) { + int numberOfBits; + + UNSAFE.putByte(outputBase, output, symbols[state1]); + numberOfBits = numbersOfBits[state1]; + state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits)); + bitsConsumed += numberOfBits; + + UNSAFE.putByte(outputBase, output + 1, symbols[state2]); + numberOfBits = numbersOfBits[state2]; + state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits)); + bitsConsumed += numberOfBits; + + UNSAFE.putByte(outputBase, output + 2, symbols[state1]); + numberOfBits = numbersOfBits[state1]; + state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits)); + bitsConsumed += numberOfBits; + + UNSAFE.putByte(outputBase, output + 3, symbols[state2]); + numberOfBits = numbersOfBits[state2]; + state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits)); + bitsConsumed += numberOfBits; + + output += SIZE_OF_INT; + + loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + boolean done = loader.load(); + bitsConsumed = loader.getBitsConsumed(); + bits = loader.getBits(); + currentAddress = loader.getCurrentAddress(); + if (done) { + break; + } + } + + while (true) { + verify(output <= outputLimit - 2, input, "Output buffer is too small"); + UNSAFE.putByte(outputBase, output++, symbols[state1]); + int numberOfBits = numbersOfBits[state1]; + state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits)); + bitsConsumed += numberOfBits; + + loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader.load(); + bitsConsumed = loader.getBitsConsumed(); + bits = loader.getBits(); + currentAddress = loader.getCurrentAddress(); + + if (loader.isOverflow()) { + UNSAFE.putByte(outputBase, output++, symbols[state2]); + break; + } + + verify(output <= outputLimit - 2, input, "Output buffer is too small"); + UNSAFE.putByte(outputBase, output++, symbols[state2]); + int numberOfBits1 = numbersOfBits[state2]; + state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits1)); + bitsConsumed += numberOfBits1; + + loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader.load(); + bitsConsumed = loader.getBitsConsumed(); + bits = loader.getBits(); + currentAddress = loader.getCurrentAddress(); + + if (loader.isOverflow()) { + UNSAFE.putByte(outputBase, output++, symbols[state1]); + break; + } + } + + return (int) (output - outputAddress); + } + + public static int compress(Object outputBase, long outputAddress, int outputSize, byte[] input, int inputSize, FseCompressionTable table) + { + return compress(outputBase, outputAddress, outputSize, input, ARRAY_BYTE_BASE_OFFSET, inputSize, table); + } + + public static int compress(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, FseCompressionTable table) + { + checkArgument(outputSize >= SIZE_OF_LONG, "Output buffer too small"); + + final long start = inputAddress; + final long inputLimit = start + inputSize; + + long input = inputLimit; + + if (inputSize <= 2) { + return 0; + } + + BitOutputStream stream = new BitOutputStream(outputBase, outputAddress, outputSize); + + int state1; + int state2; + + if ((inputSize & 1) != 0) { + input--; + state1 = table.begin(UNSAFE.getByte(inputBase, input)); + + input--; + state2 = table.begin(UNSAFE.getByte(inputBase, input)); + + input--; + state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input)); + + stream.flush(); + } + else { + input--; + state2 = table.begin(UNSAFE.getByte(inputBase, input)); + + input--; + state1 = table.begin(UNSAFE.getByte(inputBase, input)); + } + + // join to mod 4 + inputSize -= 2; + + if ((SIZE_OF_LONG * 8 > MAX_TABLE_LOG * 4 + 7) && (inputSize & 2) != 0) { /* test bit 2 */ + input--; + state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input)); + + input--; + state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input)); + + stream.flush(); + } + + // 2 or 4 encoding per loop + while (input > start) { + input--; + state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input)); + + if (SIZE_OF_LONG * 8 < MAX_TABLE_LOG * 2 + 7) { + stream.flush(); + } + + input--; + state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input)); + + if (SIZE_OF_LONG * 8 > MAX_TABLE_LOG * 4 + 7) { + input--; + state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input)); + + input--; + state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input)); + } + + stream.flush(); + } + + table.finish(stream, state2); + table.finish(stream, state1); + + return stream.close(); + } + + public static int optimalTableLog(int maxTableLog, int inputSize, int maxSymbol) + { + if (inputSize <= 1) { + throw new IllegalArgumentException(); // not supported. Use RLE instead + } + + int result = maxTableLog; + + result = Math.min(result, Util.highestBit((inputSize - 1)) - 2); // we may be able to reduce accuracy if input is small + + // Need a minimum to safely represent all symbol values + result = Math.max(result, Util.minTableLog(inputSize, maxSymbol)); + + result = Math.max(result, MIN_TABLE_LOG); + result = Math.min(result, MAX_TABLE_LOG); + + return result; + } + + public static int normalizeCounts(short[] normalizedCounts, int tableLog, int[] counts, int total, int maxSymbol) + { + checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size"); + checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large"); + checkArgument(tableLog >= Util.minTableLog(total, maxSymbol), "FSE table size too small"); + + long scale = 62 - tableLog; + long step = (1L << 62) / total; + long vstep = 1L << (scale - 20); + + int stillToDistribute = 1 << tableLog; + + int largest = 0; + short largestProbability = 0; + int lowThreshold = total >>> tableLog; + + for (int symbol = 0; symbol <= maxSymbol; symbol++) { + if (counts[symbol] == total) { + throw new IllegalArgumentException(); // TODO: should have been RLE-compressed by upper layers + } + if (counts[symbol] == 0) { + normalizedCounts[symbol] = 0; + continue; + } + if (counts[symbol] <= lowThreshold) { + normalizedCounts[symbol] = -1; + stillToDistribute--; + } + else { + short probability = (short) ((counts[symbol] * step) >>> scale); + if (probability < 8) { + long restToBeat = vstep * REST_TO_BEAT[probability]; + long delta = counts[symbol] * step - (((long) probability) << scale); + if (delta > restToBeat) { + probability++; + } + } + if (probability > largestProbability) { + largestProbability = probability; + largest = symbol; + } + normalizedCounts[symbol] = probability; + stillToDistribute -= probability; + } + } + + if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) { + // corner case. Need another normalization method + // TODO size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue); + normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol); + } + else { + normalizedCounts[largest] += (short) stillToDistribute; + } + + return tableLog; + } + + private static int normalizeCounts2(short[] normalizedCounts, int tableLog, int[] counts, int total, int maxSymbol) + { + int distributed = 0; + + int lowThreshold = total >>> tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1) + int lowOne = (total * 3) >>> (tableLog + 1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1 + + for (int i = 0; i <= maxSymbol; i++) { + if (counts[i] == 0) { + normalizedCounts[i] = 0; + } + else if (counts[i] <= lowThreshold) { + normalizedCounts[i] = -1; + distributed++; + total -= counts[i]; + } + else if (counts[i] <= lowOne) { + normalizedCounts[i] = 1; + distributed++; + total -= counts[i]; + } + else { + normalizedCounts[i] = UNASSIGNED; + } + } + + int normalizationFactor = 1 << tableLog; + int toDistribute = normalizationFactor - distributed; + + if ((total / toDistribute) > lowOne) { + /* risk of rounding to zero */ + lowOne = ((total * 3) / (toDistribute * 2)); + for (int i = 0; i <= maxSymbol; i++) { + if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) { + normalizedCounts[i] = 1; + distributed++; + total -= counts[i]; + } + } + toDistribute = normalizationFactor - distributed; + } + + if (distributed == maxSymbol + 1) { + // all values are pretty poor; + // probably incompressible data (should have already been detected); + // find max, then give all remaining points to max + int maxValue = 0; + int maxCount = 0; + for (int i = 0; i <= maxSymbol; i++) { + if (counts[i] > maxCount) { + maxValue = i; + maxCount = counts[i]; + } + } + normalizedCounts[maxValue] += (short) toDistribute; + return 0; + } + + if (total == 0) { + // all of the symbols were low enough for the lowOne or lowThreshold + for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) { + if (normalizedCounts[i] > 0) { + toDistribute--; + normalizedCounts[i]++; + } + } + return 0; + } + + // TODO: simplify/document this code + long vStepLog = 62 - tableLog; + long mid = (1L << (vStepLog - 1)) - 1; + long rStep = (((1L << vStepLog) * toDistribute) + mid) / total; /* scale on remaining */ + long tmpTotal = mid; + for (int i = 0; i <= maxSymbol; i++) { + if (normalizedCounts[i] == UNASSIGNED) { + long end = tmpTotal + (counts[i] * rStep); + int sStart = (int) (tmpTotal >>> vStepLog); + int sEnd = (int) (end >>> vStepLog); + int weight = sEnd - sStart; + + if (weight < 1) { + throw new AssertionError(); + } + normalizedCounts[i] = (short) weight; + tmpTotal = end; + } + } + + return 0; + } + + public static int writeNormalizedCounts(Object outputBase, long outputAddress, int outputSize, short[] normalizedCounts, int maxSymbol, int tableLog) + { + checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large"); + checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small"); + + long output = outputAddress; + long outputLimit = outputAddress + outputSize; + + int tableSize = 1 << tableLog; + + int bitCount = 0; + + // encode table size + int bitStream = (tableLog - MIN_TABLE_LOG); + bitCount += 4; + + int remaining = tableSize + 1; // +1 for extra accuracy + int threshold = tableSize; + int tableBitCount = tableLog + 1; + + int symbol = 0; + + boolean previousIs0 = false; + while (remaining > 1) { + if (previousIs0) { + // From RFC 8478, section 4.1.1: + // When a symbol has a probability of zero, it is followed by a 2-bit + // repeat flag. This repeat flag tells how many probabilities of zeroes + // follow the current one. It provides a number ranging from 0 to 3. + // If it is a 3, another 2-bit repeat flag follows, and so on. + int start = symbol; + + // find run of symbols with count 0 + while (normalizedCounts[symbol] == 0) { + symbol++; + } + + // encode in batches if 8 repeat sequences in one shot (representing 24 symbols total) + while (symbol >= start + 24) { + start += 24; + bitStream |= (0b11_11_11_11_11_11_11_11 << bitCount); + checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small"); + + UNSAFE.putShort(outputBase, output, (short) bitStream); + output += SIZE_OF_SHORT; + + // flush now, so no need to increase bitCount by 16 + bitStream >>>= Short.SIZE; + } + + // encode remaining in batches of 3 symbols + while (symbol >= start + 3) { + start += 3; + bitStream |= 0b11 << bitCount; + bitCount += 2; + } + + // encode tail + bitStream |= (symbol - start) << bitCount; + bitCount += 2; + + // flush bitstream if necessary + if (bitCount > 16) { + checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small"); + + UNSAFE.putShort(outputBase, output, (short) bitStream); + output += SIZE_OF_SHORT; + + bitStream >>>= Short.SIZE; + bitCount -= Short.SIZE; + } + } + + int count = normalizedCounts[symbol++]; + int max = (2 * threshold - 1) - remaining; + remaining -= count < 0 ? -count : count; + count++; /* +1 for extra accuracy */ + if (count >= threshold) { + count += max; + } + bitStream |= count << bitCount; + bitCount += tableBitCount; + bitCount -= (count < max ? 1 : 0); + previousIs0 = (count == 1); + + if (remaining < 1) { + throw new AssertionError(); + } + + while (remaining < threshold) { + tableBitCount--; + threshold >>= 1; + } + + // flush bitstream if necessary + if (bitCount > 16) { + checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small"); + + UNSAFE.putShort(outputBase, output, (short) bitStream); + output += SIZE_OF_SHORT; + + bitStream >>>= Short.SIZE; + bitCount -= Short.SIZE; + } + } + + // flush remaining bitstream + checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small"); + UNSAFE.putShort(outputBase, output, (short) bitStream); + output += (bitCount + 7) / 8; + + checkArgument(symbol <= maxSymbol + 1, "Error"); // TODO + + return (int) (output - outputAddress); + } + + public static final class Table + { + int log2Size; + final int[] newState; + final byte[] symbol; + final byte[] numberOfBits; + + public Table(int log2Capacity) + { + int capacity = 1 << log2Capacity; + newState = new int[capacity]; + symbol = new byte[capacity]; + numberOfBits = new byte[capacity]; + } + + public Table(int log2Size, int[] newState, byte[] symbol, byte[] numberOfBits) + { + int size = 1 << log2Size; + if (newState.length != size || symbol.length != size || numberOfBits.length != size) { + throw new IllegalArgumentException("Expected arrays to match provided size"); + } + + this.log2Size = log2Size; + this.newState = newState; + this.symbol = symbol; + this.numberOfBits = numberOfBits; + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/FrameHeader.java b/src/main/java/io/airlift/compress/zstd/FrameHeader.java new file mode 100644 index 0000000..91aafdf --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/FrameHeader.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Objects; +import java.util.StringJoiner; + +import static io.airlift.compress.zstd.Util.checkState; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; + +class FrameHeader +{ + final long headerSize; + final int windowSize; + final long contentSize; + final long dictionaryId; + final boolean hasChecksum; + + public FrameHeader(long headerSize, int windowSize, long contentSize, long dictionaryId, boolean hasChecksum) + { + checkState(windowSize >= 0 || contentSize >= 0, "Invalid frame header: contentSize or windowSize must be set"); + this.headerSize = headerSize; + this.windowSize = windowSize; + this.contentSize = contentSize; + this.dictionaryId = dictionaryId; + this.hasChecksum = hasChecksum; + } + + public int computeRequiredOutputBufferLookBackSize() + { + if (contentSize < 0) { + return windowSize; + } + if (windowSize < 0) { + return toIntExact(contentSize); + } + return toIntExact(min(windowSize, contentSize)); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FrameHeader that = (FrameHeader) o; + return headerSize == that.headerSize && + windowSize == that.windowSize && + contentSize == that.contentSize && + dictionaryId == that.dictionaryId && + hasChecksum == that.hasChecksum; + } + + @Override + public int hashCode() + { + return Objects.hash(headerSize, windowSize, contentSize, dictionaryId, hasChecksum); + } + + @Override + public String toString() + { + return new StringJoiner(", ", FrameHeader.class.getSimpleName() + "[", "]") + .add("headerSize=" + headerSize) + .add("windowSize=" + windowSize) + .add("contentSize=" + contentSize) + .add("dictionaryId=" + dictionaryId) + .add("hasChecksum=" + hasChecksum) + .toString(); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/FseCompressionTable.java b/src/main/java/io/airlift/compress/zstd/FseCompressionTable.java new file mode 100644 index 0000000..8272d35 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/FseCompressionTable.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.FiniteStateEntropy.MAX_SYMBOL; + +class FseCompressionTable +{ + private final short[] nextState; + private final int[] deltaNumberOfBits; + private final int[] deltaFindState; + + private int log2Size; + + public FseCompressionTable(int maxTableLog, int maxSymbol) + { + nextState = new short[1 << maxTableLog]; + deltaNumberOfBits = new int[maxSymbol + 1]; + deltaFindState = new int[maxSymbol + 1]; + } + + public static FseCompressionTable newInstance(short[] normalizedCounts, int maxSymbol, int tableLog) + { + FseCompressionTable result = new FseCompressionTable(tableLog, maxSymbol); + result.initialize(normalizedCounts, maxSymbol, tableLog); + + return result; + } + + public void initializeRleTable(int symbol) + { + log2Size = 0; + + nextState[0] = 0; + nextState[1] = 0; + + deltaFindState[symbol] = 0; + deltaNumberOfBits[symbol] = 0; + } + + public void initialize(short[] normalizedCounts, int maxSymbol, int tableLog) + { + int tableSize = 1 << tableLog; + + byte[] table = new byte[tableSize]; // TODO: allocate in workspace + int highThreshold = tableSize - 1; + + // TODO: make sure FseCompressionTable has enough size + log2Size = tableLog; + + // For explanations on how to distribute symbol values over the table: + // http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html + + // symbol start positions + int[] cumulative = new int[MAX_SYMBOL + 2]; // TODO: allocate in workspace + cumulative[0] = 0; + for (int i = 1; i <= maxSymbol + 1; i++) { + if (normalizedCounts[i - 1] == -1) { // Low probability symbol + cumulative[i] = cumulative[i - 1] + 1; + table[highThreshold--] = (byte) (i - 1); + } + else { + cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1]; + } + } + cumulative[maxSymbol + 1] = tableSize + 1; + + // Spread symbols + int position = spreadSymbols(normalizedCounts, maxSymbol, tableSize, highThreshold, table); + + if (position != 0) { + throw new AssertionError("Spread symbols failed"); + } + + // Build table + for (int i = 0; i < tableSize; i++) { + byte symbol = table[i]; + nextState[cumulative[symbol]++] = (short) (tableSize + i); /* TableU16 : sorted by symbol order; gives next state value */ + } + + // Build symbol transformation table + int total = 0; + for (int symbol = 0; symbol <= maxSymbol; symbol++) { + switch (normalizedCounts[symbol]) { + case 0: + deltaNumberOfBits[symbol] = ((tableLog + 1) << 16) - tableSize; + break; + case -1: + case 1: + deltaNumberOfBits[symbol] = (tableLog << 16) - tableSize; + deltaFindState[symbol] = total - 1; + total++; + break; + default: + int maxBitsOut = tableLog - Util.highestBit(normalizedCounts[symbol] - 1); + int minStatePlus = normalizedCounts[symbol] << maxBitsOut; + deltaNumberOfBits[symbol] = (maxBitsOut << 16) - minStatePlus; + deltaFindState[symbol] = total - normalizedCounts[symbol]; + total += normalizedCounts[symbol]; + break; + } + } + } + + public int begin(byte symbol) + { + int outputBits = (deltaNumberOfBits[symbol] + (1 << 15)) >>> 16; + int base = ((outputBits << 16) - deltaNumberOfBits[symbol]) >>> outputBits; + return nextState[base + deltaFindState[symbol]]; + } + + public int encode(BitOutputStream stream, int state, int symbol) + { + int outputBits = (state + deltaNumberOfBits[symbol]) >>> 16; + stream.addBits(state, outputBits); + return nextState[(state >>> outputBits) + deltaFindState[symbol]]; + } + + public void finish(BitOutputStream stream, int state) + { + stream.addBits(state, log2Size); + stream.flush(); + } + + private static int calculateStep(int tableSize) + { + return (tableSize >>> 1) + (tableSize >>> 3) + 3; + } + + public static int spreadSymbols(short[] normalizedCounters, int maxSymbolValue, int tableSize, int highThreshold, byte[] symbols) + { + int mask = tableSize - 1; + int step = calculateStep(tableSize); + + int position = 0; + for (byte symbol = 0; symbol <= maxSymbolValue; symbol++) { + for (int i = 0; i < normalizedCounters[symbol]; i++) { + symbols[position] = symbol; + do { + position = (position + step) & mask; + } + while (position > highThreshold); + } + } + return position; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/FseTableReader.java b/src/main/java/io/airlift/compress/zstd/FseTableReader.java new file mode 100644 index 0000000..b3b73bc --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/FseTableReader.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.FiniteStateEntropy.MAX_SYMBOL; +import static io.airlift.compress.zstd.FiniteStateEntropy.MIN_TABLE_LOG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.highestBit; +import static io.airlift.compress.zstd.Util.verify; + +class FseTableReader +{ + private final short[] nextSymbol = new short[MAX_SYMBOL + 1]; + private final short[] normalizedCounters = new short[MAX_SYMBOL + 1]; + + public int readFseTable(FiniteStateEntropy.Table table, Object inputBase, long inputAddress, long inputLimit, int maxSymbol, int maxTableLog) + { + // read table headers + long input = inputAddress; + verify(inputLimit - inputAddress >= 4, input, "Not enough input bytes"); + + int threshold; + int symbolNumber = 0; + boolean previousIsZero = false; + + int bitStream = UNSAFE.getInt(inputBase, input); + + int tableLog = (bitStream & 0xF) + MIN_TABLE_LOG; + + int numberOfBits = tableLog + 1; + bitStream >>>= 4; + int bitCount = 4; + + verify(tableLog <= maxTableLog, input, "FSE table size exceeds maximum allowed size"); + + int remaining = (1 << tableLog) + 1; + threshold = 1 << tableLog; + + while (remaining > 1 && symbolNumber <= maxSymbol) { + if (previousIsZero) { + int n0 = symbolNumber; + while ((bitStream & 0xFFFF) == 0xFFFF) { + n0 += 24; + if (input < inputLimit - 5) { + input += 2; + bitStream = (UNSAFE.getInt(inputBase, input) >>> bitCount); + } + else { + // end of bit stream + bitStream >>>= 16; + bitCount += 16; + } + } + while ((bitStream & 3) == 3) { + n0 += 3; + bitStream >>>= 2; + bitCount += 2; + } + n0 += bitStream & 3; + bitCount += 2; + + verify(n0 <= maxSymbol, input, "Symbol larger than max value"); + + while (symbolNumber < n0) { + normalizedCounters[symbolNumber++] = 0; + } + if ((input <= inputLimit - 7) || (input + (bitCount >>> 3) <= inputLimit - 4)) { + input += bitCount >>> 3; + bitCount &= 7; + bitStream = UNSAFE.getInt(inputBase, input) >>> bitCount; + } + else { + bitStream >>>= 2; + } + } + + short max = (short) ((2 * threshold - 1) - remaining); + short count; + + if ((bitStream & (threshold - 1)) < max) { + count = (short) (bitStream & (threshold - 1)); + bitCount += numberOfBits - 1; + } + else { + count = (short) (bitStream & (2 * threshold - 1)); + if (count >= threshold) { + count -= max; + } + bitCount += numberOfBits; + } + count--; // extra accuracy + + remaining -= Math.abs(count); + normalizedCounters[symbolNumber++] = count; + previousIsZero = count == 0; + while (remaining < threshold) { + numberOfBits--; + threshold >>>= 1; + } + + if ((input <= inputLimit - 7) || (input + (bitCount >> 3) <= inputLimit - 4)) { + input += bitCount >>> 3; + bitCount &= 7; + } + else { + bitCount -= (int) (8 * (inputLimit - 4 - input)); + input = inputLimit - 4; + } + bitStream = UNSAFE.getInt(inputBase, input) >>> (bitCount & 31); + } + + verify(remaining == 1 && bitCount <= 32, input, "Input is corrupted"); + + maxSymbol = symbolNumber - 1; + verify(maxSymbol <= MAX_SYMBOL, input, "Max symbol value too large (too many symbols for FSE)"); + + input += (bitCount + 7) >> 3; + + // populate decoding table + int symbolCount = maxSymbol + 1; + int tableSize = 1 << tableLog; + int highThreshold = tableSize - 1; + + table.log2Size = tableLog; + + for (byte symbol = 0; symbol < symbolCount; symbol++) { + if (normalizedCounters[symbol] == -1) { + table.symbol[highThreshold--] = symbol; + nextSymbol[symbol] = 1; + } + else { + nextSymbol[symbol] = normalizedCounters[symbol]; + } + } + + int position = FseCompressionTable.spreadSymbols(normalizedCounters, maxSymbol, tableSize, highThreshold, table.symbol); + + // position must reach all cells once, otherwise normalizedCounter is incorrect + verify(position == 0, input, "Input is corrupted"); + + for (int i = 0; i < tableSize; i++) { + byte symbol = table.symbol[i]; + short nextState = nextSymbol[symbol]++; + table.numberOfBits[i] = (byte) (tableLog - highestBit(nextState)); + table.newState[i] = (short) ((nextState << table.numberOfBits[i]) - tableSize); + } + + return (int) (input - inputAddress); + } + + public static void initializeRleTable(FiniteStateEntropy.Table table, byte value) + { + table.log2Size = 0; + table.symbol[0] = value; + table.newState[0] = 0; + table.numberOfBits[0] = 0; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/Histogram.java b/src/main/java/io/airlift/compress/zstd/Histogram.java new file mode 100644 index 0000000..579b632 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/Histogram.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +class Histogram +{ + private Histogram() + { + } + + // TODO: count parallel heuristic for large inputs + private static void count(Object inputBase, long inputAddress, int inputSize, int[] counts) + { + long input = inputAddress; + + Arrays.fill(counts, 0); + + for (int i = 0; i < inputSize; i++) { + int symbol = UNSAFE.getByte(inputBase, input) & 0xFF; + input++; + counts[symbol]++; + } + } + + public static int findLargestCount(int[] counts, int maxSymbol) + { + int max = 0; + for (int i = 0; i <= maxSymbol; i++) { + if (counts[i] > max) { + max = counts[i]; + } + } + + return max; + } + + public static int findMaxSymbol(int[] counts, int maxSymbol) + { + while (counts[maxSymbol] == 0) { + maxSymbol--; + } + return maxSymbol; + } + + public static void count(byte[] input, int length, int[] counts) + { + count(input, ARRAY_BYTE_BASE_OFFSET, length, counts); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/Huffman.java b/src/main/java/io/airlift/compress/zstd/Huffman.java new file mode 100644 index 0000000..7df2143 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/Huffman.java @@ -0,0 +1,324 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +import static io.airlift.compress.zstd.BitInputStream.isEndOfStream; +import static io.airlift.compress.zstd.BitInputStream.peekBitsFast; +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.isPowerOf2; +import static io.airlift.compress.zstd.Util.verify; + +class Huffman +{ + public static final int MAX_SYMBOL = 255; + public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1; + + public static final int MAX_TABLE_LOG = 12; + public static final int MIN_TABLE_LOG = 5; + public static final int MAX_FSE_TABLE_LOG = 6; + + // stats + private final byte[] weights = new byte[MAX_SYMBOL + 1]; + private final int[] ranks = new int[MAX_TABLE_LOG + 1]; + + // table + private int tableLog = -1; + private final byte[] symbols = new byte[1 << MAX_TABLE_LOG]; + private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG]; + + private final FseTableReader reader = new FseTableReader(); + private final FiniteStateEntropy.Table fseTable = new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG); + + public boolean isLoaded() + { + return tableLog != -1; + } + + public int readTable(final Object inputBase, final long inputAddress, final int size) + { + Arrays.fill(ranks, 0); + long input = inputAddress; + + // read table header + verify(size > 0, input, "Not enough input bytes"); + int inputSize = UNSAFE.getByte(inputBase, input++) & 0xFF; + + int outputSize; + if (inputSize >= 128) { + outputSize = inputSize - 127; + inputSize = ((outputSize + 1) / 2); + + verify(inputSize + 1 <= size, input, "Not enough input bytes"); + verify(outputSize <= MAX_SYMBOL + 1, input, "Input is corrupted"); + + for (int i = 0; i < outputSize; i += 2) { + int value = UNSAFE.getByte(inputBase, input + i / 2) & 0xFF; + weights[i] = (byte) (value >>> 4); + weights[i + 1] = (byte) (value & 0b1111); + } + } + else { + verify(inputSize + 1 <= size, input, "Not enough input bytes"); + + long inputLimit = input + inputSize; + input += reader.readFseTable(fseTable, inputBase, input, inputLimit, FiniteStateEntropy.MAX_SYMBOL, MAX_FSE_TABLE_LOG); + outputSize = FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit, weights); + } + + int totalWeight = 0; + for (int i = 0; i < outputSize; i++) { + ranks[weights[i]]++; + totalWeight += (1 << weights[i]) >> 1; // TODO same as 1 << (weights[n] - 1)? + } + verify(totalWeight != 0, input, "Input is corrupted"); + + tableLog = Util.highestBit(totalWeight) + 1; + verify(tableLog <= MAX_TABLE_LOG, input, "Input is corrupted"); + + int total = 1 << tableLog; + int rest = total - totalWeight; + verify(isPowerOf2(rest), input, "Input is corrupted"); + + int lastWeight = Util.highestBit(rest) + 1; + + weights[outputSize] = (byte) lastWeight; + ranks[lastWeight]++; + + int numberOfSymbols = outputSize + 1; + + // populate table + int nextRankStart = 0; + for (int i = 1; i < tableLog + 1; ++i) { + int current = nextRankStart; + nextRankStart += ranks[i] << (i - 1); + ranks[i] = current; + } + + for (int n = 0; n < numberOfSymbols; n++) { + int weight = weights[n]; + int length = (1 << weight) >> 1; // TODO: 1 << (weight - 1) ?? + + byte symbol = (byte) n; + byte numberOfBits = (byte) (tableLog + 1 - weight); + for (int i = ranks[weight]; i < ranks[weight] + length; i++) { + symbols[i] = symbol; + numbersOfBits[i] = numberOfBits; + } + ranks[weight] += length; + } + + verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, "Input is corrupted"); + + return inputSize + 1; + } + + public void decodeSingleStream(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit) + { + BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, inputAddress, inputLimit); + initializer.initialize(); + + long bits = initializer.getBits(); + int bitsConsumed = initializer.getBitsConsumed(); + long currentAddress = initializer.getCurrentAddress(); + + int tableLog = this.tableLog; + byte[] numbersOfBits = this.numbersOfBits; + byte[] symbols = this.symbols; + + // 4 symbols at a time + long output = outputAddress; + long fastOutputLimit = outputLimit - 4; + while (output < fastOutputLimit) { + BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits, bitsConsumed); + boolean done = loader.load(); + bits = loader.getBits(); + bitsConsumed = loader.getBitsConsumed(); + currentAddress = loader.getCurrentAddress(); + if (done) { + break; + } + + bitsConsumed = decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + bitsConsumed = decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + bitsConsumed = decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + bitsConsumed = decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + output += SIZE_OF_INT; + } + + decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits, outputBase, output, outputLimit); + } + + public void decode4Streams(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit) + { + verify(inputLimit - inputAddress >= 10, inputAddress, "Input is corrupted"); // jump table + 1 byte per stream + + long start1 = inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below + long start2 = start1 + (UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF); + long start3 = start2 + (UNSAFE.getShort(inputBase, inputAddress + 2) & 0xFFFF); + long start4 = start3 + (UNSAFE.getShort(inputBase, inputAddress + 4) & 0xFFFF); + + BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, start1, start2); + initializer.initialize(); + int stream1bitsConsumed = initializer.getBitsConsumed(); + long stream1currentAddress = initializer.getCurrentAddress(); + long stream1bits = initializer.getBits(); + + initializer = new BitInputStream.Initializer(inputBase, start2, start3); + initializer.initialize(); + int stream2bitsConsumed = initializer.getBitsConsumed(); + long stream2currentAddress = initializer.getCurrentAddress(); + long stream2bits = initializer.getBits(); + + initializer = new BitInputStream.Initializer(inputBase, start3, start4); + initializer.initialize(); + int stream3bitsConsumed = initializer.getBitsConsumed(); + long stream3currentAddress = initializer.getCurrentAddress(); + long stream3bits = initializer.getBits(); + + initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit); + initializer.initialize(); + int stream4bitsConsumed = initializer.getBitsConsumed(); + long stream4currentAddress = initializer.getCurrentAddress(); + long stream4bits = initializer.getBits(); + + int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4); + + long outputStart2 = outputAddress + segmentSize; + long outputStart3 = outputStart2 + segmentSize; + long outputStart4 = outputStart3 + segmentSize; + + long output1 = outputAddress; + long output2 = outputStart2; + long output3 = outputStart3; + long output4 = outputStart4; + + long fastOutputLimit = outputLimit - 7; + int tableLog = this.tableLog; + byte[] numbersOfBits = this.numbersOfBits; + byte[] symbols = this.symbols; + + while (output4 < fastOutputLimit) { + stream1bitsConsumed = decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols); + stream2bitsConsumed = decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols); + stream3bitsConsumed = decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols); + stream4bitsConsumed = decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols); + + stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols); + stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols); + stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols); + stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols); + + stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols); + stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols); + stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols); + stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols); + + stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols); + stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols); + stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols); + stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols); + + output1 += SIZE_OF_INT; + output2 += SIZE_OF_INT; + output3 += SIZE_OF_INT; + output4 += SIZE_OF_INT; + + BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, start1, stream1currentAddress, stream1bits, stream1bitsConsumed); + boolean done = loader.load(); + stream1bitsConsumed = loader.getBitsConsumed(); + stream1bits = loader.getBits(); + stream1currentAddress = loader.getCurrentAddress(); + + if (done) { + break; + } + + loader = new BitInputStream.Loader(inputBase, start2, stream2currentAddress, stream2bits, stream2bitsConsumed); + done = loader.load(); + stream2bitsConsumed = loader.getBitsConsumed(); + stream2bits = loader.getBits(); + stream2currentAddress = loader.getCurrentAddress(); + + if (done) { + break; + } + + loader = new BitInputStream.Loader(inputBase, start3, stream3currentAddress, stream3bits, stream3bitsConsumed); + done = loader.load(); + stream3bitsConsumed = loader.getBitsConsumed(); + stream3bits = loader.getBits(); + stream3currentAddress = loader.getCurrentAddress(); + if (done) { + break; + } + + loader = new BitInputStream.Loader(inputBase, start4, stream4currentAddress, stream4bits, stream4bitsConsumed); + done = loader.load(); + stream4bitsConsumed = loader.getBitsConsumed(); + stream4bits = loader.getBits(); + stream4currentAddress = loader.getCurrentAddress(); + if (done) { + break; + } + } + + verify(output1 <= outputStart2 && output2 <= outputStart3 && output3 <= outputStart4, inputAddress, "Input is corrupted"); + + /// finish streams one by one + decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed, stream1bits, outputBase, output1, outputStart2); + decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed, stream2bits, outputBase, output2, outputStart3); + decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed, stream3bits, outputBase, output3, outputStart4); + decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed, stream4bits, outputBase, output4, outputLimit); + } + + private void decodeTail(final Object inputBase, final long startAddress, long currentAddress, int bitsConsumed, long bits, final Object outputBase, long outputAddress, final long outputLimit) + { + int tableLog = this.tableLog; + byte[] numbersOfBits = this.numbersOfBits; + byte[] symbols = this.symbols; + + // closer to the end + while (outputAddress < outputLimit) { + BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, startAddress, currentAddress, bits, bitsConsumed); + boolean done = loader.load(); + bitsConsumed = loader.getBitsConsumed(); + bits = loader.getBits(); + currentAddress = loader.getCurrentAddress(); + if (done) { + break; + } + + bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + } + + // not more data in bit stream, so no need to reload + while (outputAddress < outputLimit) { + bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols); + } + + verify(isEndOfStream(startAddress, currentAddress, bitsConsumed), startAddress, "Bit stream is not fully consumed"); + } + + private static int decodeSymbol(Object outputBase, long outputAddress, long bitContainer, int bitsConsumed, int tableLog, byte[] numbersOfBits, byte[] symbols) + { + int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog); + UNSAFE.putByte(outputBase, outputAddress, symbols[value]); + return bitsConsumed + numbersOfBits[value]; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/HuffmanCompressionContext.java b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionContext.java new file mode 100644 index 0000000..750f775 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionContext.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +class HuffmanCompressionContext +{ + private final HuffmanTableWriterWorkspace tableWriterWorkspace = new HuffmanTableWriterWorkspace(); + private final HuffmanCompressionTableWorkspace compressionTableWorkspace = new HuffmanCompressionTableWorkspace(); + + private HuffmanCompressionTable previousTable = new HuffmanCompressionTable(Huffman.MAX_SYMBOL_COUNT); + private HuffmanCompressionTable temporaryTable = new HuffmanCompressionTable(Huffman.MAX_SYMBOL_COUNT); + + private HuffmanCompressionTable previousCandidate = previousTable; + private HuffmanCompressionTable temporaryCandidate = temporaryTable; + + public HuffmanCompressionTable getPreviousTable() + { + return previousTable; + } + + public HuffmanCompressionTable borrowTemporaryTable() + { + previousCandidate = temporaryTable; + temporaryCandidate = previousTable; + + return temporaryTable; + } + + public void discardTemporaryTable() + { + previousCandidate = previousTable; + temporaryCandidate = temporaryTable; + } + + public void saveChanges() + { + temporaryTable = temporaryCandidate; + previousTable = previousCandidate; + } + + public HuffmanCompressionTableWorkspace getCompressionTableWorkspace() + { + return compressionTableWorkspace; + } + + public HuffmanTableWriterWorkspace getTableWriterWorkspace() + { + return tableWriterWorkspace; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTable.java b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTable.java new file mode 100644 index 0000000..5eef8fe --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTable.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +import static io.airlift.compress.zstd.Huffman.MAX_FSE_TABLE_LOG; +import static io.airlift.compress.zstd.Huffman.MAX_SYMBOL; +import static io.airlift.compress.zstd.Huffman.MAX_SYMBOL_COUNT; +import static io.airlift.compress.zstd.Huffman.MAX_TABLE_LOG; +import static io.airlift.compress.zstd.Huffman.MIN_TABLE_LOG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; +import static io.airlift.compress.zstd.Util.minTableLog; + +final class HuffmanCompressionTable +{ + private final short[] values; + private final byte[] numberOfBits; + + private int maxSymbol; + private int maxNumberOfBits; + + public HuffmanCompressionTable(int capacity) + { + this.values = new short[capacity]; + this.numberOfBits = new byte[capacity]; + } + + public static int optimalNumberOfBits(int maxNumberOfBits, int inputSize, int maxSymbol) + { + if (inputSize <= 1) { + throw new IllegalArgumentException(); // not supported. Use RLE instead + } + + int result = maxNumberOfBits; + + result = Math.min(result, Util.highestBit((inputSize - 1)) - 1); // we may be able to reduce accuracy if input is small + + // Need a minimum to safely represent all symbol values + result = Math.max(result, minTableLog(inputSize, maxSymbol)); + + result = Math.max(result, MIN_TABLE_LOG); // absolute minimum for Huffman + result = Math.min(result, MAX_TABLE_LOG); // absolute maximum for Huffman + + return result; + } + + public void initialize(int[] counts, int maxSymbol, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace) + { + checkArgument(maxSymbol <= MAX_SYMBOL, "Max symbol value too large"); + + workspace.reset(); + + NodeTable nodeTable = workspace.nodeTable; + nodeTable.reset(); + + int lastNonZero = buildTree(counts, maxSymbol, nodeTable); + + // enforce max table log + maxNumberOfBits = setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace); + checkArgument(maxNumberOfBits <= MAX_TABLE_LOG, "Max number of bits larger than max table size"); + + // populate table + int symbolCount = maxSymbol + 1; + for (int node = 0; node < symbolCount; node++) { + int symbol = nodeTable.symbols[node]; + numberOfBits[symbol] = nodeTable.numberOfBits[node]; + } + + short[] entriesPerRank = workspace.entriesPerRank; + short[] valuesPerRank = workspace.valuesPerRank; + + for (int n = 0; n <= lastNonZero; n++) { + entriesPerRank[nodeTable.numberOfBits[n]]++; + } + + // determine starting value per rank + short startingValue = 0; + for (int rank = maxNumberOfBits; rank > 0; rank--) { + valuesPerRank[rank] = startingValue; // get starting value within each rank + startingValue += entriesPerRank[rank]; + startingValue >>>= 1; + } + + for (int n = 0; n <= maxSymbol; n++) { + values[n] = valuesPerRank[numberOfBits[n]]++; // assign value within rank, symbol order + } + + this.maxSymbol = maxSymbol; + this.maxNumberOfBits = maxNumberOfBits; + } + + private int buildTree(int[] counts, int maxSymbol, NodeTable nodeTable) + { + // populate the leaves of the node table from the histogram of counts + // in descending order by count, ascending by symbol value. + short current = 0; + + for (int symbol = 0; symbol <= maxSymbol; symbol++) { + int count = counts[symbol]; + + // simple insertion sort + int position = current; + while (position > 1 && count > nodeTable.count[position - 1]) { + nodeTable.copyNode(position - 1, position); + position--; + } + + nodeTable.count[position] = count; + nodeTable.symbols[position] = symbol; + + current++; + } + + int lastNonZero = maxSymbol; + while (nodeTable.count[lastNonZero] == 0) { + lastNonZero--; + } + + // populate the non-leaf nodes + short nonLeafStart = MAX_SYMBOL_COUNT; + current = nonLeafStart; + + int currentLeaf = lastNonZero; + + // combine the two smallest leaves to create the first intermediate node + int currentNonLeaf = current; + nodeTable.count[current] = nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1]; + nodeTable.parents[currentLeaf] = current; + nodeTable.parents[currentLeaf - 1] = current; + current++; + currentLeaf -= 2; + + int root = MAX_SYMBOL_COUNT + lastNonZero - 1; + + // fill in sentinels + for (int n = current; n <= root; n++) { + nodeTable.count[n] = 1 << 30; + } + + // create parents + while (current <= root) { + int child1; + if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) { + child1 = currentLeaf--; + } + else { + child1 = currentNonLeaf++; + } + + int child2; + if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) { + child2 = currentLeaf--; + } + else { + child2 = currentNonLeaf++; + } + + nodeTable.count[current] = nodeTable.count[child1] + nodeTable.count[child2]; + nodeTable.parents[child1] = current; + nodeTable.parents[child2] = current; + current++; + } + + // distribute weights + nodeTable.numberOfBits[root] = 0; + for (int n = root - 1; n >= nonLeafStart; n--) { + short parent = nodeTable.parents[n]; + nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1); + } + + for (int n = 0; n <= lastNonZero; n++) { + short parent = nodeTable.parents[n]; + nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1); + } + + return lastNonZero; + } + + // TODO: consider encoding 2 symbols at a time + // - need a table with 256x256 entries with + // - the concatenated bits for the corresponding pair of symbols + // - the sum of bits for the corresponding pair of symbols + // - read 2 symbols at a time from the input + public void encodeSymbol(BitOutputStream output, int symbol) + { + output.addBitsFast(values[symbol], numberOfBits[symbol]); + } + + public int write(Object outputBase, long outputAddress, int outputSize, HuffmanTableWriterWorkspace workspace) + { + byte[] weights = workspace.weights; + + long output = outputAddress; + + int maxNumberOfBits = this.maxNumberOfBits; + int maxSymbol = this.maxSymbol; + + // convert to weights per RFC 8478 section 4.2.1 + for (int symbol = 0; symbol < maxSymbol; symbol++) { + int bits = numberOfBits[symbol]; + + if (bits == 0) { + weights[symbol] = 0; + } + else { + weights[symbol] = (byte) (maxNumberOfBits + 1 - bits); + } + } + + // attempt weights compression by FSE + int size = compressWeights(outputBase, output + 1, outputSize - 1, weights, maxSymbol, workspace); + + if (maxSymbol > 127 && size > 127) { + // This should never happen. Since weights are in the range [0, 12], they can be compressed optimally to ~3.7 bits per symbol for a uniform distribution. + // Since maxSymbol has to be <= MAX_SYMBOL (255), this is 119 bytes + FSE headers. + throw new AssertionError(); + } + + if (size != 0 && size != 1 && size < maxSymbol / 2) { + // Go with FSE only if: + // - the weights are compressible + // - the compressed size is better than what we'd get with the raw encoding below + // - the compressed size is <= 127 bytes, which is the most that the encoding can hold for FSE-compressed weights (see RFC 8478 section 4.2.1.1). This is implied + // by the maxSymbol / 2 check, since maxSymbol must be <= 255 + UNSAFE.putByte(outputBase, output, (byte) size); + return size + 1; // header + size + } + else { + // Use raw encoding (4 bits per entry) + + // #entries = #symbols - 1 since last symbol is implicit. Thus, #entries = (maxSymbol + 1) - 1 = maxSymbol + int entryCount = maxSymbol; + + size = (entryCount + 1) / 2; // ceil(#entries / 2) + checkArgument(size + 1 /* header */ <= outputSize, "Output size too small"); // 2 entries per byte + + // encode number of symbols + // header = #entries + 127 per RFC + UNSAFE.putByte(outputBase, output, (byte) (127 + entryCount)); + output++; + + weights[maxSymbol] = 0; // last weight is implicit, so set to 0 so that it doesn't get encoded below + for (int i = 0; i < entryCount; i += 2) { + UNSAFE.putByte(outputBase, output, (byte) ((weights[i] << 4) + weights[i + 1])); + output++; + } + + return (int) (output - outputAddress); + } + } + + /** + * Can this table encode all symbols with non-zero count? + */ + public boolean isValid(int[] counts, int maxSymbol) + { + if (maxSymbol > this.maxSymbol) { + // some non-zero count symbols cannot be encoded by the current table + return false; + } + + for (int symbol = 0; symbol <= maxSymbol; ++symbol) { + if (counts[symbol] != 0 && numberOfBits[symbol] == 0) { + return false; + } + } + return true; + } + + public int estimateCompressedSize(int[] counts, int maxSymbol) + { + int numberOfBits = 0; + for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol); symbol++) { + numberOfBits += this.numberOfBits[symbol] * counts[symbol]; + } + + return numberOfBits >>> 3; // convert to bytes + } + + // http://fastcompression.blogspot.com/2015/07/huffman-revisited-part-3-depth-limited.html + private static int setMaxHeight(NodeTable nodeTable, int lastNonZero, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace) + { + int largestBits = nodeTable.numberOfBits[lastNonZero]; + + if (largestBits <= maxNumberOfBits) { + return largestBits; // early exit: no elements > maxNumberOfBits + } + + // there are several too large elements (at least >= 2) + int totalCost = 0; + int baseCost = 1 << (largestBits - maxNumberOfBits); + int n = lastNonZero; + + while (nodeTable.numberOfBits[n] > maxNumberOfBits) { + totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n])); + nodeTable.numberOfBits[n ] = (byte) maxNumberOfBits; + n--; + } // n stops at nodeTable.numberOfBits[n + offset] <= maxNumberOfBits + + while (nodeTable.numberOfBits[n] == maxNumberOfBits) { + n--; // n ends at index of smallest symbol using < maxNumberOfBits + } + + // renormalize totalCost + totalCost >>>= (largestBits - maxNumberOfBits); // note: totalCost is necessarily a multiple of baseCost + + // repay normalized cost + int noSymbol = 0xF0F0F0F0; + int[] rankLast = workspace.rankLast; + Arrays.fill(rankLast, noSymbol); + + // Get pos of last (smallest) symbol per rank + int currentNbBits = maxNumberOfBits; + for (int pos = n; pos >= 0; pos--) { + if (nodeTable.numberOfBits[pos] >= currentNbBits) { + continue; + } + currentNbBits = nodeTable.numberOfBits[pos]; // < maxNumberOfBits + rankLast[maxNumberOfBits - currentNbBits] = pos; + } + + while (totalCost > 0) { + int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1; + for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) { + int highPosition = rankLast[numberOfBitsToDecrease]; + int lowPosition = rankLast[numberOfBitsToDecrease - 1]; + if (highPosition == noSymbol) { + continue; + } + if (lowPosition == noSymbol) { + break; + } + int highTotal = nodeTable.count[highPosition]; + int lowTotal = 2 * nodeTable.count[lowPosition]; + if (highTotal <= lowTotal) { + break; + } + } + + // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !) + // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary + while ((numberOfBitsToDecrease <= MAX_TABLE_LOG) && (rankLast[numberOfBitsToDecrease] == noSymbol)) { + numberOfBitsToDecrease++; + } + totalCost -= 1 << (numberOfBitsToDecrease - 1); + if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) { + rankLast[numberOfBitsToDecrease - 1] = rankLast[numberOfBitsToDecrease]; // this rank is no longer empty + } + nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++; + if (rankLast[numberOfBitsToDecrease] == 0) { /* special case, reached largest symbol */ + rankLast[numberOfBitsToDecrease] = noSymbol; + } + else { + rankLast[numberOfBitsToDecrease]--; + if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] != maxNumberOfBits - numberOfBitsToDecrease) { + rankLast[numberOfBitsToDecrease] = noSymbol; // this rank is now empty + } + } + } + + while (totalCost < 0) { // Sometimes, cost correction overshoot + if (rankLast[1] == noSymbol) { /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */ + while (nodeTable.numberOfBits[n] == maxNumberOfBits) { + n--; + } + nodeTable.numberOfBits[n + 1]--; + rankLast[1] = n + 1; + totalCost++; + continue; + } + nodeTable.numberOfBits[rankLast[1] + 1]--; + rankLast[1]++; + totalCost++; + } + + return maxNumberOfBits; + } + + /** + * All elements within weightTable must be <= Huffman.MAX_TABLE_LOG + */ + private static int compressWeights(Object outputBase, long outputAddress, int outputSize, byte[] weights, int weightsLength, HuffmanTableWriterWorkspace workspace) + { + if (weightsLength <= 1) { + return 0; // Not compressible + } + + // Scan input and build symbol stats + int[] counts = workspace.counts; + Histogram.count(weights, weightsLength, counts); + int maxSymbol = Histogram.findMaxSymbol(counts, MAX_TABLE_LOG); + int maxCount = Histogram.findLargestCount(counts, maxSymbol); + + if (maxCount == weightsLength) { + return 1; // only a single symbol in source + } + if (maxCount == 1) { + return 0; // each symbol present maximum once => not compressible + } + + short[] normalizedCounts = workspace.normalizedCounts; + + int tableLog = FiniteStateEntropy.optimalTableLog(MAX_FSE_TABLE_LOG, weightsLength, maxSymbol); + FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, weightsLength, maxSymbol); + + long output = outputAddress; + long outputLimit = outputAddress + outputSize; + + // Write table description header + int headerSize = FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize, normalizedCounts, maxSymbol, tableLog); + output += headerSize; + + // Compress + FseCompressionTable compressionTable = workspace.fseTable; + compressionTable.initialize(normalizedCounts, maxSymbol, tableLog); + int compressedSize = FiniteStateEntropy.compress(outputBase, output, (int) (outputLimit - output), weights, weightsLength, compressionTable); + if (compressedSize == 0) { + return 0; + } + output += compressedSize; + + return (int) (output - outputAddress); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTableWorkspace.java b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTableWorkspace.java new file mode 100644 index 0000000..1969020 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/HuffmanCompressionTableWorkspace.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +class HuffmanCompressionTableWorkspace +{ + public final NodeTable nodeTable = new NodeTable((2 * Huffman.MAX_SYMBOL_COUNT - 1)); // number of nodes in binary tree with MAX_SYMBOL_COUNT leaves + + public final short[] entriesPerRank = new short[Huffman.MAX_TABLE_LOG + 1]; + public final short[] valuesPerRank = new short[Huffman.MAX_TABLE_LOG + 1]; + + // for setMaxHeight + public final int[] rankLast = new int[Huffman.MAX_TABLE_LOG + 2]; + + public void reset() + { + Arrays.fill(entriesPerRank, (short) 0); + Arrays.fill(valuesPerRank, (short) 0); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/HuffmanCompressor.java b/src/main/java/io/airlift/compress/zstd/HuffmanCompressor.java new file mode 100644 index 0000000..b1ab679 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/HuffmanCompressor.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; + +class HuffmanCompressor +{ + private HuffmanCompressor() + { + } + + public static int compress4streams(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, HuffmanCompressionTable table) + { + long input = inputAddress; + long inputLimit = inputAddress + inputSize; + long output = outputAddress; + long outputLimit = outputAddress + outputSize; + + int segmentSize = (inputSize + 3) / 4; + + if (outputSize < 6 /* jump table */ + 1 /* first stream */ + 1 /* second stream */ + 1 /* third stream */ + 8 /* 8 bytes minimum needed by the bitstream encoder */) { + return 0; // minimum space to compress successfully + } + + if (inputSize <= 6 + 1 + 1 + 1) { // jump table + one byte per stream + return 0; // no saving possible: input too small + } + + output += SIZE_OF_SHORT + SIZE_OF_SHORT + SIZE_OF_SHORT; // jump table + + int compressedSize; + + // first segment + compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table); + if (compressedSize == 0) { + return 0; + } + UNSAFE.putShort(outputBase, outputAddress, (short) compressedSize); + output += compressedSize; + input += segmentSize; + + // second segment + compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table); + if (compressedSize == 0) { + return 0; + } + UNSAFE.putShort(outputBase, outputAddress + SIZE_OF_SHORT, (short) compressedSize); + output += compressedSize; + input += segmentSize; + + // third segment + compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table); + if (compressedSize == 0) { + return 0; + } + UNSAFE.putShort(outputBase, outputAddress + SIZE_OF_SHORT + SIZE_OF_SHORT, (short) compressedSize); + output += compressedSize; + input += segmentSize; + + // fourth segment + compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, (int) (inputLimit - input), table); + if (compressedSize == 0) { + return 0; + } + output += compressedSize; + + return (int) (output - outputAddress); + } + + public static int compressSingleStream(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, HuffmanCompressionTable table) + { + if (outputSize < SIZE_OF_LONG) { + return 0; + } + + BitOutputStream bitstream = new BitOutputStream(outputBase, outputAddress, outputSize); + long input = inputAddress; + + int n = inputSize & ~3; // join to mod 4 + + switch (inputSize & 3) { + case 3: + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 2) & 0xFF); + if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 4 + 7) { + bitstream.flush(); + } + // fall-through + case 2: + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 1) & 0xFF); + if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) { + bitstream.flush(); + } + // fall-through + case 1: + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 0) & 0xFF); + bitstream.flush(); + // fall-through + case 0: /* fall-through */ + default: + break; + } + + for (; n > 0; n -= 4) { // note: n & 3 == 0 at this stage + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 1) & 0xFF); + if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) { + bitstream.flush(); + } + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 2) & 0xFF); + if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 4 + 7) { + bitstream.flush(); + } + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 3) & 0xFF); + if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) { + bitstream.flush(); + } + table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 4) & 0xFF); + bitstream.flush(); + } + + return bitstream.close(); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/HuffmanTableWriterWorkspace.java b/src/main/java/io/airlift/compress/zstd/HuffmanTableWriterWorkspace.java new file mode 100644 index 0000000..ebbef76 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/HuffmanTableWriterWorkspace.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Huffman.MAX_FSE_TABLE_LOG; +import static io.airlift.compress.zstd.Huffman.MAX_SYMBOL; +import static io.airlift.compress.zstd.Huffman.MAX_TABLE_LOG; + +class HuffmanTableWriterWorkspace +{ + // for encoding weights + public final byte[] weights = new byte[MAX_SYMBOL]; // the weight for the last symbol is implicit + + // for compressing weights + public final int[] counts = new int[MAX_TABLE_LOG + 1]; + public final short[] normalizedCounts = new short[MAX_TABLE_LOG + 1]; + public final FseCompressionTable fseTable = new FseCompressionTable(MAX_FSE_TABLE_LOG, MAX_TABLE_LOG); +} diff --git a/src/main/java/io/airlift/compress/zstd/NodeTable.java b/src/main/java/io/airlift/compress/zstd/NodeTable.java new file mode 100644 index 0000000..acdc3b1 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/NodeTable.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.util.Arrays; + +class NodeTable +{ + int[] count; + short[] parents; + int[] symbols; + byte[] numberOfBits; + + public NodeTable(int size) + { + count = new int[size]; + parents = new short[size]; + symbols = new int[size]; + numberOfBits = new byte[size]; + } + + public void reset() + { + Arrays.fill(count, 0); + Arrays.fill(parents, (short) 0); + Arrays.fill(symbols, 0); + Arrays.fill(numberOfBits, (byte) 0); + } + + public void copyNode(int from, int to) + { + count[to] = count[from]; + parents[to] = parents[from]; + symbols[to] = symbols[from]; + numberOfBits[to] = numberOfBits[from]; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/RepeatedOffsets.java b/src/main/java/io/airlift/compress/zstd/RepeatedOffsets.java new file mode 100644 index 0000000..969c038 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/RepeatedOffsets.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +class RepeatedOffsets +{ + private int offset0 = 1; + private int offset1 = 4; + + private int tempOffset0; + private int tempOffset1; + + public int getOffset0() + { + return offset0; + } + + public int getOffset1() + { + return offset1; + } + + public void saveOffset0(int offset) + { + tempOffset0 = offset; + } + + public void saveOffset1(int offset) + { + tempOffset1 = offset; + } + + public void commit() + { + offset0 = tempOffset0; + offset1 = tempOffset1; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/SequenceEncoder.java b/src/main/java/io/airlift/compress/zstd/SequenceEncoder.java new file mode 100644 index 0000000..18b6601 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/SequenceEncoder.java @@ -0,0 +1,352 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.DEFAULT_MAX_OFFSET_CODE_SYMBOL; +import static io.airlift.compress.zstd.Constants.LITERALS_LENGTH_BITS; +import static io.airlift.compress.zstd.Constants.LITERAL_LENGTH_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.LONG_NUMBER_OF_SEQUENCES; +import static io.airlift.compress.zstd.Constants.MATCH_LENGTH_BITS; +import static io.airlift.compress.zstd.Constants.MATCH_LENGTH_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MAX_OFFSET_CODE_SYMBOL; +import static io.airlift.compress.zstd.Constants.OFFSET_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_BASIC; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_COMPRESSED; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_RLE; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.FiniteStateEntropy.optimalTableLog; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; + +class SequenceEncoder +{ + private static final int DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG = 6; + private static final short[] DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS = {4, 3, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 2, 1, 1, 1, 1, 1, + -1, -1, -1, -1}; + + private static final int DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG = 6; + private static final short[] DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS = {1, 4, 3, 2, 2, 2, 2, 2, + 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, + -1, -1, -1, -1, -1}; + + private static final int DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG = 5; + private static final short[] DEFAULT_OFFSET_NORMALIZED_COUNTS = {1, 1, 1, 1, 1, 1, 2, 2, + 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + -1, -1, -1, -1, -1}; + + private static final FseCompressionTable DEFAULT_LITERAL_LENGTHS_TABLE = FseCompressionTable.newInstance(DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS, MAX_LITERALS_LENGTH_SYMBOL, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG); + private static final FseCompressionTable DEFAULT_MATCH_LENGTHS_TABLE = FseCompressionTable.newInstance(DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS, MAX_MATCH_LENGTH_SYMBOL, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG); + private static final FseCompressionTable DEFAULT_OFFSETS_TABLE = FseCompressionTable.newInstance(DEFAULT_OFFSET_NORMALIZED_COUNTS, DEFAULT_MAX_OFFSET_CODE_SYMBOL, DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG); + + private SequenceEncoder() + { + } + + public static int compressSequences(Object outputBase, final long outputAddress, int outputSize, SequenceStore sequences, CompressionParameters.Strategy strategy, SequenceEncodingContext workspace) + { + long output = outputAddress; + long outputLimit = outputAddress + outputSize; + + checkArgument(outputLimit - output > 3 /* max sequence count Size */ + 1 /* encoding type flags */, "Output buffer too small"); + + int sequenceCount = sequences.sequenceCount; + if (sequenceCount < 0x7F) { + UNSAFE.putByte(outputBase, output, (byte) sequenceCount); + output++; + } + else if (sequenceCount < LONG_NUMBER_OF_SEQUENCES) { + UNSAFE.putByte(outputBase, output, (byte) (sequenceCount >>> 8 | 0x80)); + UNSAFE.putByte(outputBase, output + 1, (byte) sequenceCount); + output += SIZE_OF_SHORT; + } + else { + UNSAFE.putByte(outputBase, output, (byte) 0xFF); + output++; + UNSAFE.putShort(outputBase, output, (short) (sequenceCount - LONG_NUMBER_OF_SEQUENCES)); + output += SIZE_OF_SHORT; + } + + if (sequenceCount == 0) { + return (int) (output - outputAddress); + } + + // flags for FSE encoding type + long headerAddress = output++; + + int maxSymbol; + int largestCount; + + // literal lengths + int[] counts = workspace.counts; + Histogram.count(sequences.literalLengthCodes, sequenceCount, workspace.counts); + maxSymbol = Histogram.findMaxSymbol(counts, MAX_LITERALS_LENGTH_SYMBOL); + largestCount = Histogram.findLargestCount(counts, maxSymbol); + + int literalsLengthEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG, true, strategy); + + FseCompressionTable literalLengthTable; + switch (literalsLengthEncodingType) { + case SEQUENCE_ENCODING_RLE: + UNSAFE.putByte(outputBase, output, sequences.literalLengthCodes[0]); + output++; + workspace.literalLengthTable.initializeRleTable(maxSymbol); + literalLengthTable = workspace.literalLengthTable; + break; + case SEQUENCE_ENCODING_BASIC: + literalLengthTable = DEFAULT_LITERAL_LENGTHS_TABLE; + break; + case SEQUENCE_ENCODING_COMPRESSED: + output += buildCompressionTable( + workspace.literalLengthTable, + outputBase, + output, + outputLimit, + sequenceCount, + LITERAL_LENGTH_TABLE_LOG, + sequences.literalLengthCodes, + workspace.counts, + maxSymbol, + workspace.normalizedCounts); + literalLengthTable = workspace.literalLengthTable; + break; + default: + throw new UnsupportedOperationException("not yet implemented"); + } + + // offsets + Histogram.count(sequences.offsetCodes, sequenceCount, workspace.counts); + maxSymbol = Histogram.findMaxSymbol(counts, MAX_OFFSET_CODE_SYMBOL); + largestCount = Histogram.findLargestCount(counts, maxSymbol); + + // We can only use the basic table if max <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, otherwise the offsets are too large . + boolean defaultAllowed = maxSymbol < DEFAULT_MAX_OFFSET_CODE_SYMBOL; + + int offsetEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG, defaultAllowed, strategy); + + FseCompressionTable offsetCodeTable; + switch (offsetEncodingType) { + case SEQUENCE_ENCODING_RLE: + UNSAFE.putByte(outputBase, output, sequences.offsetCodes[0]); + output++; + workspace.offsetCodeTable.initializeRleTable(maxSymbol); + offsetCodeTable = workspace.offsetCodeTable; + break; + case SEQUENCE_ENCODING_BASIC: + offsetCodeTable = DEFAULT_OFFSETS_TABLE; + break; + case SEQUENCE_ENCODING_COMPRESSED: + output += buildCompressionTable( + workspace.offsetCodeTable, + outputBase, + output, + output + outputSize, + sequenceCount, + OFFSET_TABLE_LOG, + sequences.offsetCodes, + workspace.counts, + maxSymbol, + workspace.normalizedCounts); + offsetCodeTable = workspace.offsetCodeTable; + break; + default: + throw new UnsupportedOperationException("not yet implemented"); + } + + // match lengths + Histogram.count(sequences.matchLengthCodes, sequenceCount, workspace.counts); + maxSymbol = Histogram.findMaxSymbol(counts, MAX_MATCH_LENGTH_SYMBOL); + largestCount = Histogram.findLargestCount(counts, maxSymbol); + + int matchLengthEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG, true, strategy); + + FseCompressionTable matchLengthTable; + switch (matchLengthEncodingType) { + case SEQUENCE_ENCODING_RLE: + UNSAFE.putByte(outputBase, output, sequences.matchLengthCodes[0]); + output++; + workspace.matchLengthTable.initializeRleTable(maxSymbol); + matchLengthTable = workspace.matchLengthTable; + break; + case SEQUENCE_ENCODING_BASIC: + matchLengthTable = DEFAULT_MATCH_LENGTHS_TABLE; + break; + case SEQUENCE_ENCODING_COMPRESSED: + output += buildCompressionTable( + workspace.matchLengthTable, + outputBase, + output, + outputLimit, + sequenceCount, + MATCH_LENGTH_TABLE_LOG, + sequences.matchLengthCodes, + workspace.counts, + maxSymbol, + workspace.normalizedCounts); + matchLengthTable = workspace.matchLengthTable; + break; + default: + throw new UnsupportedOperationException("not yet implemented"); + } + + // flags + UNSAFE.putByte(outputBase, headerAddress, (byte) ((literalsLengthEncodingType << 6) | (offsetEncodingType << 4) | (matchLengthEncodingType << 2))); + + output += encodeSequences(outputBase, output, outputLimit, matchLengthTable, offsetCodeTable, literalLengthTable, sequences); + + return (int) (output - outputAddress); + } + + private static int buildCompressionTable(FseCompressionTable table, Object outputBase, long output, long outputLimit, int sequenceCount, int maxTableLog, byte[] codes, int[] counts, int maxSymbol, short[] normalizedCounts) + { + int tableLog = optimalTableLog(maxTableLog, sequenceCount, maxSymbol); + + // this is a minor optimization. The last symbol is embedded in the initial FSE state, so it's not part of the bitstream. We can omit it from the + // statistics (but only if its count is > 1). This makes the statistics a tiny bit more accurate. + if (counts[codes[sequenceCount - 1]] > 1) { + counts[codes[sequenceCount - 1]]--; + sequenceCount--; + } + + FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, sequenceCount, maxSymbol); + table.initialize(normalizedCounts, maxSymbol, tableLog); + + return FiniteStateEntropy.writeNormalizedCounts(outputBase, output, (int) (outputLimit - output), normalizedCounts, maxSymbol, tableLog); // TODO: pass outputLimit directly + } + + private static int encodeSequences( + Object outputBase, + long output, + long outputLimit, + FseCompressionTable matchLengthTable, + FseCompressionTable offsetsTable, + FseCompressionTable literalLengthTable, + SequenceStore sequences) + { + byte[] matchLengthCodes = sequences.matchLengthCodes; + byte[] offsetCodes = sequences.offsetCodes; + byte[] literalLengthCodes = sequences.literalLengthCodes; + + BitOutputStream blockStream = new BitOutputStream(outputBase, output, (int) (outputLimit - output)); + + int sequenceCount = sequences.sequenceCount; + + // first symbols + int matchLengthState = matchLengthTable.begin(matchLengthCodes[sequenceCount - 1]); + int offsetState = offsetsTable.begin(offsetCodes[sequenceCount - 1]); + int literalLengthState = literalLengthTable.begin(literalLengthCodes[sequenceCount - 1]); + + blockStream.addBits(sequences.literalLengths[sequenceCount - 1], LITERALS_LENGTH_BITS[literalLengthCodes[sequenceCount - 1]]); + blockStream.addBits(sequences.matchLengths[sequenceCount - 1], MATCH_LENGTH_BITS[matchLengthCodes[sequenceCount - 1]]); + blockStream.addBits(sequences.offsets[sequenceCount - 1], offsetCodes[sequenceCount - 1]); + blockStream.flush(); + + if (sequenceCount >= 2) { + for (int n = sequenceCount - 2; n >= 0; n--) { + byte literalLengthCode = literalLengthCodes[n]; + byte offsetCode = offsetCodes[n]; + byte matchLengthCode = matchLengthCodes[n]; + + int literalLengthBits = LITERALS_LENGTH_BITS[literalLengthCode]; + int offsetBits = offsetCode; + int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode]; + + // (7) + offsetState = offsetsTable.encode(blockStream, offsetState, offsetCode); // 15 + matchLengthState = matchLengthTable.encode(blockStream, matchLengthState, matchLengthCode); // 24 + literalLengthState = literalLengthTable.encode(blockStream, literalLengthState, literalLengthCode); // 33 + + if ((offsetBits + matchLengthBits + literalLengthBits >= 64 - 7 - (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG + OFFSET_TABLE_LOG))) { + blockStream.flush(); /* (7)*/ + } + + blockStream.addBits(sequences.literalLengths[n], literalLengthBits); + if (((literalLengthBits + matchLengthBits) > 24)) { + blockStream.flush(); + } + + blockStream.addBits(sequences.matchLengths[n], matchLengthBits); + if ((offsetBits + matchLengthBits + literalLengthBits > 56)) { + blockStream.flush(); + } + + blockStream.addBits(sequences.offsets[n], offsetBits); // 31 + blockStream.flush(); // (7) + } + } + + matchLengthTable.finish(blockStream, matchLengthState); + offsetsTable.finish(blockStream, offsetState); + literalLengthTable.finish(blockStream, literalLengthState); + + int streamSize = blockStream.close(); + checkArgument(streamSize > 0, "Output buffer too small"); + + return streamSize; + } + + private static int selectEncodingType( + int largestCount, + int sequenceCount, + int defaultNormalizedCountsLog, + boolean isDefaultTableAllowed, + CompressionParameters.Strategy strategy) + { + if (largestCount == sequenceCount) { // => all entries are equal + if (isDefaultTableAllowed && sequenceCount <= 2) { + /* Prefer set_basic over set_rle when there are 2 or fewer symbols, + * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol. + * If basic encoding isn't possible, always choose RLE. + */ + return SEQUENCE_ENCODING_BASIC; + } + + return SEQUENCE_ENCODING_RLE; + } + + if (strategy.ordinal() < CompressionParameters.Strategy.LAZY.ordinal()) { // TODO: more robust check. Maybe encapsulate in strategy objects + if (isDefaultTableAllowed) { + int factor = 10 - strategy.ordinal(); // TODO more robust. Move it to strategy + int baseLog = 3; + long minNumberOfSequences = ((1L << defaultNormalizedCountsLog) * factor) >> baseLog; /* 28-36 for offset, 56-72 for lengths */ + + if ((sequenceCount < minNumberOfSequences) || (largestCount < (sequenceCount >> (defaultNormalizedCountsLog - 1)))) { + /* The format allows default tables to be repeated, but it isn't useful. + * When using simple heuristics to select encoding type, we don't want + * to confuse these tables with dictionaries. When running more careful + * analysis, we don't need to waste time checking both repeating tables + * and default tables. + */ + return SEQUENCE_ENCODING_BASIC; + } + } + } + else { + // TODO implement when other strategies are supported + throw new UnsupportedOperationException("not yet implemented"); + } + + return SEQUENCE_ENCODING_COMPRESSED; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/SequenceEncodingContext.java b/src/main/java/io/airlift/compress/zstd/SequenceEncodingContext.java new file mode 100644 index 0000000..663650e --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/SequenceEncodingContext.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MAX_OFFSET_CODE_SYMBOL; + +class SequenceEncodingContext +{ + private static final int MAX_SEQUENCES = Math.max(MAX_LITERALS_LENGTH_SYMBOL, MAX_MATCH_LENGTH_SYMBOL); + + public final FseCompressionTable literalLengthTable = new FseCompressionTable(Constants.LITERAL_LENGTH_TABLE_LOG, MAX_LITERALS_LENGTH_SYMBOL); + public final FseCompressionTable offsetCodeTable = new FseCompressionTable(Constants.OFFSET_TABLE_LOG, MAX_OFFSET_CODE_SYMBOL); + public final FseCompressionTable matchLengthTable = new FseCompressionTable(Constants.MATCH_LENGTH_TABLE_LOG, MAX_MATCH_LENGTH_SYMBOL); + + public final int[] counts = new int[MAX_SEQUENCES + 1]; + public final short[] normalizedCounts = new short[MAX_SEQUENCES + 1]; +} diff --git a/src/main/java/io/airlift/compress/zstd/SequenceStore.java b/src/main/java/io/airlift/compress/zstd/SequenceStore.java new file mode 100644 index 0000000..39ee9ca --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/SequenceStore.java @@ -0,0 +1,161 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +class SequenceStore +{ + public final byte[] literalsBuffer; + public int literalsLength; + + public final int[] offsets; + public final int[] literalLengths; + public final int[] matchLengths; + public int sequenceCount; + + public final byte[] literalLengthCodes; + public final byte[] matchLengthCodes; + public final byte[] offsetCodes; + + public LongField longLengthField; + public int longLengthPosition; + + public enum LongField + { + LITERAL, MATCH + } + + private static final byte[] LITERAL_LENGTH_CODE = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 17, 17, 18, 18, 19, 19, + 20, 20, 20, 20, 21, 21, 21, 21, + 22, 22, 22, 22, 22, 22, 22, 22, + 23, 23, 23, 23, 23, 23, 23, 23, + 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24}; + + private static final byte[] MATCH_LENGTH_CODE = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 36, 36, 37, 37, 37, 37, + 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, + 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, + 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, + 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42}; + + public SequenceStore(int blockSize, int maxSequences) + { + offsets = new int[maxSequences]; + literalLengths = new int[maxSequences]; + matchLengths = new int[maxSequences]; + + literalLengthCodes = new byte[maxSequences]; + matchLengthCodes = new byte[maxSequences]; + offsetCodes = new byte[maxSequences]; + + literalsBuffer = new byte[blockSize]; + + reset(); + } + + public void appendLiterals(Object inputBase, long inputAddress, int inputSize) + { + UNSAFE.copyMemory(inputBase, inputAddress, literalsBuffer, ARRAY_BYTE_BASE_OFFSET + literalsLength, inputSize); + literalsLength += inputSize; + } + + public void storeSequence(Object literalBase, long literalAddress, int literalLength, int offsetCode, int matchLengthBase) + { + long input = literalAddress; + long output = ARRAY_BYTE_BASE_OFFSET + literalsLength; + int copied = 0; + do { + UNSAFE.putLong(literalsBuffer, output, UNSAFE.getLong(literalBase, input)); + input += SIZE_OF_LONG; + output += SIZE_OF_LONG; + copied += SIZE_OF_LONG; + } + while (copied < literalLength); + + literalsLength += literalLength; + + if (literalLength > 65535) { + longLengthField = LongField.LITERAL; + longLengthPosition = sequenceCount; + } + literalLengths[sequenceCount] = literalLength; + + offsets[sequenceCount] = offsetCode + 1; + + if (matchLengthBase > 65535) { + longLengthField = LongField.MATCH; + longLengthPosition = sequenceCount; + } + + matchLengths[sequenceCount] = matchLengthBase; + + sequenceCount++; + } + + public void reset() + { + literalsLength = 0; + sequenceCount = 0; + longLengthField = null; + } + + public void generateCodes() + { + for (int i = 0; i < sequenceCount; ++i) { + literalLengthCodes[i] = (byte) literalLengthToCode(literalLengths[i]); + offsetCodes[i] = (byte) Util.highestBit(offsets[i]); + matchLengthCodes[i] = (byte) matchLengthToCode(matchLengths[i]); + } + + if (longLengthField == LongField.LITERAL) { + literalLengthCodes[longLengthPosition] = Constants.MAX_LITERALS_LENGTH_SYMBOL; + } + if (longLengthField == LongField.MATCH) { + matchLengthCodes[longLengthPosition] = Constants.MAX_MATCH_LENGTH_SYMBOL; + } + } + + private static int literalLengthToCode(int literalLength) + { + if (literalLength >= 64) { + return Util.highestBit(literalLength) + 19; + } + else { + return LITERAL_LENGTH_CODE[literalLength]; + } + } + + /* + * matchLengthBase = matchLength - MINMATCH + * (that's how it's stored in SequenceStore) + */ + private static int matchLengthToCode(int matchLengthBase) + { + if (matchLengthBase >= 128) { + return Util.highestBit(matchLengthBase) + 36; + } + else { + return MATCH_LENGTH_CODE[matchLengthBase]; + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/UnsafeUtil.java b/src/main/java/io/airlift/compress/zstd/UnsafeUtil.java new file mode 100644 index 0000000..a083bc4 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/UnsafeUtil.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.IncompatibleJvmException; +import sun.misc.Unsafe; + +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteOrder; + +import static java.lang.String.format; + +final class UnsafeUtil +{ + public static final Unsafe UNSAFE; + private static final long ADDRESS_OFFSET; + + private UnsafeUtil() {} + + static { + ByteOrder order = ByteOrder.nativeOrder(); + if (!order.equals(ByteOrder.LITTLE_ENDIAN)) { + throw new IncompatibleJvmException(format("Zstandard requires a little endian platform (found %s)", order)); + } + + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + UNSAFE = (Unsafe) theUnsafe.get(null); + } + catch (Exception e) { + throw new IncompatibleJvmException("Zstandard requires access to sun.misc.Unsafe"); + } + + try { + // fetch the address field for direct buffers + ADDRESS_OFFSET = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address")); + } + catch (NoSuchFieldException e) { + throw new IncompatibleJvmException("Zstandard requires access to java.nio.Buffer raw address field"); + } + } + + public static long getAddress(Buffer buffer) + { + if (!buffer.isDirect()) { + throw new IllegalArgumentException("buffer is not direct"); + } + + return UNSAFE.getLong(buffer, ADDRESS_OFFSET); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/Util.java b/src/main/java/io/airlift/compress/zstd/Util.java new file mode 100644 index 0000000..9cb1cd9 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/Util.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.MalformedInputException; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; + +final class Util +{ + private Util() + { + } + + public static int highestBit(int value) + { + return 31 - Integer.numberOfLeadingZeros(value); + } + + public static boolean isPowerOf2(int value) + { + return (value & (value - 1)) == 0; + } + + public static int mask(int bits) + { + return (1 << bits) - 1; + } + + public static void verify(boolean condition, long offset, String reason) + { + if (!condition) { + throw new MalformedInputException(offset, reason); + } + } + + public static void checkArgument(boolean condition, String reason) + { + if (!condition) { + throw new IllegalArgumentException(reason); + } + } + + static void checkPositionIndexes(int start, int end, int size) + { + // Carefully optimized for execution by hotspot (explanatory comment above) + if (start < 0 || end < start || end > size) { + throw new IndexOutOfBoundsException(badPositionIndexes(start, end, size)); + } + } + + private static String badPositionIndexes(int start, int end, int size) + { + if (start < 0 || start > size) { + return badPositionIndex(start, size, "start index"); + } + if (end < 0 || end > size) { + return badPositionIndex(end, size, "end index"); + } + // end < start + return String.format("end index (%s) must not be less than start index (%s)", end, start); + } + + private static String badPositionIndex(int index, int size, String desc) + { + if (index < 0) { + return String.format("%s (%s) must not be negative", desc, index); + } + else if (size < 0) { + throw new IllegalArgumentException("negative size: " + size); + } + else { // index > size + return String.format("%s (%s) must not be greater than size (%s)", desc, index, size); + } + } + + public static void checkState(boolean condition, String reason) + { + if (!condition) { + throw new IllegalStateException(reason); + } + } + + public static MalformedInputException fail(long offset, String reason) + { + throw new MalformedInputException(offset, reason); + } + + public static int cycleLog(int hashLog, CompressionParameters.Strategy strategy) + { + int cycleLog = hashLog; + if (strategy == CompressionParameters.Strategy.BTLAZY2 || strategy == CompressionParameters.Strategy.BTOPT || strategy == CompressionParameters.Strategy.BTULTRA) { + cycleLog = hashLog - 1; + } + return cycleLog; + } + + public static int get24BitLittleEndian(Object inputBase, long inputAddress) + { + return (UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF) + | ((UNSAFE.getByte(inputBase, inputAddress + SIZE_OF_SHORT) & 0xFF) << Short.SIZE); + } + + public static void put24BitLittleEndian(Object outputBase, long outputAddress, int value) + { + UNSAFE.putShort(outputBase, outputAddress, (short) value); + UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_SHORT, (byte) (value >>> Short.SIZE)); + } + + // provides the minimum logSize to safely represent a distribution + public static int minTableLog(int inputSize, int maxSymbolValue) + { + if (inputSize <= 1) { + throw new IllegalArgumentException("Not supported. RLE should be used instead"); // TODO + } + + int minBitsSrc = highestBit((inputSize - 1)) + 1; + int minBitsSymbols = highestBit(maxSymbolValue) + 2; + return Math.min(minBitsSrc, minBitsSymbols); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/XxHash64.java b/src/main/java/io/airlift/compress/zstd/XxHash64.java new file mode 100644 index 0000000..d87e38f --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/XxHash64.java @@ -0,0 +1,290 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.io.IOException; +import java.io.InputStream; + +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkPositionIndexes; +import static java.lang.Long.rotateLeft; +import static java.lang.Math.min; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +// forked from https://github.com/airlift/slice +final class XxHash64 +{ + private static final long PRIME64_1 = 0x9E3779B185EBCA87L; + private static final long PRIME64_2 = 0xC2B2AE3D27D4EB4FL; + private static final long PRIME64_3 = 0x165667B19E3779F9L; + private static final long PRIME64_4 = 0x85EBCA77C2b2AE63L; + private static final long PRIME64_5 = 0x27D4EB2F165667C5L; + + private static final long DEFAULT_SEED = 0; + + private final long seed; + + private static final long BUFFER_ADDRESS = ARRAY_BYTE_BASE_OFFSET; + private final byte[] buffer = new byte[32]; + private int bufferSize; + + private long bodyLength; + + private long v1; + private long v2; + private long v3; + private long v4; + + public XxHash64() + { + this(DEFAULT_SEED); + } + + private XxHash64(long seed) + { + this.seed = seed; + this.v1 = seed + PRIME64_1 + PRIME64_2; + this.v2 = seed + PRIME64_2; + this.v3 = seed; + this.v4 = seed - PRIME64_1; + } + + public XxHash64 update(byte[] data) + { + return update(data, 0, data.length); + } + + public XxHash64 update(byte[] data, int offset, int length) + { + checkPositionIndexes(offset, offset + length, data.length); + updateHash(data, ARRAY_BYTE_BASE_OFFSET + offset, length); + return this; + } + + public long hash() + { + long hash; + if (bodyLength > 0) { + hash = computeBody(); + } + else { + hash = seed + PRIME64_5; + } + + hash += bodyLength + bufferSize; + + return updateTail(hash, buffer, BUFFER_ADDRESS, 0, bufferSize); + } + + private long computeBody() + { + long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18); + + hash = update(hash, v1); + hash = update(hash, v2); + hash = update(hash, v3); + hash = update(hash, v4); + + return hash; + } + + private void updateHash(Object base, long address, int length) + { + if (bufferSize > 0) { + int available = min(32 - bufferSize, length); + + UNSAFE.copyMemory(base, address, buffer, BUFFER_ADDRESS + bufferSize, available); + + bufferSize += available; + address += available; + length -= available; + + if (bufferSize == 32) { + updateBody(buffer, BUFFER_ADDRESS, bufferSize); + bufferSize = 0; + } + } + + if (length >= 32) { + int index = updateBody(base, address, length); + address += index; + length -= index; + } + + if (length > 0) { + UNSAFE.copyMemory(base, address, buffer, BUFFER_ADDRESS, length); + bufferSize = length; + } + } + + private int updateBody(Object base, long address, int length) + { + int remaining = length; + while (remaining >= 32) { + v1 = mix(v1, UNSAFE.getLong(base, address)); + v2 = mix(v2, UNSAFE.getLong(base, address + 8)); + v3 = mix(v3, UNSAFE.getLong(base, address + 16)); + v4 = mix(v4, UNSAFE.getLong(base, address + 24)); + + address += 32; + remaining -= 32; + } + + int index = length - remaining; + bodyLength += index; + return index; + } + + public static long hash(long value) + { + long hash = DEFAULT_SEED + PRIME64_5 + SIZE_OF_LONG; + hash = updateTail(hash, value); + hash = finalShuffle(hash); + + return hash; + } + + public static long hash(InputStream in) + throws IOException + { + return hash(DEFAULT_SEED, in); + } + + public static long hash(long seed, InputStream in) + throws IOException + { + XxHash64 hash = new XxHash64(seed); + byte[] buffer = new byte[8192]; + while (true) { + int length = in.read(buffer); + if (length == -1) { + break; + } + hash.update(buffer, 0, length); + } + return hash.hash(); + } + + public static long hash(long seed, Object base, long address, int length) + { + long hash; + if (length >= 32) { + hash = updateBody(seed, base, address, length); + } + else { + hash = seed + PRIME64_5; + } + + hash += length; + + // round to the closest 32 byte boundary + // this is the point up to which updateBody() processed + int index = length & 0xFFFFFFE0; + + return updateTail(hash, base, address, index, length); + } + + private static long updateTail(long hash, Object base, long address, int index, int length) + { + while (index <= length - 8) { + hash = updateTail(hash, UNSAFE.getLong(base, address + index)); + index += 8; + } + + if (index <= length - 4) { + hash = updateTail(hash, UNSAFE.getInt(base, address + index)); + index += 4; + } + + while (index < length) { + hash = updateTail(hash, UNSAFE.getByte(base, address + index)); + index++; + } + + hash = finalShuffle(hash); + + return hash; + } + + private static long updateBody(long seed, Object base, long address, int length) + { + long v1 = seed + PRIME64_1 + PRIME64_2; + long v2 = seed + PRIME64_2; + long v3 = seed; + long v4 = seed - PRIME64_1; + + int remaining = length; + while (remaining >= 32) { + v1 = mix(v1, UNSAFE.getLong(base, address)); + v2 = mix(v2, UNSAFE.getLong(base, address + 8)); + v3 = mix(v3, UNSAFE.getLong(base, address + 16)); + v4 = mix(v4, UNSAFE.getLong(base, address + 24)); + + address += 32; + remaining -= 32; + } + + long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18); + + hash = update(hash, v1); + hash = update(hash, v2); + hash = update(hash, v3); + hash = update(hash, v4); + + return hash; + } + + private static long mix(long current, long value) + { + return rotateLeft(current + value * PRIME64_2, 31) * PRIME64_1; + } + + private static long update(long hash, long value) + { + long temp = hash ^ mix(0, value); + return temp * PRIME64_1 + PRIME64_4; + } + + private static long updateTail(long hash, long value) + { + long temp = hash ^ mix(0, value); + return rotateLeft(temp, 27) * PRIME64_1 + PRIME64_4; + } + + private static long updateTail(long hash, int value) + { + long unsigned = value & 0xFFFF_FFFFL; + long temp = hash ^ (unsigned * PRIME64_1); + return rotateLeft(temp, 23) * PRIME64_2 + PRIME64_3; + } + + private static long updateTail(long hash, byte value) + { + int unsigned = value & 0xFF; + long temp = hash ^ (unsigned * PRIME64_5); + return rotateLeft(temp, 11) * PRIME64_1; + } + + private static long finalShuffle(long hash) + { + hash ^= hash >>> 33; + hash *= PRIME64_2; + hash ^= hash >>> 29; + hash *= PRIME64_3; + hash ^= hash >>> 32; + return hash; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java new file mode 100644 index 0000000..757c572 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.Compressor; + +import java.nio.Buffer; +import java.nio.ByteBuffer; + +import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; +import static io.airlift.compress.zstd.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class ZstdCompressor + implements Compressor +{ + @Override + public int maxCompressedLength(int uncompressedSize) + { + int result = uncompressedSize + (uncompressedSize >>> 8); + + if (uncompressedSize < MAX_BLOCK_SIZE) { + result += (MAX_BLOCK_SIZE - uncompressedSize) >>> 11; + } + + return result; + } + + @Override + public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) + { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; + long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; + + return ZstdFrameCompressor.compress(input, inputAddress, inputAddress + inputLength, output, outputAddress, outputAddress + maxOutputLength, CompressionParameters.DEFAULT_COMPRESSION_LEVEL); + } + + @Override + public void compress(ByteBuffer inputBuffer, ByteBuffer outputBuffer) + { + // Java 9+ added an overload of various methods in ByteBuffer. When compiling with Java 11+ and targeting Java 8 bytecode + // the resulting signatures are invalid for JDK 8, so accesses below result in NoSuchMethodError. Accessing the + // methods through the interface class works around the problem + // Sidenote: we can't target "javac --release 8" because Unsafe is not available in the signature data for that profile + Buffer input = inputBuffer; + Buffer output = outputBuffer; + + Object inputBase; + long inputAddress; + long inputLimit; + if (input.isDirect()) { + inputBase = null; + long address = getAddress(input); + inputAddress = address + input.position(); + inputLimit = address + input.limit(); + } + else if (input.hasArray()) { + inputBase = input.array(); + inputAddress = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.position(); + inputLimit = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.limit(); + } + else { + throw new IllegalArgumentException("Unsupported input ByteBuffer implementation " + input.getClass().getName()); + } + + Object outputBase; + long outputAddress; + long outputLimit; + if (output.isDirect()) { + outputBase = null; + long address = getAddress(output); + outputAddress = address + output.position(); + outputLimit = address + output.limit(); + } + else if (output.hasArray()) { + outputBase = output.array(); + outputAddress = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.position(); + outputLimit = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.limit(); + } + else { + throw new IllegalArgumentException("Unsupported output ByteBuffer implementation " + output.getClass().getName()); + } + + // HACK: Assure JVM does not collect Slice wrappers while compressing, since the + // collection may trigger freeing of the underlying memory resulting in a segfault + // There is no other known way to signal to the JVM that an object should not be + // collected in a block, and technically, the JVM is allowed to eliminate these locks. + synchronized (input) { + synchronized (output) { + int written = ZstdFrameCompressor.compress( + inputBase, + inputAddress, + inputLimit, + outputBase, + outputAddress, + outputLimit, + CompressionParameters.DEFAULT_COMPRESSION_LEVEL); + output.position(output.position() + written); + } + } + } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java new file mode 100644 index 0000000..75a2485 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.Decompressor; +import io.airlift.compress.MalformedInputException; + +import java.nio.Buffer; +import java.nio.ByteBuffer; + +import static io.airlift.compress.zstd.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class ZstdDecompressor + implements Decompressor +{ + private final ZstdFrameDecompressor decompressor = new ZstdFrameDecompressor(); + + @Override + public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) + throws MalformedInputException + { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; + long inputLimit = inputAddress + inputLength; + long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; + long outputLimit = outputAddress + maxOutputLength; + + return decompressor.decompress(input, inputAddress, inputLimit, output, outputAddress, outputLimit); + } + + @Override + public void decompress(ByteBuffer inputBuffer, ByteBuffer outputBuffer) + throws MalformedInputException + { + // Java 9+ added an overload of various methods in ByteBuffer. When compiling with Java 11+ and targeting Java 8 bytecode + // the resulting signatures are invalid for JDK 8, so accesses below result in NoSuchMethodError. Accessing the + // methods through the interface class works around the problem + // Sidenote: we can't target "javac --release 8" because Unsafe is not available in the signature data for that profile + Buffer input = inputBuffer; + Buffer output = outputBuffer; + + Object inputBase; + long inputAddress; + long inputLimit; + if (input.isDirect()) { + inputBase = null; + long address = getAddress(input); + inputAddress = address + input.position(); + inputLimit = address + input.limit(); + } + else if (input.hasArray()) { + inputBase = input.array(); + inputAddress = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.position(); + inputLimit = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.limit(); + } + else { + throw new IllegalArgumentException("Unsupported input ByteBuffer implementation " + input.getClass().getName()); + } + + Object outputBase; + long outputAddress; + long outputLimit; + if (output.isDirect()) { + outputBase = null; + long address = getAddress(output); + outputAddress = address + output.position(); + outputLimit = address + output.limit(); + } + else if (output.hasArray()) { + outputBase = output.array(); + outputAddress = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.position(); + outputLimit = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.limit(); + } + else { + throw new IllegalArgumentException("Unsupported output ByteBuffer implementation " + output.getClass().getName()); + } + + // HACK: Assure JVM does not collect Slice wrappers while decompressing, since the + // collection may trigger freeing of the underlying memory resulting in a segfault + // There is no other known way to signal to the JVM that an object should not be + // collected in a block, and technically, the JVM is allowed to eliminate these locks. + synchronized (input) { + synchronized (output) { + int written = decompressor.decompress(inputBase, inputAddress, inputLimit, outputBase, outputAddress, outputLimit); + output.position(output.position() + written); + } + } + } + + public static long getDecompressedSize(byte[] input, int offset, int length) + { + int baseAddress = ARRAY_BYTE_BASE_OFFSET + offset; + return ZstdFrameDecompressor.getDecompressedSize(input, baseAddress, baseAddress + length); + } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdFrameCompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdFrameCompressor.java new file mode 100644 index 0000000..691037b --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdFrameCompressor.java @@ -0,0 +1,448 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import static io.airlift.compress.zstd.Constants.COMPRESSED_BLOCK; +import static io.airlift.compress.zstd.Constants.COMPRESSED_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.MAGIC_NUMBER; +import static io.airlift.compress.zstd.Constants.MIN_BLOCK_SIZE; +import static io.airlift.compress.zstd.Constants.MIN_WINDOW_LOG; +import static io.airlift.compress.zstd.Constants.RAW_BLOCK; +import static io.airlift.compress.zstd.Constants.RAW_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.RLE_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.SIZE_OF_BLOCK_HEADER; +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.Constants.TREELESS_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Huffman.MAX_SYMBOL; +import static io.airlift.compress.zstd.Huffman.MAX_SYMBOL_COUNT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; +import static io.airlift.compress.zstd.Util.put24BitLittleEndian; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +class ZstdFrameCompressor +{ + static final int MAX_FRAME_HEADER_SIZE = 14; + + private static final int CHECKSUM_FLAG = 0b100; + private static final int SINGLE_SEGMENT_FLAG = 0b100000; + + private static final int MINIMUM_LITERALS_SIZE = 63; + + // the maximum table log allowed for literal encoding per RFC 8478, section 4.2.1 + private static final int MAX_HUFFMAN_TABLE_LOG = 11; + + private ZstdFrameCompressor() + { + } + + // visible for testing + static int writeMagic(final Object outputBase, final long outputAddress, final long outputLimit) + { + checkArgument(outputLimit - outputAddress >= SIZE_OF_INT, "Output buffer too small"); + + UNSAFE.putInt(outputBase, outputAddress, MAGIC_NUMBER); + return SIZE_OF_INT; + } + + // visible for testing + static int writeFrameHeader(final Object outputBase, final long outputAddress, final long outputLimit, int inputSize, int windowSize) + { + checkArgument(outputLimit - outputAddress >= MAX_FRAME_HEADER_SIZE, "Output buffer too small"); + + long output = outputAddress; + + int contentSizeDescriptor = 0; + if (inputSize != -1) { + contentSizeDescriptor = (inputSize >= 256 ? 1 : 0) + (inputSize >= 65536 + 256 ? 1 : 0); + } + int frameHeaderDescriptor = (contentSizeDescriptor << 6) | CHECKSUM_FLAG; // dictionary ID missing + + boolean singleSegment = inputSize != -1 && windowSize >= inputSize; + if (singleSegment) { + frameHeaderDescriptor |= SINGLE_SEGMENT_FLAG; + } + + UNSAFE.putByte(outputBase, output, (byte) frameHeaderDescriptor); + output++; + + if (!singleSegment) { + int base = Integer.highestOneBit(windowSize); + + int exponent = 32 - Integer.numberOfLeadingZeros(base) - 1; + if (exponent < MIN_WINDOW_LOG) { + throw new IllegalArgumentException("Minimum window size is " + (1 << MIN_WINDOW_LOG)); + } + + int remainder = windowSize - base; + if (remainder % (base / 8) != 0) { + throw new IllegalArgumentException("Window size of magnitude 2^" + exponent + " must be multiple of " + (base / 8)); + } + + // mantissa is guaranteed to be between 0-7 + int mantissa = remainder / (base / 8); + int encoded = ((exponent - MIN_WINDOW_LOG) << 3) | mantissa; + + UNSAFE.putByte(outputBase, output, (byte) encoded); + output++; + } + + switch (contentSizeDescriptor) { + case 0: + if (singleSegment) { + UNSAFE.putByte(outputBase, output++, (byte) inputSize); + } + break; + case 1: + UNSAFE.putShort(outputBase, output, (short) (inputSize - 256)); + output += SIZE_OF_SHORT; + break; + case 2: + UNSAFE.putInt(outputBase, output, inputSize); + output += SIZE_OF_INT; + break; + default: + throw new AssertionError(); + } + + return (int) (output - outputAddress); + } + + // visible for testing + static int writeChecksum(Object outputBase, long outputAddress, long outputLimit, Object inputBase, long inputAddress, long inputLimit) + { + checkArgument(outputLimit - outputAddress >= SIZE_OF_INT, "Output buffer too small"); + + int inputSize = (int) (inputLimit - inputAddress); + + long hash = XxHash64.hash(0, inputBase, inputAddress, inputSize); + + UNSAFE.putInt(outputBase, outputAddress, (int) hash); + + return SIZE_OF_INT; + } + + public static int compress(Object inputBase, long inputAddress, long inputLimit, Object outputBase, long outputAddress, long outputLimit, int compressionLevel) + { + int inputSize = (int) (inputLimit - inputAddress); + + CompressionParameters parameters = CompressionParameters.compute(compressionLevel, inputSize); + + long output = outputAddress; + + output += writeMagic(outputBase, output, outputLimit); + output += writeFrameHeader(outputBase, output, outputLimit, inputSize, parameters.getWindowSize()); + output += compressFrame(inputBase, inputAddress, inputLimit, outputBase, output, outputLimit, parameters); + output += writeChecksum(outputBase, output, outputLimit, inputBase, inputAddress, inputLimit); + + return (int) (output - outputAddress); + } + + private static int compressFrame(Object inputBase, long inputAddress, long inputLimit, Object outputBase, long outputAddress, long outputLimit, CompressionParameters parameters) + { + int blockSize = parameters.getBlockSize(); + + int outputSize = (int) (outputLimit - outputAddress); + int remaining = (int) (inputLimit - inputAddress); + + long output = outputAddress; + long input = inputAddress; + + CompressionContext context = new CompressionContext(parameters, inputAddress, remaining); + do { + checkArgument(outputSize >= SIZE_OF_BLOCK_HEADER + MIN_BLOCK_SIZE, "Output buffer too small"); + + boolean lastBlock = blockSize >= remaining; + blockSize = Math.min(blockSize, remaining); + + int compressedSize = writeCompressedBlock(inputBase, input, blockSize, outputBase, output, outputSize, context, lastBlock); + + input += blockSize; + remaining -= blockSize; + output += compressedSize; + outputSize -= compressedSize; + } + while (remaining > 0); + + return (int) (output - outputAddress); + } + + static int writeCompressedBlock(Object inputBase, long input, int blockSize, Object outputBase, long output, int outputSize, CompressionContext context, boolean lastBlock) + { + checkArgument(lastBlock || blockSize == context.parameters.getBlockSize(), "Only last block can be smaller than block size"); + + int compressedSize = 0; + if (blockSize > 0) { + compressedSize = compressBlock(inputBase, input, blockSize, outputBase, output + SIZE_OF_BLOCK_HEADER, outputSize - SIZE_OF_BLOCK_HEADER, context); + } + + if (compressedSize == 0) { // block is not compressible + checkArgument(blockSize + SIZE_OF_BLOCK_HEADER <= outputSize, "Output size too small"); + + int blockHeader = (lastBlock ? 1 : 0) | (RAW_BLOCK << 1) | (blockSize << 3); + put24BitLittleEndian(outputBase, output, blockHeader); + UNSAFE.copyMemory(inputBase, input, outputBase, output + SIZE_OF_BLOCK_HEADER, blockSize); + compressedSize = SIZE_OF_BLOCK_HEADER + blockSize; + } + else { + int blockHeader = (lastBlock ? 1 : 0) | (COMPRESSED_BLOCK << 1) | (compressedSize << 3); + put24BitLittleEndian(outputBase, output, blockHeader); + compressedSize += SIZE_OF_BLOCK_HEADER; + } + return compressedSize; + } + + private static int compressBlock(Object inputBase, long inputAddress, int inputSize, Object outputBase, long outputAddress, int outputSize, CompressionContext context) + { + if (inputSize < MIN_BLOCK_SIZE + SIZE_OF_BLOCK_HEADER + 1) { + // don't even attempt compression below a certain input size + return 0; + } + + CompressionParameters parameters = context.parameters; + context.blockCompressionState.enforceMaxDistance(inputAddress + inputSize, parameters.getWindowSize()); + context.sequenceStore.reset(); + + int lastLiteralsSize = parameters.getStrategy() + .getCompressor() + .compressBlock(inputBase, inputAddress, inputSize, context.sequenceStore, context.blockCompressionState, context.offsets, parameters); + + long lastLiteralsAddress = inputAddress + inputSize - lastLiteralsSize; + + // append [lastLiteralsAddress .. lastLiteralsSize] to sequenceStore literals buffer + context.sequenceStore.appendLiterals(inputBase, lastLiteralsAddress, lastLiteralsSize); + + // convert length/offsets into codes + context.sequenceStore.generateCodes(); + + long outputLimit = outputAddress + outputSize; + long output = outputAddress; + + int compressedLiteralsSize = encodeLiterals( + context.huffmanContext, + parameters, + outputBase, + output, + (int) (outputLimit - output), + context.sequenceStore.literalsBuffer, + context.sequenceStore.literalsLength); + output += compressedLiteralsSize; + + int compressedSequencesSize = SequenceEncoder.compressSequences(outputBase, output, (int) (outputLimit - output), context.sequenceStore, parameters.getStrategy(), context.sequenceEncodingContext); + + int compressedSize = compressedLiteralsSize + compressedSequencesSize; + if (compressedSize == 0) { + // not compressible + return compressedSize; + } + + // Check compressibility + int maxCompressedSize = inputSize - calculateMinimumGain(inputSize, parameters.getStrategy()); + if (compressedSize > maxCompressedSize) { + return 0; // not compressed + } + + // confirm repeated offsets and entropy tables + context.commit(); + + return compressedSize; + } + + private static int encodeLiterals( + HuffmanCompressionContext context, + CompressionParameters parameters, + Object outputBase, + long outputAddress, + int outputSize, + byte[] literals, + int literalsSize) + { + // TODO: move this to Strategy + boolean bypassCompression = (parameters.getStrategy() == CompressionParameters.Strategy.FAST) && (parameters.getTargetLength() > 0); + if (bypassCompression || literalsSize <= MINIMUM_LITERALS_SIZE) { + return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize); + } + + int headerSize = 3 + (literalsSize >= 1024 ? 1 : 0) + (literalsSize >= 16384 ? 1 : 0); + + checkArgument(headerSize + 1 <= outputSize, "Output buffer too small"); + + int[] counts = new int[MAX_SYMBOL_COUNT]; // TODO: preallocate + Histogram.count(literals, literalsSize, counts); + int maxSymbol = Histogram.findMaxSymbol(counts, MAX_SYMBOL); + int largestCount = Histogram.findLargestCount(counts, maxSymbol); + + long literalsAddress = ARRAY_BYTE_BASE_OFFSET; + if (largestCount == literalsSize) { + // all bytes in input are equal + return rleLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize); + } + else if (largestCount <= (literalsSize >>> 7) + 4) { + // heuristic: probably not compressible enough + return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize); + } + + HuffmanCompressionTable previousTable = context.getPreviousTable(); + HuffmanCompressionTable table; + int serializedTableSize; + boolean reuseTable; + + boolean canReuse = previousTable.isValid(counts, maxSymbol); + + // heuristic: use existing table for small inputs if valid + // TODO: move to Strategy + boolean preferReuse = parameters.getStrategy().ordinal() < CompressionParameters.Strategy.LAZY.ordinal() && literalsSize <= 1024; + if (preferReuse && canReuse) { + table = previousTable; + reuseTable = true; + serializedTableSize = 0; + } + else { + HuffmanCompressionTable newTable = context.borrowTemporaryTable(); + + newTable.initialize( + counts, + maxSymbol, + HuffmanCompressionTable.optimalNumberOfBits(MAX_HUFFMAN_TABLE_LOG, literalsSize, maxSymbol), + context.getCompressionTableWorkspace()); + + serializedTableSize = newTable.write(outputBase, outputAddress + headerSize, outputSize - headerSize, context.getTableWriterWorkspace()); + + // Check if using previous huffman table is beneficial + if (canReuse && previousTable.estimateCompressedSize(counts, maxSymbol) <= serializedTableSize + newTable.estimateCompressedSize(counts, maxSymbol)) { + table = previousTable; + reuseTable = true; + serializedTableSize = 0; + context.discardTemporaryTable(); + } + else { + table = newTable; + reuseTable = false; + } + } + + int compressedSize; + boolean singleStream = literalsSize < 256; + if (singleStream) { + compressedSize = HuffmanCompressor.compressSingleStream(outputBase, outputAddress + headerSize + serializedTableSize, outputSize - headerSize - serializedTableSize, literals, literalsAddress, literalsSize, table); + } + else { + compressedSize = HuffmanCompressor.compress4streams(outputBase, outputAddress + headerSize + serializedTableSize, outputSize - headerSize - serializedTableSize, literals, literalsAddress, literalsSize, table); + } + + int totalSize = serializedTableSize + compressedSize; + int minimumGain = calculateMinimumGain(literalsSize, parameters.getStrategy()); + + if (compressedSize == 0 || totalSize >= literalsSize - minimumGain) { + // incompressible or no savings + + // discard any temporary table we might have borrowed above + context.discardTemporaryTable(); + + return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize); + } + + int encodingType = reuseTable ? TREELESS_LITERALS_BLOCK : COMPRESSED_LITERALS_BLOCK; + + // Build header + switch (headerSize) { + case 3: { // 2 - 2 - 10 - 10 + int header = encodingType | ((singleStream ? 0 : 1) << 2) | (literalsSize << 4) | (totalSize << 14); + put24BitLittleEndian(outputBase, outputAddress, header); + break; + } + case 4: { // 2 - 2 - 14 - 14 + int header = encodingType | (2 << 2) | (literalsSize << 4) | (totalSize << 18); + UNSAFE.putInt(outputBase, outputAddress, header); + break; + } + case 5: { // 2 - 2 - 18 - 18 + int header = encodingType | (3 << 2) | (literalsSize << 4) | (totalSize << 22); + UNSAFE.putInt(outputBase, outputAddress, header); + UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_INT, (byte) (totalSize >>> 10)); + break; + } + default: // not possible : headerSize is {3,4,5} + throw new IllegalStateException(); + } + + return headerSize + totalSize; + } + + private static int rleLiterals(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize) + { + int headerSize = 1 + (inputSize > 31 ? 1 : 0) + (inputSize > 4095 ? 1 : 0); + + switch (headerSize) { + case 1: // 2 - 1 - 5 + UNSAFE.putByte(outputBase, outputAddress, (byte) (RLE_LITERALS_BLOCK | (inputSize << 3))); + break; + case 2: // 2 - 2 - 12 + UNSAFE.putShort(outputBase, outputAddress, (short) (RLE_LITERALS_BLOCK | (1 << 2) | (inputSize << 4))); + break; + case 3: // 2 - 2 - 20 + UNSAFE.putInt(outputBase, outputAddress, RLE_LITERALS_BLOCK | 3 << 2 | inputSize << 4); + break; + default: // impossible. headerSize is {1,2,3} + throw new IllegalStateException(); + } + + UNSAFE.putByte(outputBase, outputAddress + headerSize, UNSAFE.getByte(inputBase, inputAddress)); + + return headerSize + 1; + } + + private static int calculateMinimumGain(int inputSize, CompressionParameters.Strategy strategy) + { + // TODO: move this to Strategy to avoid hardcoding a specific strategy here + int minLog = strategy == CompressionParameters.Strategy.BTULTRA ? 7 : 6; + return (inputSize >>> minLog) + 2; + } + + private static int rawLiterals(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize) + { + int headerSize = 1; + if (inputSize >= 32) { + headerSize++; + } + if (inputSize >= 4096) { + headerSize++; + } + + checkArgument(inputSize + headerSize <= outputSize, "Output buffer too small"); + + switch (headerSize) { + case 1: + UNSAFE.putByte(outputBase, outputAddress, (byte) (RAW_LITERALS_BLOCK | (inputSize << 3))); + break; + case 2: + UNSAFE.putShort(outputBase, outputAddress, (short) (RAW_LITERALS_BLOCK | (1 << 2) | (inputSize << 4))); + break; + case 3: + put24BitLittleEndian(outputBase, outputAddress, RAW_LITERALS_BLOCK | (3 << 2) | (inputSize << 4)); + break; + default: + throw new AssertionError(); + } + + // TODO: ensure this test is correct + checkArgument(inputSize + 1 <= outputSize, "Output buffer too small"); + + UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress + headerSize, inputSize); + + return headerSize + inputSize; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdFrameDecompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdFrameDecompressor.java new file mode 100644 index 0000000..ce76709 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdFrameDecompressor.java @@ -0,0 +1,980 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.MalformedInputException; + +import java.util.Arrays; + +import static io.airlift.compress.zstd.BitInputStream.peekBits; +import static io.airlift.compress.zstd.Constants.COMPRESSED_BLOCK; +import static io.airlift.compress.zstd.Constants.COMPRESSED_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.DEFAULT_MAX_OFFSET_CODE_SYMBOL; +import static io.airlift.compress.zstd.Constants.LITERALS_LENGTH_BITS; +import static io.airlift.compress.zstd.Constants.LITERAL_LENGTH_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.LONG_NUMBER_OF_SEQUENCES; +import static io.airlift.compress.zstd.Constants.MAGIC_NUMBER; +import static io.airlift.compress.zstd.Constants.MATCH_LENGTH_BITS; +import static io.airlift.compress.zstd.Constants.MATCH_LENGTH_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; +import static io.airlift.compress.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL; +import static io.airlift.compress.zstd.Constants.MIN_BLOCK_SIZE; +import static io.airlift.compress.zstd.Constants.MIN_SEQUENCES_SIZE; +import static io.airlift.compress.zstd.Constants.MIN_WINDOW_LOG; +import static io.airlift.compress.zstd.Constants.OFFSET_TABLE_LOG; +import static io.airlift.compress.zstd.Constants.RAW_BLOCK; +import static io.airlift.compress.zstd.Constants.RAW_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.RLE_BLOCK; +import static io.airlift.compress.zstd.Constants.RLE_LITERALS_BLOCK; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_BASIC; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_COMPRESSED; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_REPEAT; +import static io.airlift.compress.zstd.Constants.SEQUENCE_ENCODING_RLE; +import static io.airlift.compress.zstd.Constants.SIZE_OF_BLOCK_HEADER; +import static io.airlift.compress.zstd.Constants.SIZE_OF_BYTE; +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.Constants.SIZE_OF_SHORT; +import static io.airlift.compress.zstd.Constants.TREELESS_LITERALS_BLOCK; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.fail; +import static io.airlift.compress.zstd.Util.get24BitLittleEndian; +import static io.airlift.compress.zstd.Util.mask; +import static io.airlift.compress.zstd.Util.verify; +import static java.lang.String.format; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +class ZstdFrameDecompressor +{ + private static final int[] DEC_32_TABLE = {4, 1, 2, 1, 4, 4, 4, 4}; + private static final int[] DEC_64_TABLE = {0, 0, 0, -1, 0, 1, 2, 3}; + + private static final int V07_MAGIC_NUMBER = 0xFD2FB527; + + static final int MAX_WINDOW_SIZE = 1 << 23; + + private static final int[] LITERALS_LENGTH_BASE = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 18, 20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, + 0x2000, 0x4000, 0x8000, 0x10000}; + + private static final int[] MATCH_LENGTH_BASE = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, + 0x1003, 0x2003, 0x4003, 0x8003, 0x10003}; + + private static final int[] OFFSET_CODES_BASE = { + 0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, + 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD, 0x1FFD, 0x3FFD, 0x7FFD, + 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, + 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD}; + + private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE = new FiniteStateEntropy.Table( + 6, + new int[] { + 0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0, 0, 48, 16, 32, 32, 32, + 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32, 0, 0, 0, 0}, + new byte[] { + 0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26, 27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23, 25, 25, 26, 28, 30, 0, + 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20, 21, 23, 24, 35, 34, 33, 32}, + new byte[] { + 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6}); + + private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE = new FiniteStateEntropy.Table( + 5, + new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0}, + new byte[] {0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4, 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24}, + new byte[] {5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5}); + + private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE = new FiniteStateEntropy.Table( + 6, + new int[] { + 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48, 16, 32, 32, 32, 32, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new byte[] { + 0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39, 41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34, 36, 38, 40, 42, 44, 1, + 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29, 52, 51, 50, 49, 48, 47, 46}, + new byte[] { + 6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6}); + + private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG]; // extra space to allow for long-at-a-time copy + + // current buffer containing literals + private Object literalsBase; + private long literalsAddress; + private long literalsLimit; + + private final int[] previousOffsets = new int[3]; + + private final FiniteStateEntropy.Table literalsLengthTable = new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG); + private final FiniteStateEntropy.Table offsetCodesTable = new FiniteStateEntropy.Table(OFFSET_TABLE_LOG); + private final FiniteStateEntropy.Table matchLengthTable = new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG); + + private FiniteStateEntropy.Table currentLiteralsLengthTable; + private FiniteStateEntropy.Table currentOffsetCodesTable; + private FiniteStateEntropy.Table currentMatchLengthTable; + + private final Huffman huffman = new Huffman(); + private final FseTableReader fse = new FseTableReader(); + + public int decompress( + final Object inputBase, + final long inputAddress, + final long inputLimit, + final Object outputBase, + final long outputAddress, + final long outputLimit) + { + if (outputAddress == outputLimit) { + return 0; + } + + long input = inputAddress; + long output = outputAddress; + + while (input < inputLimit) { + reset(); + long outputStart = output; + input += verifyMagic(inputBase, input, inputLimit); + + FrameHeader frameHeader = readFrameHeader(inputBase, input, inputLimit); + input += frameHeader.headerSize; + + boolean lastBlock; + do { + verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input, "Not enough input bytes"); + + // read block header + int header = get24BitLittleEndian(inputBase, input); + input += SIZE_OF_BLOCK_HEADER; + + lastBlock = (header & 1) != 0; + int blockType = (header >>> 1) & 0b11; + int blockSize = (header >>> 3) & 0x1F_FFFF; // 21 bits + + int decodedSize; + switch (blockType) { + case RAW_BLOCK: + verify(inputAddress + blockSize <= inputLimit, input, "Not enough input bytes"); + decodedSize = decodeRawBlock(inputBase, input, blockSize, outputBase, output, outputLimit); + input += blockSize; + break; + case RLE_BLOCK: + verify(inputAddress + 1 <= inputLimit, input, "Not enough input bytes"); + decodedSize = decodeRleBlock(blockSize, inputBase, input, outputBase, output, outputLimit); + input += 1; + break; + case COMPRESSED_BLOCK: + verify(inputAddress + blockSize <= inputLimit, input, "Not enough input bytes"); + decodedSize = decodeCompressedBlock(inputBase, input, blockSize, outputBase, output, outputLimit, frameHeader.windowSize, outputAddress); + input += blockSize; + break; + default: + throw fail(input, "Invalid block type"); + } + + output += decodedSize; + } + while (!lastBlock); + + if (frameHeader.hasChecksum) { + int decodedFrameSize = (int) (output - outputStart); + + long hash = XxHash64.hash(0, outputBase, outputStart, decodedFrameSize); + + int checksum = UNSAFE.getInt(inputBase, input); + if (checksum != (int) hash) { + throw new MalformedInputException(input, format("Bad checksum. Expected: %s, actual: %s", Integer.toHexString(checksum), Integer.toHexString((int) hash))); + } + + input += SIZE_OF_INT; + } + } + + return (int) (output - outputAddress); + } + + void reset() + { + previousOffsets[0] = 1; + previousOffsets[1] = 4; + previousOffsets[2] = 8; + + currentLiteralsLengthTable = null; + currentOffsetCodesTable = null; + currentMatchLengthTable = null; + } + + static int decodeRawBlock(Object inputBase, long inputAddress, int blockSize, Object outputBase, long outputAddress, long outputLimit) + { + verify(outputAddress + blockSize <= outputLimit, inputAddress, "Output buffer too small"); + + UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize); + return blockSize; + } + + static int decodeRleBlock(int size, Object inputBase, long inputAddress, Object outputBase, long outputAddress, long outputLimit) + { + verify(outputAddress + size <= outputLimit, inputAddress, "Output buffer too small"); + + long output = outputAddress; + long value = UNSAFE.getByte(inputBase, inputAddress) & 0xFFL; + + int remaining = size; + if (remaining >= SIZE_OF_LONG) { + long packed = value + | (value << 8) + | (value << 16) + | (value << 24) + | (value << 32) + | (value << 40) + | (value << 48) + | (value << 56); + + do { + UNSAFE.putLong(outputBase, output, packed); + output += SIZE_OF_LONG; + remaining -= SIZE_OF_LONG; + } + while (remaining >= SIZE_OF_LONG); + } + + for (int i = 0; i < remaining; i++) { + UNSAFE.putByte(outputBase, output, (byte) value); + output++; + } + + return size; + } + + int decodeCompressedBlock( + Object inputBase, + final long inputAddress, + int blockSize, + Object outputBase, + long outputAddress, + long outputLimit, + int windowSize, + long outputAbsoluteBaseAddress) + { + long inputLimit = inputAddress + blockSize; + long input = inputAddress; + + verify(blockSize <= MAX_BLOCK_SIZE, input, "Expected match length table to be present"); + verify(blockSize >= MIN_BLOCK_SIZE, input, "Compressed block size too small"); + + // decode literals + int literalsBlockType = UNSAFE.getByte(inputBase, input) & 0b11; + + switch (literalsBlockType) { + case RAW_LITERALS_BLOCK: { + input += decodeRawLiterals(inputBase, input, inputLimit); + break; + } + case RLE_LITERALS_BLOCK: { + input += decodeRleLiterals(inputBase, input, blockSize); + break; + } + case TREELESS_LITERALS_BLOCK: + verify(huffman.isLoaded(), input, "Dictionary is corrupted"); + case COMPRESSED_LITERALS_BLOCK: { + input += decodeCompressedLiterals(inputBase, input, blockSize, literalsBlockType); + break; + } + default: + throw fail(input, "Invalid literals block encoding type"); + } + + verify(windowSize <= MAX_WINDOW_SIZE, input, "Window size too large (not yet supported)"); + + return decompressSequences( + inputBase, input, inputAddress + blockSize, + outputBase, outputAddress, outputLimit, + literalsBase, literalsAddress, literalsLimit, + outputAbsoluteBaseAddress); + } + + private int decompressSequences( + final Object inputBase, final long inputAddress, final long inputLimit, + final Object outputBase, final long outputAddress, final long outputLimit, + final Object literalsBase, final long literalsAddress, final long literalsLimit, + long outputAbsoluteBaseAddress) + { + final long fastOutputLimit = outputLimit - SIZE_OF_LONG; + final long fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG; + + long input = inputAddress; + long output = outputAddress; + + long literalsInput = literalsAddress; + + int size = (int) (inputLimit - inputAddress); + verify(size >= MIN_SEQUENCES_SIZE, input, "Not enough input bytes"); + + // decode header + int sequenceCount = UNSAFE.getByte(inputBase, input++) & 0xFF; + if (sequenceCount != 0) { + if (sequenceCount == 255) { + verify(input + SIZE_OF_SHORT <= inputLimit, input, "Not enough input bytes"); + sequenceCount = (UNSAFE.getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES; + input += SIZE_OF_SHORT; + } + else if (sequenceCount > 127) { + verify(input < inputLimit, input, "Not enough input bytes"); + sequenceCount = ((sequenceCount - 128) << 8) + (UNSAFE.getByte(inputBase, input++) & 0xFF); + } + + verify(input + SIZE_OF_INT <= inputLimit, input, "Not enough input bytes"); + + byte type = UNSAFE.getByte(inputBase, input++); + + int literalsLengthType = (type & 0xFF) >>> 6; + int offsetCodesType = (type >>> 4) & 0b11; + int matchLengthType = (type >>> 2) & 0b11; + + input = computeLiteralsTable(literalsLengthType, inputBase, input, inputLimit); + input = computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit); + input = computeMatchLengthTable(matchLengthType, inputBase, input, inputLimit); + + // decompress sequences + BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, input, inputLimit); + initializer.initialize(); + int bitsConsumed = initializer.getBitsConsumed(); + long bits = initializer.getBits(); + long currentAddress = initializer.getCurrentAddress(); + + FiniteStateEntropy.Table currentLiteralsLengthTable = this.currentLiteralsLengthTable; + FiniteStateEntropy.Table currentOffsetCodesTable = this.currentOffsetCodesTable; + FiniteStateEntropy.Table currentMatchLengthTable = this.currentMatchLengthTable; + + int literalsLengthState = (int) peekBits(bitsConsumed, bits, currentLiteralsLengthTable.log2Size); + bitsConsumed += currentLiteralsLengthTable.log2Size; + + int offsetCodesState = (int) peekBits(bitsConsumed, bits, currentOffsetCodesTable.log2Size); + bitsConsumed += currentOffsetCodesTable.log2Size; + + int matchLengthState = (int) peekBits(bitsConsumed, bits, currentMatchLengthTable.log2Size); + bitsConsumed += currentMatchLengthTable.log2Size; + + int[] previousOffsets = this.previousOffsets; + + byte[] literalsLengthNumbersOfBits = currentLiteralsLengthTable.numberOfBits; + int[] literalsLengthNewStates = currentLiteralsLengthTable.newState; + byte[] literalsLengthSymbols = currentLiteralsLengthTable.symbol; + + byte[] matchLengthNumbersOfBits = currentMatchLengthTable.numberOfBits; + int[] matchLengthNewStates = currentMatchLengthTable.newState; + byte[] matchLengthSymbols = currentMatchLengthTable.symbol; + + byte[] offsetCodesNumbersOfBits = currentOffsetCodesTable.numberOfBits; + int[] offsetCodesNewStates = currentOffsetCodesTable.newState; + byte[] offsetCodesSymbols = currentOffsetCodesTable.symbol; + + while (sequenceCount > 0) { + sequenceCount--; + + BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader.load(); + bitsConsumed = loader.getBitsConsumed(); + bits = loader.getBits(); + currentAddress = loader.getCurrentAddress(); + if (loader.isOverflow()) { + verify(sequenceCount == 0, input, "Not all sequences were consumed"); + break; + } + + // decode sequence + int literalsLengthCode = literalsLengthSymbols[literalsLengthState]; + int matchLengthCode = matchLengthSymbols[matchLengthState]; + int offsetCode = offsetCodesSymbols[offsetCodesState]; + + int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode]; + int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode]; + int offsetBits = offsetCode; + + int offset = OFFSET_CODES_BASE[offsetCode]; + if (offsetCode > 0) { + offset += peekBits(bitsConsumed, bits, offsetBits); + bitsConsumed += offsetBits; + } + + if (offsetCode <= 1) { + if (literalsLengthCode == 0) { + offset++; + } + + if (offset != 0) { + int temp; + if (offset == 3) { + temp = previousOffsets[0] - 1; + } + else { + temp = previousOffsets[offset]; + } + + if (temp == 0) { + temp = 1; + } + + if (offset != 1) { + previousOffsets[2] = previousOffsets[1]; + } + previousOffsets[1] = previousOffsets[0]; + previousOffsets[0] = temp; + + offset = temp; + } + else { + offset = previousOffsets[0]; + } + } + else { + previousOffsets[2] = previousOffsets[1]; + previousOffsets[1] = previousOffsets[0]; + previousOffsets[0] = offset; + } + + int matchLength = MATCH_LENGTH_BASE[matchLengthCode]; + if (matchLengthCode > 31) { + matchLength += peekBits(bitsConsumed, bits, matchLengthBits); + bitsConsumed += matchLengthBits; + } + + int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode]; + if (literalsLengthCode > 15) { + literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits); + bitsConsumed += literalsLengthBits; + } + + int totalBits = literalsLengthBits + matchLengthBits + offsetBits; + if (totalBits > 64 - 7 - (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG + OFFSET_TABLE_LOG)) { + BitInputStream.Loader loader1 = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed); + loader1.load(); + + bitsConsumed = loader1.getBitsConsumed(); + bits = loader1.getBits(); + currentAddress = loader1.getCurrentAddress(); + } + + int numberOfBits; + + numberOfBits = literalsLengthNumbersOfBits[literalsLengthState]; + literalsLengthState = (int) (literalsLengthNewStates[literalsLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits + bitsConsumed += numberOfBits; + + numberOfBits = matchLengthNumbersOfBits[matchLengthState]; + matchLengthState = (int) (matchLengthNewStates[matchLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits + bitsConsumed += numberOfBits; + + numberOfBits = offsetCodesNumbersOfBits[offsetCodesState]; + offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 8 bits + bitsConsumed += numberOfBits; + + final long literalOutputLimit = output + literalsLength; + final long matchOutputLimit = literalOutputLimit + matchLength; + + verify(matchOutputLimit <= outputLimit, input, "Output buffer too small"); + long literalEnd = literalsInput + literalsLength; + verify(literalEnd <= literalsLimit, input, "Input is corrupted"); + + long matchAddress = literalOutputLimit - offset; + verify(matchAddress >= outputAbsoluteBaseAddress, input, "Input is corrupted"); + + if (literalOutputLimit > fastOutputLimit) { + executeLastSequence(outputBase, output, literalOutputLimit, matchOutputLimit, fastOutputLimit, literalsInput, matchAddress); + } + else { + // copy literals. literalOutputLimit <= fastOutputLimit, so we can copy + // long at a time with over-copy + output = copyLiterals(outputBase, literalsBase, output, literalsInput, literalOutputLimit); + copyMatch(outputBase, fastOutputLimit, output, offset, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit); + } + output = matchOutputLimit; + literalsInput = literalEnd; + } + } + + // last literal segment + output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output, literalsInput); + + return (int) (output - outputAddress); + } + + private static long copyLastLiteral(Object outputBase, Object literalsBase, long literalsLimit, long output, long literalsInput) + { + long lastLiteralsSize = literalsLimit - literalsInput; + UNSAFE.copyMemory(literalsBase, literalsInput, outputBase, output, lastLiteralsSize); + output += lastLiteralsSize; + return output; + } + + private static void copyMatch(Object outputBase, + long fastOutputLimit, + long output, + int offset, + long matchOutputLimit, + long matchAddress, + int matchLength, + long fastMatchOutputLimit) + { + matchAddress = copyMatchHead(outputBase, output, offset, matchAddress); + output += SIZE_OF_LONG; + matchLength -= SIZE_OF_LONG; // first 8 bytes copied above + + copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit); + } + + private static void copyMatchTail(Object outputBase, long fastOutputLimit, long output, long matchOutputLimit, long matchAddress, int matchLength, long fastMatchOutputLimit) + { + // fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the + // whole invocation to decompressSequences. Otherwise, we'd just compute it here. + // If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer + // will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition + if (matchOutputLimit < fastMatchOutputLimit) { + int copied = 0; + do { + UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress)); + output += SIZE_OF_LONG; + matchAddress += SIZE_OF_LONG; + copied += SIZE_OF_LONG; + } + while (copied < matchLength); + } + else { + while (output < fastOutputLimit) { + UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress)); + matchAddress += SIZE_OF_LONG; + output += SIZE_OF_LONG; + } + + while (output < matchOutputLimit) { + UNSAFE.putByte(outputBase, output++, UNSAFE.getByte(outputBase, matchAddress++)); + } + } + } + + private static long copyMatchHead(Object outputBase, long output, int offset, long matchAddress) + { + // copy match + if (offset < 8) { + // 8 bytes apart so that we can copy long-at-a-time below + int increment32 = DEC_32_TABLE[offset]; + int decrement64 = DEC_64_TABLE[offset]; + + UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress)); + UNSAFE.putByte(outputBase, output + 1, UNSAFE.getByte(outputBase, matchAddress + 1)); + UNSAFE.putByte(outputBase, output + 2, UNSAFE.getByte(outputBase, matchAddress + 2)); + UNSAFE.putByte(outputBase, output + 3, UNSAFE.getByte(outputBase, matchAddress + 3)); + matchAddress += increment32; + + UNSAFE.putInt(outputBase, output + 4, UNSAFE.getInt(outputBase, matchAddress)); + matchAddress -= decrement64; + } + else { + UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress)); + matchAddress += SIZE_OF_LONG; + } + return matchAddress; + } + + private static long copyLiterals(Object outputBase, Object literalsBase, long output, long literalsInput, long literalOutputLimit) + { + long literalInput = literalsInput; + do { + UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput)); + output += SIZE_OF_LONG; + literalInput += SIZE_OF_LONG; + } + while (output < literalOutputLimit); + output = literalOutputLimit; // correction in case we over-copied + return output; + } + + private long computeMatchLengthTable(int matchLengthType, Object inputBase, long input, long inputLimit) + { + switch (matchLengthType) { + case SEQUENCE_ENCODING_RLE: + verify(input < inputLimit, input, "Not enough input bytes"); + + byte value = UNSAFE.getByte(inputBase, input++); + verify(value <= MAX_MATCH_LENGTH_SYMBOL, input, "Value exceeds expected maximum value"); + + FseTableReader.initializeRleTable(matchLengthTable, value); + currentMatchLengthTable = matchLengthTable; + break; + case SEQUENCE_ENCODING_BASIC: + currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE; + break; + case SEQUENCE_ENCODING_REPEAT: + verify(currentMatchLengthTable != null, input, "Expected match length table to be present"); + break; + case SEQUENCE_ENCODING_COMPRESSED: + input += fse.readFseTable(matchLengthTable, inputBase, input, inputLimit, MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG); + currentMatchLengthTable = matchLengthTable; + break; + default: + throw fail(input, "Invalid match length encoding type"); + } + return input; + } + + private long computeOffsetsTable(int offsetCodesType, Object inputBase, long input, long inputLimit) + { + switch (offsetCodesType) { + case SEQUENCE_ENCODING_RLE: + verify(input < inputLimit, input, "Not enough input bytes"); + + byte value = UNSAFE.getByte(inputBase, input++); + verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input, "Value exceeds expected maximum value"); + + FseTableReader.initializeRleTable(offsetCodesTable, value); + currentOffsetCodesTable = offsetCodesTable; + break; + case SEQUENCE_ENCODING_BASIC: + currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE; + break; + case SEQUENCE_ENCODING_REPEAT: + verify(currentOffsetCodesTable != null, input, "Expected match length table to be present"); + break; + case SEQUENCE_ENCODING_COMPRESSED: + input += fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit, DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG); + currentOffsetCodesTable = offsetCodesTable; + break; + default: + throw fail(input, "Invalid offset code encoding type"); + } + return input; + } + + private long computeLiteralsTable(int literalsLengthType, Object inputBase, long input, long inputLimit) + { + switch (literalsLengthType) { + case SEQUENCE_ENCODING_RLE: + verify(input < inputLimit, input, "Not enough input bytes"); + + byte value = UNSAFE.getByte(inputBase, input++); + verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input, "Value exceeds expected maximum value"); + + FseTableReader.initializeRleTable(literalsLengthTable, value); + currentLiteralsLengthTable = literalsLengthTable; + break; + case SEQUENCE_ENCODING_BASIC: + currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE; + break; + case SEQUENCE_ENCODING_REPEAT: + verify(currentLiteralsLengthTable != null, input, "Expected match length table to be present"); + break; + case SEQUENCE_ENCODING_COMPRESSED: + input += fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit, MAX_LITERALS_LENGTH_SYMBOL, LITERAL_LENGTH_TABLE_LOG); + currentLiteralsLengthTable = literalsLengthTable; + break; + default: + throw fail(input, "Invalid literals length encoding type"); + } + return input; + } + + private void executeLastSequence(Object outputBase, long output, long literalOutputLimit, long matchOutputLimit, long fastOutputLimit, long literalInput, long matchAddress) + { + // copy literals + if (output < fastOutputLimit) { + // wild copy + do { + UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput)); + output += SIZE_OF_LONG; + literalInput += SIZE_OF_LONG; + } + while (output < fastOutputLimit); + + literalInput -= output - fastOutputLimit; + output = fastOutputLimit; + } + + while (output < literalOutputLimit) { + UNSAFE.putByte(outputBase, output, UNSAFE.getByte(literalsBase, literalInput)); + output++; + literalInput++; + } + + // copy match + while (output < matchOutputLimit) { + UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress)); + output++; + matchAddress++; + } + } + + private int decodeCompressedLiterals(Object inputBase, final long inputAddress, int blockSize, int literalsBlockType) + { + long input = inputAddress; + verify(blockSize >= 5, input, "Not enough input bytes"); + + // compressed + int compressedSize; + int uncompressedSize; + boolean singleStream = false; + int headerSize; + int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11; + switch (type) { + case 0: + singleStream = true; + case 1: { + int header = UNSAFE.getInt(inputBase, input); + + headerSize = 3; + uncompressedSize = (header >>> 4) & mask(10); + compressedSize = (header >>> 14) & mask(10); + break; + } + case 2: { + int header = UNSAFE.getInt(inputBase, input); + + headerSize = 4; + uncompressedSize = (header >>> 4) & mask(14); + compressedSize = (header >>> 18) & mask(14); + break; + } + case 3: { + // read 5 little-endian bytes + long header = UNSAFE.getByte(inputBase, input) & 0xFF | + (UNSAFE.getInt(inputBase, input + 1) & 0xFFFF_FFFFL) << 8; + + headerSize = 5; + uncompressedSize = (int) ((header >>> 4) & mask(18)); + compressedSize = (int) ((header >>> 22) & mask(18)); + break; + } + default: + throw fail(input, "Invalid literals header size type"); + } + + verify(uncompressedSize <= MAX_BLOCK_SIZE, input, "Block exceeds maximum size"); + verify(headerSize + compressedSize <= blockSize, input, "Input is corrupted"); + + input += headerSize; + + long inputLimit = input + compressedSize; + if (literalsBlockType != TREELESS_LITERALS_BLOCK) { + input += huffman.readTable(inputBase, input, compressedSize); + } + + literalsBase = literals; + literalsAddress = ARRAY_BYTE_BASE_OFFSET; + literalsLimit = ARRAY_BYTE_BASE_OFFSET + uncompressedSize; + + if (singleStream) { + huffman.decodeSingleStream(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit); + } + else { + huffman.decode4Streams(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit); + } + + return headerSize + compressedSize; + } + + private int decodeRleLiterals(Object inputBase, final long inputAddress, int blockSize) + { + long input = inputAddress; + int outputSize; + + int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11; + switch (type) { + case 0: + case 2: + outputSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3; + input++; + break; + case 1: + outputSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4; + input += 2; + break; + case 3: + // we need at least 4 bytes (3 for the header, 1 for the payload) + verify(blockSize >= SIZE_OF_INT, input, "Not enough input bytes"); + outputSize = (UNSAFE.getInt(inputBase, input) & 0xFF_FFFF) >>> 4; + input += 3; + break; + default: + throw fail(input, "Invalid RLE literals header encoding type"); + } + + verify(outputSize <= MAX_BLOCK_SIZE, input, "Output exceeds maximum block size"); + + byte value = UNSAFE.getByte(inputBase, input++); + Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value); + + literalsBase = literals; + literalsAddress = ARRAY_BYTE_BASE_OFFSET; + literalsLimit = ARRAY_BYTE_BASE_OFFSET + outputSize; + + return (int) (input - inputAddress); + } + + private int decodeRawLiterals(Object inputBase, final long inputAddress, long inputLimit) + { + long input = inputAddress; + int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11; + + int literalSize; + switch (type) { + case 0: + case 2: + literalSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3; + input++; + break; + case 1: + literalSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4; + input += 2; + break; + case 3: + // read 3 little-endian bytes + int header = ((UNSAFE.getByte(inputBase, input) & 0xFF) | + ((UNSAFE.getShort(inputBase, input + 1) & 0xFFFF) << 8)); + + literalSize = header >>> 4; + input += 3; + break; + default: + throw fail(input, "Invalid raw literals header encoding type"); + } + + verify(input + literalSize <= inputLimit, input, "Not enough input bytes"); + + // Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding + // Otherwise, copy literals into buffer that's big enough to guarantee that + if (literalSize > (inputLimit - input) - SIZE_OF_LONG) { + literalsBase = literals; + literalsAddress = ARRAY_BYTE_BASE_OFFSET; + literalsLimit = ARRAY_BYTE_BASE_OFFSET + literalSize; + + UNSAFE.copyMemory(inputBase, input, literals, literalsAddress, literalSize); + Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0); + } + else { + literalsBase = inputBase; + literalsAddress = input; + literalsLimit = literalsAddress + literalSize; + } + input += literalSize; + + return (int) (input - inputAddress); + } + + static FrameHeader readFrameHeader(final Object inputBase, final long inputAddress, final long inputLimit) + { + long input = inputAddress; + verify(input < inputLimit, input, "Not enough input bytes"); + + int frameHeaderDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF; + boolean singleSegment = (frameHeaderDescriptor & 0b100000) != 0; + int dictionaryDescriptor = frameHeaderDescriptor & 0b11; + int contentSizeDescriptor = frameHeaderDescriptor >>> 6; + + int headerSize = 1 + + (singleSegment ? 0 : 1) + + (dictionaryDescriptor == 0 ? 0 : (1 << (dictionaryDescriptor - 1))) + + (contentSizeDescriptor == 0 ? (singleSegment ? 1 : 0) : (1 << contentSizeDescriptor)); + + verify(headerSize <= inputLimit - inputAddress, input, "Not enough input bytes"); + + // decode window size + int windowSize = -1; + if (!singleSegment) { + int windowDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF; + int exponent = windowDescriptor >>> 3; + int mantissa = windowDescriptor & 0b111; + + int base = 1 << (MIN_WINDOW_LOG + exponent); + windowSize = base + (base / 8) * mantissa; + } + + // decode dictionary id + long dictionaryId = -1; + switch (dictionaryDescriptor) { + case 1: + dictionaryId = UNSAFE.getByte(inputBase, input) & 0xFF; + input += SIZE_OF_BYTE; + break; + case 2: + dictionaryId = UNSAFE.getShort(inputBase, input) & 0xFFFF; + input += SIZE_OF_SHORT; + break; + case 3: + dictionaryId = UNSAFE.getInt(inputBase, input) & 0xFFFF_FFFFL; + input += SIZE_OF_INT; + break; + } + verify(dictionaryId == -1, input, "Custom dictionaries not supported"); + + // decode content size + long contentSize = -1; + switch (contentSizeDescriptor) { + case 0: + if (singleSegment) { + contentSize = UNSAFE.getByte(inputBase, input) & 0xFF; + input += SIZE_OF_BYTE; + } + break; + case 1: + contentSize = UNSAFE.getShort(inputBase, input) & 0xFFFF; + contentSize += 256; + input += SIZE_OF_SHORT; + break; + case 2: + contentSize = UNSAFE.getInt(inputBase, input) & 0xFFFF_FFFFL; + input += SIZE_OF_INT; + break; + case 3: + contentSize = UNSAFE.getLong(inputBase, input); + input += SIZE_OF_LONG; + break; + } + + boolean hasChecksum = (frameHeaderDescriptor & 0b100) != 0; + + return new FrameHeader( + input - inputAddress, + windowSize, + contentSize, + dictionaryId, + hasChecksum); + } + + public static long getDecompressedSize(final Object inputBase, final long inputAddress, final long inputLimit) + { + long input = inputAddress; + input += verifyMagic(inputBase, input, inputLimit); + return readFrameHeader(inputBase, input, inputLimit).contentSize; + } + + static int verifyMagic(Object inputBase, long inputAddress, long inputLimit) + { + verify(inputLimit - inputAddress >= 4, inputAddress, "Not enough input bytes"); + + int magic = UNSAFE.getInt(inputBase, inputAddress); + if (magic != MAGIC_NUMBER) { + if (magic == V07_MAGIC_NUMBER) { + throw new MalformedInputException(inputAddress, "Data encoded in unsupported ZSTD v0.7 format"); + } + throw new MalformedInputException(inputAddress, "Invalid magic prefix: " + Integer.toHexString(magic)); + } + + return SIZE_OF_INT; + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdIncrementalFrameDecompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdIncrementalFrameDecompressor.java new file mode 100644 index 0000000..171f172 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdIncrementalFrameDecompressor.java @@ -0,0 +1,391 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import io.airlift.compress.MalformedInputException; + +import java.util.Arrays; + +import static io.airlift.compress.zstd.Constants.COMPRESSED_BLOCK; +import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; +import static io.airlift.compress.zstd.Constants.RAW_BLOCK; +import static io.airlift.compress.zstd.Constants.RLE_BLOCK; +import static io.airlift.compress.zstd.Constants.SIZE_OF_BLOCK_HEADER; +import static io.airlift.compress.zstd.Constants.SIZE_OF_INT; +import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE; +import static io.airlift.compress.zstd.Util.checkArgument; +import static io.airlift.compress.zstd.Util.checkState; +import static io.airlift.compress.zstd.Util.fail; +import static io.airlift.compress.zstd.Util.verify; +import static io.airlift.compress.zstd.ZstdFrameDecompressor.MAX_WINDOW_SIZE; +import static io.airlift.compress.zstd.ZstdFrameDecompressor.decodeRawBlock; +import static io.airlift.compress.zstd.ZstdFrameDecompressor.decodeRleBlock; +import static io.airlift.compress.zstd.ZstdFrameDecompressor.readFrameHeader; +import static io.airlift.compress.zstd.ZstdFrameDecompressor.verifyMagic; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class ZstdIncrementalFrameDecompressor +{ + private enum State { + INITIAL, + READ_FRAME_MAGIC, + READ_FRAME_HEADER, + READ_BLOCK_HEADER, + READ_BLOCK, + READ_BLOCK_CHECKSUM, + FLUSH_OUTPUT + } + + private final ZstdFrameDecompressor frameDecompressor = new ZstdFrameDecompressor(); + + private State state = State.INITIAL; + private FrameHeader frameHeader; + private int blockHeader = -1; + + private int inputConsumed; + private int outputBufferUsed; + + private int inputRequired; + private int requestedOutputSize; + + // current window buffer + private byte[] windowBase = new byte[0]; + private long windowAddress = ARRAY_BYTE_BASE_OFFSET; + private long windowLimit = ARRAY_BYTE_BASE_OFFSET; + private long windowPosition = ARRAY_BYTE_BASE_OFFSET; + + private XxHash64 partialHash; + + public boolean isAtStoppingPoint() + { + return state == State.READ_FRAME_MAGIC; + } + + public int getInputConsumed() + { + return inputConsumed; + } + + public int getOutputBufferUsed() + { + return outputBufferUsed; + } + + public int getInputRequired() + { + return inputRequired; + } + + public int getRequestedOutputSize() + { + return requestedOutputSize; + } + + public void partialDecompress( + final Object inputBase, + final long inputAddress, + final long inputLimit, + final byte[] outputArray, + final int outputOffset, + final int outputLimit) + { + if (inputRequired > inputLimit - inputAddress) { + throw new IllegalArgumentException(format( + "Required %s input bytes, but only %s input bytes were supplied", + inputRequired, + inputLimit - inputAddress)); + } + if (requestedOutputSize > 0 && outputOffset >= outputLimit) { + throw new IllegalArgumentException("Not enough space in output buffer to output"); + } + + long input = inputAddress; + int output = outputOffset; + + while (true) { + // Flush ready output + { + int flushableOutputSize = computeFlushableOutputSize(frameHeader); + if (flushableOutputSize > 0) { + int freeOutputSize = outputLimit - output; + if (freeOutputSize > 0) { + int copySize = min(freeOutputSize, flushableOutputSize); + System.arraycopy(windowBase, toIntExact(windowAddress - ARRAY_BYTE_BASE_OFFSET), outputArray, output, copySize); + if (partialHash != null) { + partialHash.update(outputArray, output, copySize); + } + windowAddress += copySize; + output += copySize; + flushableOutputSize -= copySize; + } + if (flushableOutputSize > 0) { + requestOutput(inputAddress, outputOffset, input, output, flushableOutputSize); + return; + } + } + } + // verify data was completely flushed + checkState(computeFlushableOutputSize(frameHeader) == 0, "Expected output to be flushed"); + + if (state == State.READ_FRAME_MAGIC || state == State.INITIAL) { + if (inputLimit - input < 4) { + inputRequired(inputAddress, outputOffset, input, output, 4); + return; + } + input += verifyMagic(inputBase, input, inputLimit); + state = State.READ_FRAME_HEADER; + } + + if (state == State.READ_FRAME_HEADER) { + if (inputLimit - input < 1) { + inputRequired(inputAddress, outputOffset, input, output, 1); + return; + } + int frameHeaderSize = determineFrameHeaderSize(inputBase, input, inputLimit); + if (inputLimit - input < frameHeaderSize) { + inputRequired(inputAddress, outputOffset, input, output, frameHeaderSize); + return; + } + frameHeader = readFrameHeader(inputBase, input, inputLimit); + verify(frameHeaderSize == frameHeader.headerSize, input, "Unexpected frame header size"); + input += frameHeaderSize; + state = State.READ_BLOCK_HEADER; + + reset(); + if (frameHeader.hasChecksum) { + partialHash = new XxHash64(); + } + } + else { + verify(frameHeader != null, input, "Frame header is not set"); + } + + if (state == State.READ_BLOCK_HEADER) { + long inputBufferSize = inputLimit - input; + if (inputBufferSize < SIZE_OF_BLOCK_HEADER) { + inputRequired(inputAddress, outputOffset, input, output, SIZE_OF_BLOCK_HEADER); + return; + } + if (inputBufferSize >= SIZE_OF_INT) { + blockHeader = UNSAFE.getInt(inputBase, input) & 0xFF_FFFF; + } + else { + blockHeader = UNSAFE.getByte(inputBase, input) & 0xFF | + (UNSAFE.getByte(inputBase, input + 1) & 0xFF) << 8 | + (UNSAFE.getByte(inputBase, input + 2) & 0xFF) << 16; + int expected = UNSAFE.getInt(inputBase, input) & 0xFF_FFFF; + verify(blockHeader == expected, input, "oops"); + } + input += SIZE_OF_BLOCK_HEADER; + state = State.READ_BLOCK; + } + else { + verify(blockHeader != -1, input, "Block header is not set"); + } + + boolean lastBlock = (blockHeader & 1) != 0; + if (state == State.READ_BLOCK) { + int blockType = (blockHeader >>> 1) & 0b11; + int blockSize = (blockHeader >>> 3) & 0x1F_FFFF; // 21 bits + + resizeWindowBufferIfNecessary(frameHeader, blockType, blockSize); + + int decodedSize; + switch (blockType) { + case RAW_BLOCK: { + if (inputLimit - input < blockSize) { + inputRequired(inputAddress, outputOffset, input, output, blockSize); + return; + } + verify(windowLimit - windowPosition >= blockSize, input, "window buffer is too small"); + decodedSize = decodeRawBlock(inputBase, input, blockSize, windowBase, windowPosition, windowLimit); + input += blockSize; + break; + } + case RLE_BLOCK: { + if (inputLimit - input < 1) { + inputRequired(inputAddress, outputOffset, input, output, 1); + return; + } + verify(windowLimit - windowPosition >= blockSize, input, "window buffer is too small"); + decodedSize = decodeRleBlock(blockSize, inputBase, input, windowBase, windowPosition, windowLimit); + input += 1; + break; + } + case COMPRESSED_BLOCK: { + if (inputLimit - input < blockSize) { + inputRequired(inputAddress, outputOffset, input, output, blockSize); + return; + } + verify(windowLimit - windowPosition >= MAX_BLOCK_SIZE, input, "window buffer is too small"); + decodedSize = frameDecompressor.decodeCompressedBlock(inputBase, input, blockSize, windowBase, windowPosition, windowLimit, frameHeader.windowSize, windowAddress); + input += blockSize; + break; + } + default: + throw fail(input, "Invalid block type"); + } + windowPosition += decodedSize; + if (lastBlock) { + state = State.READ_BLOCK_CHECKSUM; + } + else { + state = State.READ_BLOCK_HEADER; + } + } + + if (state == State.READ_BLOCK_CHECKSUM) { + if (frameHeader.hasChecksum) { + if (inputLimit - input < SIZE_OF_INT) { + inputRequired(inputAddress, outputOffset, input, output, SIZE_OF_INT); + return; + } + + // read checksum + int checksum = UNSAFE.getInt(inputBase, input); + input += SIZE_OF_INT; + + checkState(partialHash != null, "Partial hash not set"); + + // hash remaining frame data + int pendingOutputSize = toIntExact(windowPosition - windowAddress); + partialHash.update(windowBase, toIntExact(windowAddress - ARRAY_BYTE_BASE_OFFSET), pendingOutputSize); + + // verify hash + long hash = partialHash.hash(); + if (checksum != (int) hash) { + throw new MalformedInputException(input, format("Bad checksum. Expected: %s, actual: %s", Integer.toHexString(checksum), Integer.toHexString((int) hash))); + } + } + state = State.READ_FRAME_MAGIC; + frameHeader = null; + blockHeader = -1; + } + } + } + + private void reset() + { + frameDecompressor.reset(); + + windowAddress = ARRAY_BYTE_BASE_OFFSET; + windowPosition = ARRAY_BYTE_BASE_OFFSET; + } + + private int computeFlushableOutputSize(FrameHeader frameHeader) + { + return max(0, toIntExact(windowPosition - windowAddress - (frameHeader == null ? 0 : frameHeader.computeRequiredOutputBufferLookBackSize()))); + } + + private void resizeWindowBufferIfNecessary(FrameHeader frameHeader, int blockType, int blockSize) + { + int maxBlockOutput; + if (blockType == RAW_BLOCK || blockType == RLE_BLOCK) { + maxBlockOutput = blockSize; + } + else { + maxBlockOutput = MAX_BLOCK_SIZE; + } + + // if window buffer is full, move content to head of buffer and continue + if (windowLimit - windowPosition < MAX_BLOCK_SIZE) { + // output should have been flushed at the top of this method + int requiredWindowSize = frameHeader.computeRequiredOutputBufferLookBackSize(); + checkState(windowPosition - windowAddress <= requiredWindowSize, "Expected output to be flushed"); + + int windowContentsSize = toIntExact(windowPosition - windowAddress); + + // if window content is currently offset from the array base, move to the front + if (windowAddress != ARRAY_BYTE_BASE_OFFSET) { + // copy the window contents to the head of the window buffer + System.arraycopy(windowBase, toIntExact(windowAddress - ARRAY_BYTE_BASE_OFFSET), windowBase, 0, windowContentsSize); + windowAddress = ARRAY_BYTE_BASE_OFFSET; + windowPosition = windowAddress + windowContentsSize; + } + checkState(windowAddress == ARRAY_BYTE_BASE_OFFSET, "Window should be packed"); + + // if window free space is still too small, grow array + if (windowLimit - windowPosition < maxBlockOutput) { + // if content size is set and smaller than the required window size, use the content size + int newWindowSize; + if (frameHeader.contentSize >= 0 && frameHeader.contentSize < requiredWindowSize) { + newWindowSize = toIntExact(frameHeader.contentSize); + } + else { + // double the current necessary window size + newWindowSize = (windowContentsSize + maxBlockOutput) * 2; + // limit to 4x the required window size (or block size if larger) + newWindowSize = min(newWindowSize, max(requiredWindowSize, MAX_BLOCK_SIZE) * 4); + // limit to the max window size with one max sized block + newWindowSize = min(newWindowSize, MAX_WINDOW_SIZE + MAX_BLOCK_SIZE); + // must allocate at least enough space for a max sized block + newWindowSize = max(windowContentsSize + maxBlockOutput, newWindowSize); + checkState(windowContentsSize + maxBlockOutput <= newWindowSize, "Computed new window size buffer is not large enough"); + } + windowBase = Arrays.copyOf(windowBase, newWindowSize); + windowLimit = newWindowSize + ARRAY_BYTE_BASE_OFFSET; + } + + checkState(windowLimit - windowPosition >= maxBlockOutput, "window buffer is too small"); + } + } + + private static int determineFrameHeaderSize(final Object inputBase, final long inputAddress, final long inputLimit) + { + verify(inputAddress < inputLimit, inputAddress, "Not enough input bytes"); + + int frameHeaderDescriptor = UNSAFE.getByte(inputBase, inputAddress) & 0xFF; + boolean singleSegment = (frameHeaderDescriptor & 0b100000) != 0; + int dictionaryDescriptor = frameHeaderDescriptor & 0b11; + int contentSizeDescriptor = frameHeaderDescriptor >>> 6; + + return 1 + + (singleSegment ? 0 : 1) + + (dictionaryDescriptor == 0 ? 0 : (1 << (dictionaryDescriptor - 1))) + + (contentSizeDescriptor == 0 ? (singleSegment ? 1 : 0) : (1 << contentSizeDescriptor)); + } + + private void requestOutput(long inputAddress, int outputOffset, long input, int output, int requestedOutputSize) + { + updateInputOutputState(inputAddress, outputOffset, input, output); + + checkArgument(requestedOutputSize >= 0, "requestedOutputSize is negative"); + this.requestedOutputSize = requestedOutputSize; + + this.inputRequired = 0; + } + + private void inputRequired(long inputAddress, int outputOffset, long input, int output, int inputRequired) + { + updateInputOutputState(inputAddress, outputOffset, input, output); + + checkState(inputRequired >= 0, "inputRequired is negative"); + this.inputRequired = inputRequired; + + this.requestedOutputSize = 0; + } + + private void updateInputOutputState(long inputAddress, int outputOffset, long input, int output) + { + inputConsumed = (int) (input - inputAddress); + checkState(inputConsumed >= 0, "inputConsumed is negative"); + outputBufferUsed = output - outputOffset; + checkState(outputBufferUsed >= 0, "outputBufferUsed is negative"); + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdInputStream.java b/src/main/java/io/airlift/compress/zstd/ZstdInputStream.java new file mode 100644 index 0000000..53b1ad3 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdInputStream.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +import static io.airlift.compress.zstd.Util.checkPositionIndexes; +import static io.airlift.compress.zstd.Util.checkState; +import static java.lang.Math.max; +import static java.util.Objects.requireNonNull; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class ZstdInputStream + extends InputStream +{ + private static final int MIN_BUFFER_SIZE = 4096; + + private final InputStream inputStream; + private final ZstdIncrementalFrameDecompressor decompressor = new ZstdIncrementalFrameDecompressor(); + + private byte[] inputBuffer = new byte[decompressor.getInputRequired()]; + private int inputBufferOffset; + private int inputBufferLimit; + + private byte[] singleByteOutputBuffer; + + private boolean closed; + + public ZstdInputStream(InputStream inputStream) + { + this.inputStream = requireNonNull(inputStream, "inputStream is null"); + } + + @Override + public int read() + throws IOException + { + if (singleByteOutputBuffer == null) { + singleByteOutputBuffer = new byte[1]; + } + int readSize = read(singleByteOutputBuffer, 0, 1); + checkState(readSize != 0, "A zero read size should never be returned"); + if (readSize != 1) { + return -1; + } + return singleByteOutputBuffer[0] & 0xFF; + } + + @Override + public int read(final byte[] outputBuffer, final int outputOffset, final int outputLength) + throws IOException + { + if (closed) { + throw new IOException("Stream is closed"); + } + + if (outputBuffer == null) { + throw new NullPointerException(); + } + checkPositionIndexes(outputOffset, outputOffset + outputLength, outputBuffer.length); + if (outputLength == 0) { + return 0; + } + + final int outputLimit = outputOffset + outputLength; + int outputUsed = 0; + while (outputUsed < outputLength) { + boolean enoughInput = fillInputBufferIfNecessary(decompressor.getInputRequired()); + if (!enoughInput) { + if (decompressor.isAtStoppingPoint()) { + return outputUsed > 0 ? outputUsed : -1; + } + throw new IOException("Not enough input bytes"); + } + + decompressor.partialDecompress( + inputBuffer, + inputBufferOffset + ARRAY_BYTE_BASE_OFFSET, + inputBufferLimit + ARRAY_BYTE_BASE_OFFSET, + outputBuffer, + outputOffset + outputUsed, + outputLimit); + + inputBufferOffset += decompressor.getInputConsumed(); + outputUsed += decompressor.getOutputBufferUsed(); + } + return outputUsed; + } + + private boolean fillInputBufferIfNecessary(int requiredSize) + throws IOException + { + if (inputBufferLimit - inputBufferOffset >= requiredSize) { + return true; + } + + // compact existing buffered data to the front of the buffer + if (inputBufferOffset > 0) { + int copySize = inputBufferLimit - inputBufferOffset; + System.arraycopy(inputBuffer, inputBufferOffset, inputBuffer, 0, copySize); + inputBufferOffset = 0; + inputBufferLimit = copySize; + } + + if (inputBuffer.length < requiredSize) { + inputBuffer = Arrays.copyOf(inputBuffer, max(requiredSize, MIN_BUFFER_SIZE)); + } + + while (inputBufferLimit < inputBuffer.length) { + int readSize = inputStream.read(inputBuffer, inputBufferLimit, inputBuffer.length - inputBufferLimit); + if (readSize < 0) { + break; + } + inputBufferLimit += readSize; + } + return inputBufferLimit >= requiredSize; + } + + @Override + public int available() + throws IOException + { + if (closed) { + return 0; + } + return decompressor.getRequestedOutputSize(); + } + + @Override + public void close() + throws IOException + { + if (!closed) { + closed = true; + inputStream.close(); + } + } +} diff --git a/src/main/java/io/airlift/compress/zstd/ZstdOutputStream.java b/src/main/java/io/airlift/compress/zstd/ZstdOutputStream.java new file mode 100644 index 0000000..87c7762 --- /dev/null +++ b/src/main/java/io/airlift/compress/zstd/ZstdOutputStream.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.airlift.compress.zstd; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; + +import static io.airlift.compress.zstd.CompressionParameters.DEFAULT_COMPRESSION_LEVEL; +import static io.airlift.compress.zstd.Constants.SIZE_OF_BLOCK_HEADER; +import static io.airlift.compress.zstd.Constants.SIZE_OF_LONG; +import static io.airlift.compress.zstd.Util.checkState; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class ZstdOutputStream + extends OutputStream +{ + private final OutputStream outputStream; + private final CompressionContext context; + private final int maxBufferSize; + + private XxHash64 partialHash; + + private byte[] uncompressed = new byte[0]; + private final byte[] compressed; + + // start of unprocessed data in uncompressed buffer + private int uncompressedOffset; + // end of unprocessed data in uncompressed buffer + private int uncompressedPosition; + + private boolean closed; + + public ZstdOutputStream(OutputStream outputStream) + throws IOException + { + this.outputStream = requireNonNull(outputStream, "outputStream is null"); + this.context = new CompressionContext(CompressionParameters.compute(DEFAULT_COMPRESSION_LEVEL, -1), ARRAY_BYTE_BASE_OFFSET, Integer.MAX_VALUE); + this.maxBufferSize = context.parameters.getWindowSize() * 4; + + // create output buffer large enough for a single block + int bufferSize = context.parameters.getBlockSize() + SIZE_OF_BLOCK_HEADER; + // todo is the "+ (bufferSize >>> 8)" required here? + // add extra long to give code more leeway + this.compressed = new byte[bufferSize + (bufferSize >>> 8) + SIZE_OF_LONG]; + } + + @Override + public void write(int b) + throws IOException + { + if (closed) { + throw new IOException("Stream is closed"); + } + + growBufferIfNecessary(1); + + uncompressed[uncompressedPosition++] = (byte) b; + + compressIfNecessary(); + } + + @Override + public void write(byte[] buffer) + throws IOException + { + write(buffer, 0, buffer.length); + } + + @Override + public void write(byte[] buffer, int offset, int length) + throws IOException + { + if (closed) { + throw new IOException("Stream is closed"); + } + + growBufferIfNecessary(length); + + while (length > 0) { + int writeSize = min(length, uncompressed.length - uncompressedPosition); + System.arraycopy(buffer, offset, uncompressed, uncompressedPosition, writeSize); + + uncompressedPosition += writeSize; + length -= writeSize; + offset += writeSize; + + compressIfNecessary(); + } + } + + private void growBufferIfNecessary(int length) + { + if (uncompressedPosition + length <= uncompressed.length || uncompressed.length >= maxBufferSize) { + return; + } + + // assume we will need double the current required space + int newSize = (uncompressed.length + length) * 2; + // limit to max buffer size + newSize = min(newSize, maxBufferSize); + // allocate at least a minimal buffer to start; + newSize = max(newSize, context.parameters.getBlockSize()); + uncompressed = Arrays.copyOf(uncompressed, newSize); + } + + private void compressIfNecessary() + throws IOException + { + // only flush when the buffer if is max size, full, and the buffer is larger than the window and one additional block + if (uncompressed.length >= maxBufferSize && + uncompressedPosition == uncompressed.length && + uncompressed.length - context.parameters.getWindowSize() > context.parameters.getBlockSize()) { + writeChunk(false); + } + } + + // visible for Hadoop stream + void finishWithoutClosingSource() + throws IOException + { + writeChunk(true); + closed = true; + } + + @Override + public void close() + throws IOException + { + writeChunk(true); + + closed = true; + outputStream.close(); + } + + private void writeChunk(boolean lastChunk) + throws IOException + { + int chunkSize; + if (lastChunk) { + // write all the data + chunkSize = uncompressedPosition - uncompressedOffset; + } + else { + int blockSize = context.parameters.getBlockSize(); + chunkSize = uncompressedPosition - uncompressedOffset - context.parameters.getWindowSize() - blockSize; + checkState(chunkSize > blockSize, "Must write at least one full block"); + // only write full blocks + chunkSize = (chunkSize / blockSize) * blockSize; + } + + // if first write + if (partialHash == null) { + partialHash = new XxHash64(); + + // if this is also the last chunk we know the exact size, otherwise, this is traditional streaming + int inputSize = lastChunk ? chunkSize : -1; + + int outputAddress = ARRAY_BYTE_BASE_OFFSET; + outputAddress += ZstdFrameCompressor.writeMagic(compressed, outputAddress, outputAddress + 4); + outputAddress += ZstdFrameCompressor.writeFrameHeader(compressed, outputAddress, outputAddress + 14, inputSize, context.parameters.getWindowSize()); + outputStream.write(compressed, 0, outputAddress - ARRAY_BYTE_BASE_OFFSET); + } + + partialHash.update(uncompressed, uncompressedOffset, chunkSize); + + // write one block at a time + // note this is a do while to ensure that zero length input gets at least one block written + do { + int blockSize = min(chunkSize, context.parameters.getBlockSize()); + int compressedSize = ZstdFrameCompressor.writeCompressedBlock( + uncompressed, + ARRAY_BYTE_BASE_OFFSET + uncompressedOffset, + blockSize, + compressed, + ARRAY_BYTE_BASE_OFFSET, + compressed.length, + context, + lastChunk && blockSize == chunkSize); + outputStream.write(compressed, 0, compressedSize); + uncompressedOffset += blockSize; + chunkSize -= blockSize; + } + while (chunkSize > 0); + + if (lastChunk) { + // write checksum + int hash = (int) partialHash.hash(); + outputStream.write(hash); + outputStream.write(hash >> 8); + outputStream.write(hash >> 16); + outputStream.write(hash >> 24); + } + else { + // slide window forward, leaving the entire window and the unprocessed data + int slideWindowSize = uncompressedOffset - context.parameters.getWindowSize(); + context.slideWindow(slideWindowSize); + + System.arraycopy(uncompressed, slideWindowSize, uncompressed, 0, context.parameters.getWindowSize() + (uncompressedPosition - uncompressedOffset)); + uncompressedOffset -= slideWindowSize; + uncompressedPosition -= slideWindowSize; + } + } +} diff --git a/version.txt b/version.txt index 5625e59..6085e94 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.2 +1.2.1