diff --git a/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java b/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java index 9bc503f987..c453c5c920 100644 --- a/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java +++ b/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java @@ -24,6 +24,9 @@ import feast.store.serving.redis.RedisCustomIO.RedisMutation; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -42,17 +45,27 @@ public FeatureRowToRedisMutationDoFn(Map featureSetSpecs private RedisKey getKey(FeatureRow featureRow) { FeatureSetSpec featureSetSpec = featureSetSpecs.get(featureRow.getFeatureSet()); - Set entityNames = + List entityNames = featureSetSpec.getEntitiesList().stream() .map(EntitySpec::getName) - .collect(Collectors.toSet()); + .sorted() + .collect(Collectors.toList()); + Map entityFields = new HashMap<>(); Builder redisKeyBuilder = RedisKey.newBuilder().setFeatureSet(featureRow.getFeatureSet()); for (Field field : featureRow.getFieldsList()) { if (entityNames.contains(field.getName())) { - redisKeyBuilder.addEntities(field); + entityFields.putIfAbsent(field.getName(), + Field.newBuilder() + .setName(field.getName()) + .setValue(field.getValue()) + .build() + ); } } + for (String entityName : entityNames) { + redisKeyBuilder.addEntities(entityFields.get(entityName)); + } return redisKeyBuilder.build(); } diff --git a/ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java b/ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java new file mode 100644 index 0000000000..6e0db2dd49 --- /dev/null +++ b/ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java @@ -0,0 +1,132 @@ +package feast.store.serving.redis; + +import static org.junit.Assert.*; + +import com.google.protobuf.Timestamp; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.ingestion.transform.ValidateFeatureRows; +import feast.storage.RedisProto.RedisKey; +import feast.store.serving.redis.RedisCustomIO.Method; +import feast.store.serving.redis.RedisCustomIO.RedisMutation; +import feast.test.TestUtil; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import feast.types.ValueProto.ValueType.Enum; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Rule; +import org.junit.Test; + +public class FeatureRowToRedisMutationDoFnTest { + + @Rule + public transient TestPipeline p = TestPipeline.create(); + + private FeatureSetSpec fs = FeatureSetSpec.newBuilder() + .setName("feature_set") + .setVersion(1) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_primary") + .setValueType(Enum.INT32) + .build()) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_secondary") + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) + .build(); + + @Test + public void shouldConvertRowWithDuplicateEntitiesToValidKey() { + Map featureSetSpecs = new HashMap<>(); + featureSetSpecs.put("feature_set", fs); + + FeatureRow offendingRow = FeatureRow.newBuilder() + .setFeatureSet("feature_set") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields(Field.newBuilder().setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addFields(Field.newBuilder().setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(2))) + .addFields(Field.newBuilder().setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + PCollection output = p + .apply(Create.of(Collections.singletonList(offendingRow))) + .setCoder(ProtoCoder.of(FeatureRow.class)) + .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSetSpecs))); + + RedisKey expectedKey = RedisKey.newBuilder() + .setFeatureSet("feature_set") + .addEntities(Field.newBuilder().setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities(Field.newBuilder().setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + PAssert.that(output).satisfies((SerializableFunction, Void>) input -> { + input.forEach(rm -> { + assert(Arrays.equals(rm.getKey(), expectedKey.toByteArray())); + assert(Arrays.equals(rm.getValue(), offendingRow.toByteArray())); + }); + return null; + }); + p.run(); + } + + @Test + public void shouldConvertRowWithOutOfOrderEntitiesToValidKey() { + Map featureSetSpecs = new HashMap<>(); + featureSetSpecs.put("feature_set", fs); + + FeatureRow offendingRow = FeatureRow.newBuilder() + .setFeatureSet("feature_set") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields(Field.newBuilder().setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .addFields(Field.newBuilder().setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .build(); + + PCollection output = p + .apply(Create.of(Collections.singletonList(offendingRow))) + .setCoder(ProtoCoder.of(FeatureRow.class)) + .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSetSpecs))); + + RedisKey expectedKey = RedisKey.newBuilder() + .setFeatureSet("feature_set") + .addEntities(Field.newBuilder().setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities(Field.newBuilder().setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + PAssert.that(output).satisfies((SerializableFunction, Void>) input -> { + input.forEach(rm -> { + assert(Arrays.equals(rm.getKey(), expectedKey.toByteArray())); + assert(Arrays.equals(rm.getValue(), offendingRow.toByteArray())); + }); + return null; + }); + p.run(); + } + +} \ No newline at end of file