Skip to content

Commit

Permalink
Support the ability to pass in None for both get_column_plot and …
Browse files Browse the repository at this point in the history
…`get_column_pair_plot` (#2344)
  • Loading branch information
R-Palazzo authored Jan 16, 2025
1 parent 6b47f9d commit 075cdff
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 39 deletions.
35 changes: 35 additions & 0 deletions sdv/evaluation/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pandas as pd


def _prepare_data_visualization(data, metadata, column_names, sample_size):
"""Prepare the data for a plot.
Args:
data (pd.DataFrame or None):
The data to be prepared.
metadata (Metadata):
The metadata of the data.
column_names (str or list[str]):
The column names to plot.
sample_size (int or None):
The number of samples to plot. If ``None``, use the whole dataset.
Returns:
pd.DataFrame or None:
The prepared data.
"""
if data is None:
return None

col_names = column_names if isinstance(column_names, list) else [column_names]
data = data.copy()
for column_name in col_names:
sdtype = metadata.columns[column_name]['sdtype']
if sdtype == 'datetime':
datetime_format = metadata.columns[column_name].get('datetime_format')
data[column_name] = pd.to_datetime(data[column_name], format=datetime_format)

if sample_size and sample_size < len(data):
data = data.sample(n=sample_size)

return data
4 changes: 2 additions & 2 deletions sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def get_column_pair_plot(
"""Get a plot of the real and synthetic data for a given column pair.
Args:
real_data (dict):
real_data (dict or None):
Dictionary containing the real table data.
synthetic_column (dict):
synthetic_column (dict or None):
Dictionary containing the synthetic table data.
metadata (Metadata):
Metadata describing the data.
Expand Down
35 changes: 9 additions & 26 deletions sdv/evaluation/single_table.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Methods to compare the real and synthetic data for single-table."""

import pandas as pd
from sdmetrics import visualization
from sdmetrics.reports.single_table.diagnostic_report import DiagnosticReport
from sdmetrics.reports.single_table.quality_report import QualityReport

from sdv.errors import VisualizationUnavailableError
from sdv.evaluation._utils import _prepare_data_visualization
from sdv.metadata.metadata import Metadata


Expand Down Expand Up @@ -68,9 +68,9 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
"""Get a plot of the real and synthetic data for a given column.
Args:
real_data (pandas.DataFrame):
real_data (pandas.DataFrame or None):
The real table data.
synthetic_data (pandas.DataFrame):
synthetic_data (pandas.DataFrame or None):
The synthetic table data.
metadata (Metadata):
The table metadata.
Expand Down Expand Up @@ -103,14 +103,8 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
"'plot_type'."
)

if sdtype == 'datetime':
datetime_format = metadata.columns.get(column_name).get('datetime_format')
real_data = pd.DataFrame({
column_name: pd.to_datetime(real_data[column_name], format=datetime_format)
})
synthetic_data = pd.DataFrame({
column_name: pd.to_datetime(synthetic_data[column_name], format=datetime_format)
})
real_data = _prepare_data_visualization(real_data, metadata, column_name, None)
synthetic_data = _prepare_data_visualization(synthetic_data, metadata, column_name, None)

return visualization.get_column_plot(
real_data, synthetic_data, column_name, plot_type=plot_type
Expand Down Expand Up @@ -147,8 +141,6 @@ def get_column_pair_plot(
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

real_data = real_data.copy()
synthetic_data = synthetic_data.copy()
if plot_type is None:
plot_type = []
for column_name in column_names:
Expand All @@ -169,18 +161,9 @@ def get_column_pair_plot(
else:
plot_type = plot_type.pop()

for column_name in column_names:
sdtype = metadata.columns.get(column_name)['sdtype']
if sdtype == 'datetime':
datetime_format = metadata.columns.get(column_name).get('datetime_format')
real_data[column_name] = pd.to_datetime(real_data[column_name], format=datetime_format)
synthetic_data[column_name] = pd.to_datetime(
synthetic_data[column_name], format=datetime_format
)

require_subsample = sample_size and sample_size < min(len(real_data), len(synthetic_data))
if require_subsample:
real_data = real_data.sample(n=sample_size)
synthetic_data = synthetic_data.sample(n=sample_size)
real_data = _prepare_data_visualization(real_data, metadata, column_names, sample_size)
synthetic_data = _prepare_data_visualization(
synthetic_data, metadata, column_names, sample_size
)

return visualization.get_column_pair_plot(real_data, synthetic_data, column_names, plot_type)
36 changes: 36 additions & 0 deletions tests/unit/evaluation/test__utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import pandas as pd

from sdv.evaluation._utils import _prepare_data_visualization
from sdv.metadata import SingleTableMetadata


def test__prepare_data_visualization():
"""Test ``_prepare_data_visualization``."""
# Setup
np.random.seed(0)
metadata = SingleTableMetadata.load_from_dict({
'columns': {
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'col2': {'sdtype': 'numerical'},
}
})
column_names = ['col1', 'col2']
sample_size = 2
data = pd.DataFrame({
'col1': ['2021-01-01', '2021-02-01', '2021-03-01'],
'col2': [4, 5, 6],
})

# Run
result = _prepare_data_visualization(data, metadata, column_names, sample_size)

# Assert
expected_result = pd.DataFrame(
{
'col1': pd.to_datetime(['2021-03-01', '2021-02-01']),
'col2': [6, 5],
},
index=[2, 1],
)
pd.testing.assert_frame_equal(result, expected_result)
Loading

0 comments on commit 075cdff

Please sign in to comment.