-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import_spark #326
Merged
Merged
Import_spark #326
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
81901b4
feat: html improvment for arrays
pierre-monnet e15fa47
reset formating
pierre-monnet dd6ffc1
reset formating
pierre-monnet 63ab147
Merge branch 'main' of github.com:pierre-monnet/datacontract-cli
pierre-monnet a122c78
fix: readd missing code
pierre-monnet 934e497
Changelog update
pierre-monnet 5b151e3
Merge branch 'main' into main
pierre-monnet 5ab22d2
Merge branch 'datacontract:main' into main
pierre-monnet 4a745fc
Merge branch 'datacontract:main' into main
pierre-monnet 20e4b07
Merge branch 'datacontract:main' into main
pierre-monnet a992216
Merge branch 'datacontract:main' into main
pierre-monnet 1095a2a
feat: add spark importer
pierre-monnet 9ec5dce
edit changelog: add PR number
pierre-monnet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from pyspark.sql import DataFrame, SparkSession, types | ||
from datacontract.imports.importer import Importer | ||
from datacontract.model.data_contract_specification import ( | ||
DataContractSpecification, | ||
Model, | ||
Field, | ||
Server, | ||
) | ||
|
||
|
||
class SparkImporter(Importer): | ||
def import_source( | ||
self, | ||
data_contract_specification: DataContractSpecification, | ||
source: str, | ||
import_args: dict, | ||
) -> dict: | ||
""" | ||
Imports data from a Spark source into the data contract specification. | ||
|
||
Args: | ||
data_contract_specification: The data contract specification object. | ||
source: The source string indicating the Spark tables to read. | ||
import_args: Additional arguments for the import process. | ||
|
||
Returns: | ||
dict: The updated data contract specification. | ||
""" | ||
return import_spark(data_contract_specification, source) | ||
|
||
|
||
def import_spark(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification: | ||
""" | ||
Reads Spark tables and updates the data contract specification with their schemas. | ||
|
||
Args: | ||
data_contract_specification: The data contract specification to update. | ||
source: A comma-separated string of Spark temporary views to read. | ||
|
||
Returns: | ||
DataContractSpecification: The updated data contract specification. | ||
""" | ||
spark = SparkSession.builder.getOrCreate() | ||
data_contract_specification.servers["local"] = Server(type="dataframe") | ||
for temp_view in source.split(","): | ||
temp_view = temp_view.strip() | ||
df = spark.read.table(temp_view) | ||
data_contract_specification.models[temp_view] = import_from_spark_df(df) | ||
return data_contract_specification | ||
|
||
|
||
def import_from_spark_df(df: DataFrame) -> Model: | ||
""" | ||
Converts a Spark DataFrame into a Model. | ||
|
||
Args: | ||
df: The Spark DataFrame to convert. | ||
|
||
Returns: | ||
Model: The generated data contract model. | ||
""" | ||
model = Model() | ||
schema = df.schema | ||
|
||
for field in schema: | ||
model.fields[field.name] = _field_from_spark(field) | ||
|
||
return model | ||
|
||
|
||
def _field_from_spark(spark_field: types.StructField) -> Field: | ||
""" | ||
Converts a Spark StructField into a Field object for the data contract. | ||
|
||
Args: | ||
spark_field: The Spark StructField to convert. | ||
|
||
Returns: | ||
Field: The corresponding Field object. | ||
""" | ||
field_type = _data_type_from_spark(spark_field.dataType) | ||
field = Field() | ||
field.type = field_type | ||
field.required = not spark_field.nullable | ||
|
||
if field_type == "array": | ||
field.items = _field_from_spark(spark_field.dataType.elementType) | ||
|
||
if field_type == "struct": | ||
field.fields = {sf.name: _field_from_spark(sf) for sf in spark_field.dataType.fields} | ||
|
||
return field | ||
|
||
|
||
def _data_type_from_spark(spark_type: types.DataType) -> str: | ||
""" | ||
Maps Spark data types to the Data Contract type system. | ||
|
||
Args: | ||
spark_type: The Spark data type to map. | ||
|
||
Returns: | ||
str: The corresponding Data Contract type. | ||
""" | ||
if isinstance(spark_type, types.StringType): | ||
return "string" | ||
elif isinstance(spark_type, types.IntegerType): | ||
return "integer" | ||
elif isinstance(spark_type, types.LongType): | ||
return "long" | ||
elif isinstance(spark_type, types.FloatType): | ||
return "float" | ||
elif isinstance(spark_type, types.DoubleType): | ||
return "double" | ||
elif isinstance(spark_type, types.StructType): | ||
return "struct" | ||
elif isinstance(spark_type, types.ArrayType): | ||
return "array" | ||
elif isinstance(spark_type, types.TimestampType): | ||
return "timestamp" | ||
elif isinstance(spark_type, types.TimestampNTZType): | ||
return "timestamp_ntz" | ||
elif isinstance(spark_type, types.DateType): | ||
return "date" | ||
elif isinstance(spark_type, types.BooleanType): | ||
return "boolean" | ||
elif isinstance(spark_type, types.BinaryType): | ||
return "bytes" | ||
elif isinstance(spark_type, types.DecimalType): | ||
return "decimal" | ||
elif isinstance(spark_type, types.NullType): | ||
return "null" | ||
else: | ||
raise ValueError(f"Unsupported Spark type: {spark_type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import yaml | ||
import pytest | ||
from pyspark.sql import types | ||
|
||
from pyspark.sql import SparkSession | ||
|
||
from datacontract.data_contract import DataContract | ||
|
||
from typer.testing import CliRunner | ||
from datacontract.cli import app | ||
|
||
|
||
expected = """ | ||
dataContractSpecification: 0.9.3 | ||
id: my-data-contract-id | ||
info: | ||
title: My Data Contract | ||
version: 0.0.1 | ||
servers: | ||
local: | ||
type: dataframe | ||
models: | ||
users: | ||
fields: | ||
id: | ||
type: string | ||
required: false | ||
name: | ||
type: string | ||
required: false | ||
address: | ||
type: struct | ||
required: false | ||
fields: | ||
number: | ||
type: integer | ||
required: false | ||
street: | ||
type: string | ||
required: false | ||
city: | ||
type: string | ||
required: false | ||
""" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def spark(tmp_path_factory) -> SparkSession: | ||
"""Create and configure a Spark session.""" | ||
spark = ( | ||
SparkSession.builder.appName("datacontract-dataframe-unittest") | ||
.config( | ||
"spark.sql.warehouse.dir", | ||
f"{tmp_path_factory.mktemp('spark')}/spark-warehouse", | ||
) | ||
.config("spark.streaming.stopGracefullyOnShutdown", "true") | ||
.config( | ||
"spark.jars.packages", | ||
"org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0,org.apache.spark:spark-avro_2.12:3.5.0", | ||
) | ||
.getOrCreate() | ||
) | ||
spark.sparkContext.setLogLevel("WARN") | ||
print(f"Using PySpark version {spark.version}") | ||
return spark | ||
|
||
|
||
def test_cli(spark: SparkSession): | ||
df_user = spark.createDataFrame( | ||
data=[ | ||
{ | ||
"id": "1", | ||
"name": "John Doe", | ||
"address": { | ||
"number": 123, | ||
"street": "Maple Street", | ||
"city": "Anytown", | ||
}, | ||
} | ||
], | ||
schema=types.StructType( | ||
[ | ||
types.StructField("id", types.StringType()), | ||
types.StructField("name", types.StringType()), | ||
types.StructField( | ||
"address", | ||
types.StructType( | ||
[ | ||
types.StructField("number", types.IntegerType()), | ||
types.StructField("street", types.StringType()), | ||
types.StructField("city", types.StringType()), | ||
] | ||
), | ||
), | ||
] | ||
), | ||
) | ||
|
||
df_user.createOrReplaceTempView("users") | ||
runner = CliRunner() | ||
result = runner.invoke( | ||
app, | ||
[ | ||
"import", | ||
"--format", | ||
"spark", | ||
"--source", | ||
"users", | ||
], | ||
) | ||
|
||
output = result.stdout | ||
assert result.exit_code == 0 | ||
assert output.strip() == expected.strip() | ||
|
||
|
||
def test_table_not_exists(): | ||
runner = CliRunner() | ||
result = runner.invoke( | ||
app, | ||
[ | ||
"import", | ||
"--format", | ||
"spark", | ||
"--source", | ||
"table_not_exists", | ||
], | ||
) | ||
|
||
assert result.exit_code == 1 | ||
|
||
|
||
def test_prog(spark: SparkSession): | ||
df_user = spark.createDataFrame( | ||
data=[ | ||
{ | ||
"id": "1", | ||
"name": "John Doe", | ||
"address": {"number": 123, "street": "Maple Street", "city": "Anytown"}, | ||
} | ||
], | ||
schema=types.StructType( | ||
[ | ||
types.StructField("id", types.StringType()), | ||
types.StructField("name", types.StringType()), | ||
types.StructField( | ||
"address", | ||
types.StructType( | ||
[ | ||
types.StructField("number", types.IntegerType()), | ||
types.StructField("street", types.StringType()), | ||
types.StructField("city", types.StringType()), | ||
] | ||
), | ||
), | ||
] | ||
), | ||
) | ||
|
||
df_user.createOrReplaceTempView("users") | ||
result = DataContract().import_from_source("spark", "users") | ||
assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not add this here, as there is no server type
dataframe
at the moment at datacontract.comThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is. At least in the schema.