Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import_spark #326

Merged
merged 13 commits into from
Jul 15, 2024
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Add support for dbt manifest file (#104)
- Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)

- Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)- `datacontract import --format spark`: Import from Spark tables (#326)

## [0.10.9] - 2024-07-03

Expand Down
180 changes: 95 additions & 85 deletions README.md

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion datacontract/imports/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def __init__(self, import_format) -> None:

@abstractmethod
def import_source(
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
self,
data_contract_specification: DataContractSpecification,
source: str,
import_args: dict,
) -> dict:
pass

Expand All @@ -24,6 +27,7 @@ class ImportFormat(str, Enum):
bigquery = "bigquery"
odcs = "odcs"
unity = "unity"
spark = "spark"

@classmethod
def get_suported_formats(cls):
Expand Down
25 changes: 20 additions & 5 deletions datacontract/imports/importer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,44 @@ def load_module_class(module_path, class_name):

importer_factory = ImporterFactory()
importer_factory.register_lazy_importer(
name=ImportFormat.avro, module_path="datacontract.imports.avro_importer", class_name="AvroImporter"
name=ImportFormat.avro,
module_path="datacontract.imports.avro_importer",
class_name="AvroImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.bigquery,
module_path="datacontract.imports.bigquery_importer",
class_name="BigQueryImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.glue, module_path="datacontract.imports.glue_importer", class_name="GlueImporter"
name=ImportFormat.glue,
module_path="datacontract.imports.glue_importer",
class_name="GlueImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.jsonschema,
module_path="datacontract.imports.jsonschema_importer",
class_name="JsonSchemaImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.odcs, module_path="datacontract.imports.odcs_importer", class_name="OdcsImporter"
name=ImportFormat.odcs,
module_path="datacontract.imports.odcs_importer",
class_name="OdcsImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.sql, module_path="datacontract.imports.sql_importer", class_name="SqlImporter"
name=ImportFormat.sql,
module_path="datacontract.imports.sql_importer",
class_name="SqlImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.unity, module_path="datacontract.imports.unity_importer", class_name="UnityImporter"
name=ImportFormat.unity,
module_path="datacontract.imports.unity_importer",
class_name="UnityImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.spark,
module_path="datacontract.imports.spark_importer",
class_name="SparkImporter",
)
importer_factory.register_lazy_importer(
name=ImportFormat.dbt, module_path="datacontract.imports.dbt_importer", class_name="DbtManifestImporter"
Expand Down
134 changes: 134 additions & 0 deletions datacontract/imports/spark_importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from pyspark.sql import DataFrame, SparkSession, types
from datacontract.imports.importer import Importer
from datacontract.model.data_contract_specification import (
DataContractSpecification,
Model,
Field,
Server,
)


class SparkImporter(Importer):
def import_source(
self,
data_contract_specification: DataContractSpecification,
source: str,
import_args: dict,
) -> dict:
"""
Imports data from a Spark source into the data contract specification.

Args:
data_contract_specification: The data contract specification object.
source: The source string indicating the Spark tables to read.
import_args: Additional arguments for the import process.

Returns:
dict: The updated data contract specification.
"""
return import_spark(data_contract_specification, source)


def import_spark(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification:
"""
Reads Spark tables and updates the data contract specification with their schemas.

Args:
data_contract_specification: The data contract specification to update.
source: A comma-separated string of Spark temporary views to read.

Returns:
DataContractSpecification: The updated data contract specification.
"""
spark = SparkSession.builder.getOrCreate()
data_contract_specification.servers["local"] = Server(type="dataframe")
for temp_view in source.split(","):
temp_view = temp_view.strip()
df = spark.read.table(temp_view)
data_contract_specification.models[temp_view] = import_from_spark_df(df)
return data_contract_specification


def import_from_spark_df(df: DataFrame) -> Model:
"""
Converts a Spark DataFrame into a Model.

Args:
df: The Spark DataFrame to convert.

Returns:
Model: The generated data contract model.
"""
model = Model()
schema = df.schema

for field in schema:
model.fields[field.name] = _field_from_spark(field)

return model


def _field_from_spark(spark_field: types.StructField) -> Field:
"""
Converts a Spark StructField into a Field object for the data contract.

Args:
spark_field: The Spark StructField to convert.

Returns:
Field: The corresponding Field object.
"""
field_type = _data_type_from_spark(spark_field.dataType)
field = Field()
field.type = field_type
field.required = not spark_field.nullable

if field_type == "array":
field.items = _field_from_spark(spark_field.dataType.elementType)

if field_type == "struct":
field.fields = {sf.name: _field_from_spark(sf) for sf in spark_field.dataType.fields}

return field


def _data_type_from_spark(spark_type: types.DataType) -> str:
"""
Maps Spark data types to the Data Contract type system.

Args:
spark_type: The Spark data type to map.

Returns:
str: The corresponding Data Contract type.
"""
if isinstance(spark_type, types.StringType):
return "string"
elif isinstance(spark_type, types.IntegerType):
return "integer"
elif isinstance(spark_type, types.LongType):
return "long"
elif isinstance(spark_type, types.FloatType):
return "float"
elif isinstance(spark_type, types.DoubleType):
return "double"
elif isinstance(spark_type, types.StructType):
return "struct"
elif isinstance(spark_type, types.ArrayType):
return "array"
elif isinstance(spark_type, types.TimestampType):
return "timestamp"
elif isinstance(spark_type, types.TimestampNTZType):
return "timestamp_ntz"
elif isinstance(spark_type, types.DateType):
return "date"
elif isinstance(spark_type, types.BooleanType):
return "boolean"
elif isinstance(spark_type, types.BinaryType):
return "bytes"
elif isinstance(spark_type, types.DecimalType):
return "decimal"
elif isinstance(spark_type, types.NullType):
return "null"
else:
raise ValueError(f"Unsupported Spark type: {spark_type}")
162 changes: 162 additions & 0 deletions tests/test_import_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import yaml
import pytest
from pyspark.sql import types

from pyspark.sql import SparkSession

from datacontract.data_contract import DataContract

from typer.testing import CliRunner
from datacontract.cli import app


expected = """
dataContractSpecification: 0.9.3
id: my-data-contract-id
info:
title: My Data Contract
version: 0.0.1
servers:
local:
type: dataframe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not add this here, as there is no server type dataframe at the moment at datacontract.com

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is. At least in the schema.

models:
users:
fields:
id:
type: string
required: false
name:
type: string
required: false
address:
type: struct
required: false
fields:
number:
type: integer
required: false
street:
type: string
required: false
city:
type: string
required: false
"""


@pytest.fixture(scope="session")
def spark(tmp_path_factory) -> SparkSession:
"""Create and configure a Spark session."""
spark = (
SparkSession.builder.appName("datacontract-dataframe-unittest")
.config(
"spark.sql.warehouse.dir",
f"{tmp_path_factory.mktemp('spark')}/spark-warehouse",
)
.config("spark.streaming.stopGracefullyOnShutdown", "true")
.config(
"spark.jars.packages",
"org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0,org.apache.spark:spark-avro_2.12:3.5.0",
)
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
print(f"Using PySpark version {spark.version}")
return spark


def test_cli(spark: SparkSession):
df_user = spark.createDataFrame(
data=[
{
"id": "1",
"name": "John Doe",
"address": {
"number": 123,
"street": "Maple Street",
"city": "Anytown",
},
}
],
schema=types.StructType(
[
types.StructField("id", types.StringType()),
types.StructField("name", types.StringType()),
types.StructField(
"address",
types.StructType(
[
types.StructField("number", types.IntegerType()),
types.StructField("street", types.StringType()),
types.StructField("city", types.StringType()),
]
),
),
]
),
)

df_user.createOrReplaceTempView("users")
runner = CliRunner()
result = runner.invoke(
app,
[
"import",
"--format",
"spark",
"--source",
"users",
],
)

output = result.stdout
assert result.exit_code == 0
assert output.strip() == expected.strip()


def test_table_not_exists():
runner = CliRunner()
result = runner.invoke(
app,
[
"import",
"--format",
"spark",
"--source",
"table_not_exists",
],
)

assert result.exit_code == 1


def test_prog(spark: SparkSession):
df_user = spark.createDataFrame(
data=[
{
"id": "1",
"name": "John Doe",
"address": {"number": 123, "street": "Maple Street", "city": "Anytown"},
}
],
schema=types.StructType(
[
types.StructField("id", types.StringType()),
types.StructField("name", types.StringType()),
types.StructField(
"address",
types.StructType(
[
types.StructField("number", types.IntegerType()),
types.StructField("street", types.StringType()),
types.StructField("city", types.StringType()),
]
),
),
]
),
)

df_user.createOrReplaceTempView("users")
result = DataContract().import_from_source("spark", "users")
assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected)
Loading