Skip to content

Commit

Permalink
[Improve][Connector-V2] optimized code
Browse files Browse the repository at this point in the history
  • Loading branch information
corgy-w committed Aug 23, 2024
1 parent d1e026a commit a6d9d5b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
import java.nio.Buffer;
import java.nio.ByteBuffer;

public class BufferUtil {
public class BufferUtils {

public static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
public static ByteBuffer toByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);

for (Short value : shortArray) {
Expand All @@ -52,7 +52,7 @@ public static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
return byteBuffer;
}

public static Short[] byteBufferToShortArray(ByteBuffer byteBuffer) {
public static Short[] toShortArray(ByteBuffer byteBuffer) {
Short[] shortArray = new Short[byteBuffer.capacity() / 2];

for (int i = 0; i < shortArray.length; i++) {
Expand All @@ -62,7 +62,7 @@ public static Short[] byteBufferToShortArray(ByteBuffer byteBuffer) {
return shortArray;
}

public static ByteBuffer floatArrayToByteBuffer(Float[] floatArray) {
public static ByteBuffer toByteBuffer(Float[] floatArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(floatArray.length * 4);

for (float value : floatArray) {
Expand All @@ -74,7 +74,7 @@ public static ByteBuffer floatArrayToByteBuffer(Float[] floatArray) {
return byteBuffer;
}

public static Float[] byteBufferToFloatArray(ByteBuffer byteBuffer) {
public static Float[] toFloatArray(ByteBuffer byteBuffer) {
Float[] floatArray = new Float[byteBuffer.capacity() / 4];

for (int i = 0; i < floatArray.length; i++) {
Expand All @@ -84,7 +84,7 @@ public static Float[] byteBufferToFloatArray(ByteBuffer byteBuffer) {
return floatArray;
}

public static ByteBuffer doubleArrayToByteBuffer(Double[] doubleArray) {
public static ByteBuffer toByteBuffer(Double[] doubleArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(doubleArray.length * 8);

for (double value : doubleArray) {
Expand All @@ -96,7 +96,7 @@ public static ByteBuffer doubleArrayToByteBuffer(Double[] doubleArray) {
return byteBuffer;
}

public static Double[] byteBufferToDoubleArray(ByteBuffer byteBuffer) {
public static Double[] toDoubleArray(ByteBuffer byteBuffer) {
Double[] doubleArray = new Double[byteBuffer.capacity() / 8];

for (int i = 0; i < doubleArray.length; i++) {
Expand All @@ -106,7 +106,7 @@ public static Double[] byteBufferToDoubleArray(ByteBuffer byteBuffer) {
return doubleArray;
}

public static ByteBuffer intArrayToByteBuffer(Integer[] intArray) {
public static ByteBuffer toByteBuffer(Integer[] intArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(intArray.length * 4);

for (int value : intArray) {
Expand All @@ -118,7 +118,7 @@ public static ByteBuffer intArrayToByteBuffer(Integer[] intArray) {
return byteBuffer;
}

public static Integer[] byteBufferToIntArray(ByteBuffer byteBuffer) {
public static Integer[] toIntArray(ByteBuffer byteBuffer) {
Integer[] intArray = new Integer[byteBuffer.capacity() / 4];

for (int i = 0; i < intArray.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.common.utils;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.nio.ByteBuffer;

public class BufferUtilsTest {

@Test
public void testToByteBufferAndToShortArray() {
Short[] shortArray = {1, 2, 3, 4, 5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(shortArray);
Short[] resultArray = BufferUtils.toShortArray(byteBuffer);

Assertions.assertArrayEquals(shortArray, resultArray, "Short array conversion failed");
}

@Test
public void testToByteBufferAndToFloatArray() {
Float[] floatArray = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(floatArray);
Float[] resultArray = BufferUtils.toFloatArray(byteBuffer);

Assertions.assertArrayEquals(floatArray, resultArray, "Float array conversion failed");
}

@Test
public void testToByteBufferAndToDoubleArray() {
Double[] doubleArray = {1.1, 2.2, 3.3, 4.4, 5.5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(doubleArray);
Double[] resultArray = BufferUtils.toDoubleArray(byteBuffer);

Assertions.assertArrayEquals(doubleArray, resultArray, "Double array conversion failed");
}

@Test
public void testToByteBufferAndToIntArray() {
Integer[] intArray = {1, 2, 3, 4, 5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(intArray);
Integer[] resultArray = BufferUtils.toIntArray(byteBuffer);

Assertions.assertArrayEquals(intArray, resultArray, "Integer array conversion failed");
}

@Test
public void testEmptyArrayConversion() {
// Test empty arrays
Short[] shortArray = {};
ByteBuffer shortBuffer = BufferUtils.toByteBuffer(shortArray);
Short[] shortResultArray = BufferUtils.toShortArray(shortBuffer);
Assertions.assertArrayEquals(
shortArray, shortResultArray, "Empty Short array conversion failed");

Float[] floatArray = {};
ByteBuffer floatBuffer = BufferUtils.toByteBuffer(floatArray);
Float[] floatResultArray = BufferUtils.toFloatArray(floatBuffer);
Assertions.assertArrayEquals(
floatArray, floatResultArray, "Empty Float array conversion failed");

Double[] doubleArray = {};
ByteBuffer doubleBuffer = BufferUtils.toByteBuffer(doubleArray);
Double[] doubleResultArray = BufferUtils.toDoubleArray(doubleBuffer);
Assertions.assertArrayEquals(
doubleArray, doubleResultArray, "Empty Double array conversion failed");

Integer[] intArray = {};
ByteBuffer intBuffer = BufferUtils.toByteBuffer(intArray);
Integer[] intResultArray = BufferUtils.toIntArray(intBuffer);
Assertions.assertArrayEquals(
intArray, intResultArray, "Empty Integer array conversion failed");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.seatunnel.connectors.seatunnel.fake.utils;

import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.connectors.seatunnel.fake.config.FakeConfig;

import org.apache.commons.collections4.CollectionUtils;
Expand Down Expand Up @@ -185,7 +185,7 @@ public ByteBuffer randomFloatVector() {
RandomUtils.nextFloat(
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
}
return BufferUtil.floatArrayToByteBuffer(floatVector);
return BufferUtils.toByteBuffer(floatVector);
}

public ByteBuffer randomFloat16Vector() {
Expand All @@ -196,7 +196,7 @@ public ByteBuffer randomFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
float16Vector[i] = floatToFloat16(value);
}
return BufferUtil.shortArrayToByteBuffer(float16Vector);
return BufferUtils.toByteBuffer(float16Vector);
}

public ByteBuffer randomBFloat16Vector() {
Expand All @@ -207,7 +207,7 @@ public ByteBuffer randomBFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
bfloat16Vector[i] = floatToBFloat16(value);
}
return BufferUtil.shortArrayToByteBuffer(bfloat16Vector);
return BufferUtils.toByteBuffer(bfloat16Vector);
}

public Map<Integer, Float> randomSparseFloatVector() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonError;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorErrorCode;
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.converter.AbstractJdbcRowConverter;
Expand Down Expand Up @@ -94,7 +94,7 @@ public SeaTunnelRow toInternal(ResultSet rs, TableSchema tableSchema) throws SQL
for (int i = 0; i < objects.length; i++) {
arrays[i] = Float.parseFloat(objects[i].toString());
}
fields[fieldIndex] = BufferUtil.floatArrayToByteBuffer(arrays);
fields[fieldIndex] = BufferUtils.toByteBuffer(arrays);
break;
case DOUBLE:
fields[fieldIndex] = JdbcFieldTypeUtils.getDouble(rs, resultSetIndex);
Expand Down Expand Up @@ -177,7 +177,7 @@ public PreparedStatement toExternal(
if (row.getField(fieldIndex) instanceof ByteBuffer) {
ByteBuffer byteBuffer = (ByteBuffer) row.getField(fieldIndex);
// Convert ByteBuffer to Float[]
Float[] floatArray = BufferUtil.byteBufferToFloatArray(byteBuffer);
Float[] floatArray = BufferUtils.toFloatArray(byteBuffer);
StringBuilder vector = new StringBuilder();
vector.append("[");
for (Float aFloat : floatArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.common.utils.JsonUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.catalog.MilvusOptions;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSourceConfig;
Expand Down Expand Up @@ -322,7 +322,7 @@ public static Object convertBySeaTunnelType(SeaTunnelDataType<?> fieldType, Obje
return value.toString();
case FLOAT_VECTOR:
ByteBuffer floatVectorBuffer = (ByteBuffer) value;
Float[] floats = BufferUtil.byteBufferToFloatArray(floatVectorBuffer);
Float[] floats = BufferUtils.toFloatArray(floatVectorBuffer);
return Arrays.stream(floats).collect(Collectors.toList());
case BINARY_VECTOR:
case BFLOAT16_VECTOR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSourceConfig;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
Expand Down Expand Up @@ -221,7 +221,7 @@ public SeaTunnelRow convertToSeaTunnelRow(
for (int i = 0; i < list.size(); i++) {
arrays[i] = Float.parseFloat(list.get(i).toString());
}
fields[fieldIndex] = BufferUtil.floatArrayToByteBuffer(arrays);
fields[fieldIndex] = BufferUtils.toByteBuffer(arrays);
break;
} else {
throw new MilvusConnectorException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.seatunnel.e2e.connector.v2.milvus;

import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.TestSuiteBase;
import org.apache.seatunnel.e2e.common.container.EngineType;
Expand Down Expand Up @@ -227,7 +227,7 @@ private void initSourceData() {
List<Float> vector = Arrays.asList((float) i, (float) i, (float) i, (float) i);
row.add(VECTOR_FIELD, gson.toJsonTree(vector));
Short[] shorts = {(short) i, (short) i, (short) i, (short) i};
ByteBuffer shortByteBuffer = BufferUtil.shortArrayToByteBuffer(shorts);
ByteBuffer shortByteBuffer = BufferUtils.toByteBuffer(shorts);
row.add(VECTOR_FIELD2, gson.toJsonTree(shortByteBuffer.array()));
ByteBuffer binaryByteBuffer = ByteBuffer.wrap(new byte[] {16});
row.add(VECTOR_FIELD3, gson.toJsonTree(binaryByteBuffer.array()));
Expand Down

0 comments on commit a6d9d5b

Please sign in to comment.