Skip to content

Commit

Permalink
[ARROW-5917][Java] Redesign the dictionary encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
liyafan82 committed Aug 2, 2019
1 parent 06fd2da commit a916c15
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.algorithm.dictionary;

import org.apache.arrow.algorithm.search.VectorSearcher;
import org.apache.arrow.algorithm.sort.VectorValueComparator;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.BitVectorHelper;
import org.apache.arrow.vector.ValueVector;

/**
* Dictionary encoder based on searching.
* @param <E> encoded vector type.
* @param <D> decoded vector type, which is also the dictionary type.
*/
public class SearchDictionaryEncoder<E extends BaseIntVector, D extends ValueVector> {

/**
* The dictionary for encoding/decoding.
* It must be sorted.
*/
protected final D dictionary;

/**
* The criteria by which the dictionary is sorted.
*/
protected final VectorValueComparator<D> comparator;

/**
* A flag indicating if null should be encoded.
*/
protected final boolean encodeNull;

/**
* Constructs a dictionary encoder.
* @param dictionary the dictionary. It must be in sorted order.
* @param comparator the criteria for sorting.
*/
public SearchDictionaryEncoder(D dictionary, VectorValueComparator<D> comparator) {
this(dictionary, comparator, false);
}

/**
* Constructs a dictionary encoder.
* @param dictionary the dictionary. It must be in sorted order.
* @param comparator the criteria for sorting.
* @param encodeNull if null should be encoded.
*/
public SearchDictionaryEncoder(D dictionary, VectorValueComparator<D> comparator, boolean encodeNull) {
this.dictionary = dictionary;
this.comparator = comparator;
this.encodeNull = encodeNull;
}

/**
* Encodes an input vector by binary search.
* So the algorithm takes O(nlogm) time, where n is the length of the input vector,
* and m is the length of the dictionary.
* @param input the input vector.
* @param output the output vector.
*/
public void encode(D input, E output) {
for (int i = 0; i < input.getValueCount(); i++) {
if (!encodeNull && input.isNull(i)) {
BitVectorHelper.setValidityBit(output.getValidityBuffer(), i, 0);
continue;
}

int index = VectorSearcher.binarySearch(dictionary, comparator, input, i);
if (index == -1) {
throw new IllegalStateException("The data element is not found in the dictionary");
}
output.setWithPossibleTruncate(i, index);
}
output.setValueCount(input.getValueCount());
}

/**
* Decodes an input vector. The algorithm takes O(n) time,
* where n is the length of the input vector.
* @param input the input vector.
* @param output the output vector.
*/
public void decode(E input, D output) {
for (int i = 0; i < input.getValueCount(); i++) {
if (!encodeNull && input.isNull(i)) {
BitVectorHelper.setValidityBit(output.getValidityBuffer(), i, 0);
continue;
}

int index = (int) input.getValueAsLong(i);
output.copyFrom(index, i, dictionary);
}
output.setValueCount(input.getValueCount());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.algorithm.dictionary;

import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.util.Random;

import org.apache.arrow.algorithm.sort.DefaultVectorComparators;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

/**
* Test cases for {@link SearchDictionaryEncoder}.
*/
public class TestSearchDictionaryEncoder {

private final int VECTOR_LENGTH = 50;

private final int DICTIONARY_LENGTH = 10;

private BufferAllocator allocator;

@Before
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
}

@After
public void shutdown() {
allocator.close();
}

@Test
public void testEncodeAndDecode() {
Random random = new Random();
try (VarCharVector rawVector = new VarCharVector("original vector", allocator);
IntVector encodedVector = new IntVector("encoded vector", allocator);
VarCharVector dictionary = new VarCharVector("dictionary", allocator);
VarCharVector decodedVector = new VarCharVector("decoded vector", allocator)) {

// set up dictionary
dictionary.allocateNew();
for (int i = 0; i < DICTIONARY_LENGTH; i++) {
// encode "i" as i
dictionary.setSafe(i, String.valueOf(i).getBytes());
}
dictionary.setValueCount(DICTIONARY_LENGTH);

// set up raw vector
rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH);
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
rawVector.setNull(i);
} else {
int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH;
rawVector.set(i, String.valueOf(val).getBytes());
}
}
rawVector.setValueCount(VECTOR_LENGTH);

SearchDictionaryEncoder<IntVector, VarCharVector> encoder =
new SearchDictionaryEncoder<>(dictionary, DefaultVectorComparators.createDefaultComparator(rawVector));

// perform encoding
encodedVector.allocateNew();
encoder.encode(rawVector, encodedVector);

// verify encoding results
assertEquals(rawVector.getValueCount(), encodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
assertTrue(encodedVector.isNull(i));
} else {
assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes());
}
}

// perform decoding
decodedVector.allocateNew();
encoder.decode(encodedVector, decodedVector);

// verify decoding results
assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
assertTrue(decodedVector.isNull(i));
} else {
assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i));
}
}
}
}
}

0 comments on commit a916c15

Please sign in to comment.