Skip to content

Commit

Permalink
[BEAM-14430] Adding a logical type support for Python callables to Ro…
Browse files Browse the repository at this point in the history
…w schema
  • Loading branch information
ihji committed May 10, 2022
1 parent 70b7567 commit 5a5e51e
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.UnknownLogicalType;
import org.apache.beam.sdk.util.SerializableUtils;
Expand Down Expand Up @@ -74,6 +75,7 @@ public class SchemaTranslation {
ImmutableMap.<String, Class<? extends LogicalType<?, ?>>>builder()
.put(MicrosInstant.IDENTIFIER, MicrosInstant.class)
.put(SchemaLogicalType.IDENTIFIER, SchemaLogicalType.class)
.put(PythonCallable.IDENTIFIER, PythonCallable.class)
.build();

public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.logicaltypes;

import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;

/** A logical type for PythonCallableSource objects. */
@Experimental(Experimental.Kind.SCHEMAS)
public class PythonCallable implements LogicalType<PythonCallableSource, String> {
public static final String IDENTIFIER = "beam:logical_type:python_callable:v1";

@Override
public String getIdentifier() {
return IDENTIFIER;
}

@Override
public Schema.@Nullable FieldType getArgumentType() {
return null;
}

@Override
public Schema.FieldType getBaseType() {
return Schema.FieldType.STRING;
}

@Override
public @NonNull String toBaseType(@NonNull PythonCallableSource input) {
return input.getPythonCallableCode();
}

@Override
public @NonNull PythonCallableSource toInputType(@NonNull String base) {
return PythonCallableSource.of(base);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public static Schema.FieldType fieldFromType(
return fieldFromType(type, fieldValueTypeSupplier, new HashMap<Class, Schema>());
}

// TODO(BEAM-14458): support type inference for logical types
private static Schema.FieldType fieldFromType(
TypeDescriptor type,
FieldValueTypeSupplier fieldValueTypeSupplier,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.util;

import java.io.Serializable;

/**
* A wrapper object storing a Python code that can be evaluated to Python callables in Python SDK.
*/
public class PythonCallableSource implements Serializable {
private final String pythonCallableCode;

private PythonCallableSource(String pythonCallableCode) {
this.pythonCallableCode = pythonCallableCode;
}

public static PythonCallableSource of(String pythonCallableCode) {
// TODO(BEAM-14457): check syntactic correctness of Python code if possible
return new PythonCallableSource(pythonCallableCode);
}

public String getPythonCallableCode() {
return pythonCallableCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.beam.sdk.schemas.logicaltypes.DateTime;
import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
Expand Down Expand Up @@ -132,6 +133,7 @@ public static Iterable<Schema> data() {
Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME)))
.add(Schema.of(Field.of("fixed_bytes", FieldType.logicalType(FixedBytes.of(24)))))
.add(Schema.of(Field.of("micros_instant", FieldType.logicalType(new MicrosInstant()))))
.add(Schema.of(Field.of("python_callable", FieldType.logicalType(new PythonCallable()))))
.add(
Schema.of(
Field.of("field_with_option_atomic", FieldType.STRING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.extensions.python;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
Expand Down Expand Up @@ -64,6 +65,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
// We preseve the order here since Schema's care about order of fields but the order will not
// matter when applying kwargs at the Python side.
private SortedMap<String, Object> kwargsMap;
private Map<java.lang.Class<?>, Schema.FieldType> typeHints;

private @Nullable Object @NonNull [] argsArray;
private @Nullable Row providedKwargsRow;
Expand All @@ -72,6 +74,7 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
this.fullyQualifiedName = fullyQualifiedName;
this.expansionService = expansionService;
this.kwargsMap = new TreeMap<>();
this.typeHints = new HashMap<>();
argsArray = new Object[] {};
}

Expand Down Expand Up @@ -162,6 +165,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {
return this;
}

/**
* Specifies the field type of arguments.
*
* <p>Type hints are especially useful for logical types since type inference does not work well
* for logical types.
*
* @param argType A class object for the argument type.
* @param fieldType A schema field type for the argument.
* @return updated wrapper for the cross-language transform.
*/
public PythonExternalTransform<InputT, OutputT> withTypeHint(
java.lang.Class<?> argType, Schema.FieldType fieldType) {
if (typeHints.containsKey(argType)) {
throw new IllegalArgumentException(
String.format("typehint for arg type %s already exists", argType));
}
typeHints.put(argType, fieldType);
return this;
}

@VisibleForTesting
Row buildOrGetKwargsRow() {
if (providedKwargsRow != null) {
Expand All @@ -180,15 +203,17 @@ Row buildOrGetKwargsRow() {
// * Java primitives
// * Type String
// * Type Row
private static boolean isCustomType(java.lang.Class<?> type) {
// * Any Type explicitly annotated by withTypeHint()
private boolean isCustomType(java.lang.Class<?> type) {
boolean val =
!(ClassUtils.isPrimitiveOrWrapper(type)
|| type == String.class
|| typeHints.containsKey(type)
|| Row.class.isAssignableFrom(type));
return val;
}

// If the custom type has a registered schema, we use that. OTherwise we try to register it using
// If the custom type has a registered schema, we use that. Otherwise, we try to register it using
// 'JavaFieldSchema'.
private Row convertCustomValue(Object value) {
SerializableFunction<Object, Row> toRowFunc;
Expand Down Expand Up @@ -239,6 +264,8 @@ private Schema generateSchemaDirectly(
if (field instanceof Row) {
// Rows are used as is but other types are converted to proper field types.
builder.addRowField(fieldName, ((Row) field).getSchema());
} else if (typeHints.containsKey(field.getClass())) {
builder.addField(fieldName, typeHints.get(field.getClass()));
} else {
builder.addField(
fieldName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
Expand All @@ -41,7 +43,7 @@
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ExternalPythonTransformTest implements Serializable {
public class PythonExternalTransformTest implements Serializable {
@Ignore("BEAM-14148")
@Test
public void trivialPythonTransform() {
Expand Down Expand Up @@ -184,6 +186,19 @@ public void generateArgsWithCustomType() {
assertEquals(456, (int) receivedRow.getRow("field1").getInt32("intField"));
}

@Test
public void generateArgsWithTypeHint() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withArgs(PythonCallableSource.of("dummy data"))
.withTypeHint(
PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
Row receivedRow = transform.buildOrGetArgsRow();
assertTrue(receivedRow.getValue("field0") instanceof PythonCallableSource);
}

@Test
public void generateKwargsEmpty() {
PythonExternalTransform<?, ?> transform =
Expand Down Expand Up @@ -274,6 +289,19 @@ public void generateKwargsWithCustomType() {
assertEquals(456, (int) receivedRow.getRow("customField1").getInt32("intField"));
}

@Test
public void generateKwargsWithTypeHint() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withKwarg("customField0", PythonCallableSource.of("dummy data"))
.withTypeHint(
PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
Row receivedRow = transform.buildOrGetKwargsRow();
assertTrue(receivedRow.getValue("customField0") instanceof PythonCallableSource);
}

@Test
public void generateKwargsFromMap() {
Map<String, Object> kwargsMap =
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from apache_beam.typehints.native_type_compatibility import extract_optional_type
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.utils import proto_utils
from apache_beam.utils.python_callable import PythonCallableWithSource
from apache_beam.utils.timestamp import Timestamp

PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1"
Expand Down Expand Up @@ -559,3 +560,27 @@ def to_representation_type(self, value):
def to_language_type(self, value):
# type: (MicrosInstantRepresentation) -> Timestamp
return Timestamp(seconds=int(value.seconds), micros=int(value.micros))


@LogicalType.register_logical_type
class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]):
@classmethod
def urn(cls):
return "beam:logical_type:python_callable:v1"

@classmethod
def representation_type(cls):
# type: () -> type
return str

@classmethod
def language_type(cls):
return PythonCallableWithSource

def to_representation_type(self, value):
# type: (PythonCallableWithSource) -> str
return value.get_source()

def to_language_type(self, value):
# type: (str) -> PythonCallableWithSource
return PythonCallableWithSource(value)
41 changes: 41 additions & 0 deletions sdks/python/apache_beam/utils/python_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Python Callable utilities.
For internal use only; no backwards-compatibility guarantees.
"""


class PythonCallableWithSource(object):
"""Represents a Python callable object with source codes before evaluated.
Proxy object to Store a callable object with its string form (source code).
The string form is used when the object is encoded and transferred to foreign
SDKs (non-Python SDKs).
"""
def __init__(self, source):
# type: (str) -> None
self._source = source
self._callable = eval(source) # pylint: disable=eval-used

def get_source(self):
# type: () -> str
return self._source

def __call__(self, *args, **kwargs):
return self._callable(*args, **kwargs)

0 comments on commit 5a5e51e

Please sign in to comment.