diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java index e1112b2472de..9a63c2d87819 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant; +import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable; import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType; import org.apache.beam.sdk.schemas.logicaltypes.UnknownLogicalType; import org.apache.beam.sdk.util.SerializableUtils; @@ -74,6 +75,7 @@ public class SchemaTranslation { ImmutableMap.>>builder() .put(MicrosInstant.IDENTIFIER, MicrosInstant.class) .put(SchemaLogicalType.IDENTIFIER, SchemaLogicalType.class) + .put(PythonCallable.IDENTIFIER, PythonCallable.class) .build(); public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/PythonCallable.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/PythonCallable.java new file mode 100644 index 000000000000..6bd43cb8ba89 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/PythonCallable.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.logicaltypes; + +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.util.PythonCallableSource; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** A logical type for PythonCallableSource objects. */ +@Experimental(Experimental.Kind.SCHEMAS) +public class PythonCallable implements LogicalType { + public static final String IDENTIFIER = "beam:logical_type:python_callable:v1"; + + @Override + public String getIdentifier() { + return IDENTIFIER; + } + + @Override + public Schema.@Nullable FieldType getArgumentType() { + return null; + } + + @Override + public Schema.FieldType getBaseType() { + return Schema.FieldType.STRING; + } + + @Override + public @NonNull String toBaseType(@NonNull PythonCallableSource input) { + return input.getPythonCallableCode(); + } + + @Override + public @NonNull PythonCallableSource toInputType(@NonNull String base) { + return PythonCallableSource.of(base); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 103405037bed..a1437c2d0ccd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -126,6 +126,7 @@ public static Schema.FieldType fieldFromType( return fieldFromType(type, fieldValueTypeSupplier, new HashMap()); } + // TODO(BEAM-14458): support type inference for logical types private static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PythonCallableSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PythonCallableSource.java new file mode 100644 index 000000000000..8875d8982963 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PythonCallableSource.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import java.io.Serializable; + +/** + * A wrapper object storing a Python code that can be evaluated to Python callables in Python SDK. + */ +public class PythonCallableSource implements Serializable { + private final String pythonCallableCode; + + private PythonCallableSource(String pythonCallableCode) { + this.pythonCallableCode = pythonCallableCode; + } + + public static PythonCallableSource of(String pythonCallableCode) { + // TODO(BEAM-14457): check syntactic correctness of Python code if possible + return new PythonCallableSource(pythonCallableCode); + } + + public String getPythonCallableCode() { + return pythonCallableCode; + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java index f4274de02ea5..9f3f7004e8c1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.schemas.logicaltypes.DateTime; import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant; +import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable; import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString; @@ -132,6 +133,7 @@ public static Iterable data() { Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME))) .add(Schema.of(Field.of("fixed_bytes", FieldType.logicalType(FixedBytes.of(24))))) .add(Schema.of(Field.of("micros_instant", FieldType.logicalType(new MicrosInstant())))) + .add(Schema.of(Field.of("python_callable", FieldType.logicalType(new PythonCallable())))) .add( Schema.of( Field.of("field_with_option_atomic", FieldType.STRING) diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index c412acd220ee..30f72429e5e9 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.extensions.python; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.SortedMap; @@ -64,6 +65,7 @@ public class PythonExternalTransform kwargsMap; + private Map, Schema.FieldType> typeHints; private @Nullable Object @NonNull [] argsArray; private @Nullable Row providedKwargsRow; @@ -72,6 +74,7 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi this.fullyQualifiedName = fullyQualifiedName; this.expansionService = expansionService; this.kwargsMap = new TreeMap<>(); + this.typeHints = new HashMap<>(); argsArray = new Object[] {}; } @@ -162,6 +165,26 @@ public PythonExternalTransform withKwargs(Row kwargs) { return this; } + /** + * Specifies the field type of arguments. + * + *

Type hints are especially useful for logical types since type inference does not work well + * for logical types. + * + * @param argType A class object for the argument type. + * @param fieldType A schema field type for the argument. + * @return updated wrapper for the cross-language transform. + */ + public PythonExternalTransform withTypeHint( + java.lang.Class argType, Schema.FieldType fieldType) { + if (typeHints.containsKey(argType)) { + throw new IllegalArgumentException( + String.format("typehint for arg type %s already exists", argType)); + } + typeHints.put(argType, fieldType); + return this; + } + @VisibleForTesting Row buildOrGetKwargsRow() { if (providedKwargsRow != null) { @@ -180,15 +203,17 @@ Row buildOrGetKwargsRow() { // * Java primitives // * Type String // * Type Row - private static boolean isCustomType(java.lang.Class type) { + // * Any Type explicitly annotated by withTypeHint() + private boolean isCustomType(java.lang.Class type) { boolean val = !(ClassUtils.isPrimitiveOrWrapper(type) || type == String.class + || typeHints.containsKey(type) || Row.class.isAssignableFrom(type)); return val; } - // If the custom type has a registered schema, we use that. OTherwise we try to register it using + // If the custom type has a registered schema, we use that. Otherwise, we try to register it using // 'JavaFieldSchema'. private Row convertCustomValue(Object value) { SerializableFunction toRowFunc; @@ -239,6 +264,8 @@ private Schema generateSchemaDirectly( if (field instanceof Row) { // Rows are used as is but other types are converted to proper field types. builder.addRowField(fieldName, ((Row) field).getSchema()); + } else if (typeHints.containsKey(field.getClass())) { + builder.addField(fieldName, typeHints.get(field.getClass())); } else { builder.addField( fieldName, diff --git a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/PythonExternalTransformTest.java similarity index 89% rename from sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java rename to sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/PythonExternalTransformTest.java index 60deebfc6e66..cfe7428ba2e5 100644 --- a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java +++ b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/PythonExternalTransformTest.java @@ -27,9 +27,11 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.util.PythonCallableSource; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; @@ -41,7 +43,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class ExternalPythonTransformTest implements Serializable { +public class PythonExternalTransformTest implements Serializable { @Ignore("BEAM-14148") @Test public void trivialPythonTransform() { @@ -184,6 +186,19 @@ public void generateArgsWithCustomType() { assertEquals(456, (int) receivedRow.getRow("field1").getInt32("intField")); } + @Test + public void generateArgsWithTypeHint() { + PythonExternalTransform transform = + PythonExternalTransform + .>, PCollection>>>from( + "DummyTransform") + .withArgs(PythonCallableSource.of("dummy data")) + .withTypeHint( + PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable())); + Row receivedRow = transform.buildOrGetArgsRow(); + assertTrue(receivedRow.getValue("field0") instanceof PythonCallableSource); + } + @Test public void generateKwargsEmpty() { PythonExternalTransform transform = @@ -274,6 +289,19 @@ public void generateKwargsWithCustomType() { assertEquals(456, (int) receivedRow.getRow("customField1").getInt32("intField")); } + @Test + public void generateKwargsWithTypeHint() { + PythonExternalTransform transform = + PythonExternalTransform + .>, PCollection>>>from( + "DummyTransform") + .withKwarg("customField0", PythonCallableSource.of("dummy data")) + .withTypeHint( + PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable())); + Row receivedRow = transform.buildOrGetKwargsRow(); + assertTrue(receivedRow.getValue("customField0") instanceof PythonCallableSource); + } + @Test public void generateKwargsFromMap() { Map kwargsMap = diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 5a04ba51722b..74feb1466343 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -78,6 +78,7 @@ from apache_beam.typehints.native_type_compatibility import extract_optional_type from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.utils import proto_utils +from apache_beam.utils.python_callable import PythonCallableWithSource from apache_beam.utils.timestamp import Timestamp PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1" @@ -559,3 +560,27 @@ def to_representation_type(self, value): def to_language_type(self, value): # type: (MicrosInstantRepresentation) -> Timestamp return Timestamp(seconds=int(value.seconds), micros=int(value.micros)) + + +@LogicalType.register_logical_type +class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]): + @classmethod + def urn(cls): + return "beam:logical_type:python_callable:v1" + + @classmethod + def representation_type(cls): + # type: () -> type + return str + + @classmethod + def language_type(cls): + return PythonCallableWithSource + + def to_representation_type(self, value): + # type: (PythonCallableWithSource) -> str + return value.get_source() + + def to_language_type(self, value): + # type: (str) -> PythonCallableWithSource + return PythonCallableWithSource(value) diff --git a/sdks/python/apache_beam/utils/python_callable.py b/sdks/python/apache_beam/utils/python_callable.py new file mode 100644 index 000000000000..9238e4de66ba --- /dev/null +++ b/sdks/python/apache_beam/utils/python_callable.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Python Callable utilities. + +For internal use only; no backwards-compatibility guarantees. +""" + + +class PythonCallableWithSource(object): + """Represents a Python callable object with source codes before evaluated. + + Proxy object to Store a callable object with its string form (source code). + The string form is used when the object is encoded and transferred to foreign + SDKs (non-Python SDKs). + """ + def __init__(self, source): + # type: (str) -> None + self._source = source + self._callable = eval(source) # pylint: disable=eval-used + + def get_source(self): + # type: () -> str + return self._source + + def __call__(self, *args, **kwargs): + return self._callable(*args, **kwargs)