Skip to content

Commit

Permalink
[Improve][Connector-V2] update vectorType
Browse files Browse the repository at this point in the history
  • Loading branch information
corgy-w committed Aug 21, 2024
1 parent 570bbb3 commit 835e43b
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ private int getBytesForValue(Object v, SeaTunnelDataType<?> dataType) {
case TIMESTAMP:
return 48;
case FLOAT_VECTOR:
return getArrayNotNullSize((Object[]) v) * 4;
case FLOAT16_VECTOR:
case BFLOAT16_VECTOR:
case BINARY_VECTOR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@

package org.apache.seatunnel.api.table.type;

import org.apache.seatunnel.api.annotation.Experimental;

import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Objects;

/**
* VectorType represents a vector type in SeaTunnel.
*
* <p>Experimental feature, use with caution
*/
@Experimental
public class VectorType<T> implements SeaTunnelDataType<T> {
private static final long serialVersionUID = 2L;

public static final VectorType<Float> VECTOR_FLOAT_TYPE =
new VectorType<>(Float.class, SqlType.FLOAT_VECTOR);
public static final VectorType<ByteBuffer> VECTOR_FLOAT_TYPE =
new VectorType<>(ByteBuffer.class, SqlType.FLOAT_VECTOR);

public static final VectorType<Map> VECTOR_SPARSE_FLOAT_TYPE =
new VectorType<>(Map.class, SqlType.SPARSE_FLOAT_VECTOR);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.common.utils;

import java.nio.Buffer;
import java.nio.ByteBuffer;

public class BufferUtil {

public static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);

for (Short value : shortArray) {
byteBuffer.putShort(value);
}

// Compatible compilation and running versions are not consistent
// Flip the buffer to prepare for reading
((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Short[] byteBufferToShortArray(ByteBuffer byteBuffer) {
Short[] shortArray = new Short[byteBuffer.capacity() / 2];

for (int i = 0; i < shortArray.length; i++) {
shortArray[i] = byteBuffer.getShort();
}

return shortArray;
}

public static ByteBuffer floatArrayToByteBuffer(Float[] floatArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(floatArray.length * 4);

for (float value : floatArray) {
byteBuffer.putFloat(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Float[] byteBufferToFloatArray(ByteBuffer byteBuffer) {
Float[] floatArray = new Float[byteBuffer.capacity() / 4];

for (int i = 0; i < floatArray.length; i++) {
floatArray[i] = byteBuffer.getFloat();
}

return floatArray;
}

public static ByteBuffer doubleArrayToByteBuffer(Double[] doubleArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(doubleArray.length * 8);

for (double value : doubleArray) {
byteBuffer.putDouble(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Double[] byteBufferToDoubleArray(ByteBuffer byteBuffer) {
Double[] doubleArray = new Double[byteBuffer.capacity() / 8];

for (int i = 0; i < doubleArray.length; i++) {
doubleArray[i] = byteBuffer.getDouble();
}

return doubleArray;
}

public static ByteBuffer intArrayToByteBuffer(Integer[] intArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(intArray.length * 4);

for (int value : intArray) {
byteBuffer.putInt(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Integer[] byteBufferToIntArray(ByteBuffer byteBuffer) {
Integer[] intArray = new Integer[byteBuffer.capacity() / 4];

for (int i = 0; i < intArray.length; i++) {
intArray[i] = byteBuffer.getInt();
}

return intArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.seatunnel.connectors.seatunnel.fake.utils;

import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.connectors.seatunnel.fake.config.FakeConfig;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.RandomUtils;

import java.math.BigDecimal;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.time.LocalDate;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -178,14 +178,14 @@ public ByteBuffer randomBinaryVector() {
return ByteBuffer.wrap(RandomUtils.nextBytes(byteCount));
}

public Float[] randomFloatVector() {
public ByteBuffer randomFloatVector() {
Float[] floatVector = new Float[fakeConfig.getVectorDimension()];
for (int i = 0; i < fakeConfig.getVectorDimension(); i++) {
floatVector[i] =
RandomUtils.nextFloat(
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
}
return floatVector;
return BufferUtil.floatArrayToByteBuffer(floatVector);
}

public ByteBuffer randomFloat16Vector() {
Expand All @@ -196,7 +196,7 @@ public ByteBuffer randomFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
float16Vector[i] = floatToFloat16(value);
}
return shortArrayToByteBuffer(float16Vector);
return BufferUtil.shortArrayToByteBuffer(float16Vector);
}

public ByteBuffer randomBFloat16Vector() {
Expand All @@ -207,7 +207,7 @@ public ByteBuffer randomBFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
bfloat16Vector[i] = floatToBFloat16(value);
}
return shortArrayToByteBuffer(bfloat16Vector);
return BufferUtil.shortArrayToByteBuffer(bfloat16Vector);
}

public Map<Integer, Float> randomSparseFloatVector() {
Expand Down Expand Up @@ -242,20 +242,6 @@ private static short floatToFloat16(float value) {
return (short) (sign | (exponent << 10) | (mantissa >> 13));
}

private static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);

for (Short value : shortArray) {
byteBuffer.putShort(value);
}

// Compatible compilation and running versions are not consistent
// Flip the buffer to prepare for reading
((Buffer) byteBuffer).flip();

return byteBuffer;
}

private static short floatToBFloat16(float value) {
int intBits = Float.floatToIntBits(value);
return (short) (intBits >> 16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.common.utils.JsonUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.catalog.MilvusOptions;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSourceConfig;
Expand Down Expand Up @@ -320,16 +321,14 @@ public static Object convertBySeaTunnelType(SeaTunnelDataType<?> fieldType, Obje
case DATE:
return value.toString();
case FLOAT_VECTOR:
List<Float> vector = new ArrayList<>();
for (Object o : (Object[]) value) {
vector.add(Float.parseFloat(o.toString()));
}
return vector;
ByteBuffer floatVectorBuffer = (ByteBuffer) value;
Float[] floats = BufferUtil.byteBufferToFloatArray(floatVectorBuffer);
return Arrays.stream(floats).collect(Collectors.toList());
case BINARY_VECTOR:
case BFLOAT16_VECTOR:
case FLOAT16_VECTOR:
ByteBuffer binaryVector = (ByteBuffer) value;
return gson.toJsonTree(binaryVector.array());
ByteBuffer vector = (ByteBuffer) value;
return gson.toJsonTree(vector.array());
case SPARSE_FLOAT_VECTOR:
return JsonParser.parseString(JacksonUtils.toJsonString(value)).getAsJsonObject();
case FLOAT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ private JsonObject buildMilvusData(SeaTunnelRow element) {

private void writeData2Collection() {
// default to use upsertReq, but upsert only works when autoID is disabled
System.out.println("enableUpsert: " + enableUpsert + ", autoId: " + autoId);
System.out.println("milvusDataCache: " + milvusDataCache);
System.out.println("collectionName: " + collectionName);
if (enableUpsert && !autoId) {
UpsertReq upsertReq =
UpsertReq.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSourceConfig;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
Expand Down Expand Up @@ -220,7 +221,7 @@ public SeaTunnelRow convertToSeaTunnelRow(
for (int i = 0; i < list.size(); i++) {
arrays[i] = Float.parseFloat(list.get(i).toString());
}
fields[fieldIndex] = arrays;
fields[fieldIndex] = BufferUtil.floatArrayToByteBuffer(arrays);
break;
} else {
throw new MilvusConnectorException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.seatunnel.e2e.connector.v2.milvus;

import org.apache.seatunnel.common.utils.BufferUtil;
import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.TestSuiteBase;
import org.apache.seatunnel.e2e.common.container.EngineType;
Expand Down Expand Up @@ -54,7 +55,6 @@
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.sql.SQLException;
import java.util.ArrayList;
Expand Down Expand Up @@ -227,7 +227,7 @@ private void initSourceData() {
List<Float> vector = Arrays.asList((float) i, (float) i, (float) i, (float) i);
row.add(VECTOR_FIELD, gson.toJsonTree(vector));
Short[] shorts = {(short) i, (short) i, (short) i, (short) i};
ByteBuffer shortByteBuffer = shortArrayToByteBuffer(shorts);
ByteBuffer shortByteBuffer = BufferUtil.shortArrayToByteBuffer(shorts);
row.add(VECTOR_FIELD2, gson.toJsonTree(shortByteBuffer.array()));
ByteBuffer binaryByteBuffer = ByteBuffer.wrap(new byte[] {16});
row.add(VECTOR_FIELD3, gson.toJsonTree(binaryByteBuffer.array()));
Expand Down Expand Up @@ -362,18 +362,4 @@ public void testMultiFakeToMilvus(TestContainer container)
Assertions.assertTrue(fileds.contains("book_intro_3"));
Assertions.assertTrue(fileds.contains("book_intro_4"));
}

private static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2); // 2 bytes per short

for (Short value : shortArray) {
byteBuffer.putShort(value);
}

// Compatible compilation and running versions are not consistent
// Flip the buffer to prepare for reading
((Buffer) byteBuffer).flip();

return byteBuffer;
}
}

0 comments on commit 835e43b

Please sign in to comment.