Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: OneHotEncoder no longer creates duplicate column names #271

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,10 +1136,10 @@ def transform_table(self, transformer: TableTransformer) -> Table:
>>> table = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]})
>>> fitted_transformer = transformer.fit(table, None)
>>> table.transform_table(fitted_transformer)
col1_1 col1_2 col2_1 col2_2 col2_4
0 1.0 0.0 1.0 0.0 0.0
1 0.0 1.0 0.0 1.0 0.0
2 1.0 0.0 0.0 0.0 1.0
col1__1 col1__2 col2__1 col2__2 col2__4
0 1.0 0.0 1.0 0.0 0.0
1 0.0 1.0 0.0 1.0 0.0
2 1.0 0.0 0.0 0.0 1.0
"""
return transformer.transform(self)

Expand Down
142 changes: 87 additions & 55 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import pandas as pd
from sklearn.preprocessing import OneHotEncoder as sk_OneHotEncoder
from collections import Counter
from typing import Any

from safeds.data.tabular.containers import Table
from safeds.data.tabular.containers import Column, Table
from safeds.data.tabular.transformation._table_transformer import (
InvertibleTableTransformer,
)
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError, ValueNotPresentWhenFittedError


class OneHotEncoder(InvertibleTableTransformer):
Expand All @@ -27,12 +27,12 @@ class OneHotEncoder(InvertibleTableTransformer):

The one-hot encoding of this table is:

| col1_a | col1_b | col1_c |
|--------|--------|--------|
| 1 | 0 | 0 |
| 0 | 1 | 0 |
| 0 | 0 | 1 |
| 1 | 0 | 0 |
| col1__a | col1__b | col1__c |
|---------|---------|---------|
| 1 | 0 | 0 |
| 0 | 1 | 0 |
| 0 | 0 | 1 |
| 1 | 0 | 0 |

The name "one-hot" comes from the fact that each row has exactly one 1 in it, and the rest of the values are 0s.
One-hot encoding is closely related to dummy variable / indicator variables, which are used in statistics.
Expand All @@ -44,16 +44,18 @@ class OneHotEncoder(InvertibleTableTransformer):
>>> table = Table({"col1": ["a", "b", "c", "a"]})
>>> transformer = OneHotEncoder()
>>> transformer.fit_and_transform(table, ["col1"])
col1_a col1_b col1_c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
3 1.0 0.0 0.0
col1__a col1__b col1__c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
3 1.0 0.0 0.0
"""

def __init__(self) -> None:
self._wrapped_transformer: sk_OneHotEncoder | None = None
# Maps each old column to (list of) new columns created from it:
self._column_names: dict[str, list[str]] | None = None
# Maps concrete values (tuples of old column and value) to corresponding new column names:
self._value_to_column: dict[tuple[str, Any], str] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
Expand Down Expand Up @@ -84,15 +86,28 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
data = table._data.copy()
data.columns = table.column_names

wrapped_transformer = sk_OneHotEncoder()
wrapped_transformer.fit(data[column_names])

result = OneHotEncoder()
result._wrapped_transformer = wrapped_transformer
result._column_names = {
column: [f"{column}_{element}" for element in table.get_column(column).get_unique_values()]
for column in column_names
}

result._column_names = {}
result._value_to_column = {}

# Keep track of number of occurrences of column names;
# initially all old column names appear exactly once:
name_counter = Counter(data.columns)

# Iterate through all columns to-be-changed:
for column in column_names:
result._column_names[column] = []
for element in table.get_column(column).get_unique_values():
base_name = f"{column}__{element}"
name_counter[base_name] += 1
new_column_name = base_name
# Check if newly created name matches some other existing column name:
if name_counter[base_name] > 1:
new_column_name += f"#{name_counter[base_name]}"
# Update dictionary entries:
result._column_names[column] += [new_column_name]
result._value_to_column[(column, element)] = new_column_name

return result

Expand All @@ -119,37 +134,49 @@ def transform(self, table: Table) -> Table:
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
if self._column_names is None or self._value_to_column is None:
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names.keys()) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

original = table._data.copy()
original.columns = table.schema.column_names

one_hot_encoded = pd.DataFrame(
self._wrapped_transformer.transform(original[self._column_names.keys()]).toarray(),
)
one_hot_encoded.columns = self._wrapped_transformer.get_feature_names_out()

unchanged = original.drop(self._column_names.keys(), axis=1)

res = Table._from_pandas_dataframe(pd.concat([unchanged, one_hot_encoded], axis=1))
encoded_values = {}
for new_column_name in self._value_to_column.values():
encoded_values[new_column_name] = [0.0 for _ in range(table.number_of_rows)]

for old_column_name in self._column_names:
for i in range(table.number_of_rows):
value = table.get_column(old_column_name).get_value(i)
try:
new_column_name = self._value_to_column[(old_column_name, value)]
except KeyError:
# This happens when a column in the to-be-transformed table contains a new value that was not
# already present in the table the OneHotEncoder was fitted on.
raise ValueNotPresentWhenFittedError(value, old_column_name) from None
encoded_values[new_column_name][i] = 1.0

for new_column in self._column_names[old_column_name]:
table = table.add_column(Column(new_column, encoded_values[new_column]))

# New columns may not be sorted:
column_names = []

for name in table.column_names:
if name not in self._column_names.keys():
column_names.append(name)
else:
column_names.extend(
[f_name for f_name in self._wrapped_transformer.get_feature_names_out() if f_name.startswith(name)],
[f_name for f_name in self._value_to_column.values() if f_name.startswith(name)],
)
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return res
# Drop old, non-encoded columns:
# (Don't do this earlier - we need the old column nams for sorting,
# plus we need to prevent the table from possibly having 0 columns temporarily.)
table = table.remove_columns(list(self._column_names.keys()))

# Apply sorting and return:
return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
Expand All @@ -174,21 +201,24 @@ def inverse_transform(self, transformed_table: Table) -> Table:
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
if self._column_names is None or self._value_to_column is None:
raise TransformerNotFittedError

data = transformed_table._data.copy()
data.columns = transformed_table.column_names
original_columns = {}
for original_column_name in self._column_names:
original_columns[original_column_name] = [None for _ in range(transformed_table.number_of_rows)]

for original_column_name, value in self._value_to_column:
constructed_column = self._value_to_column[(original_column_name, value)]
for i in range(transformed_table.number_of_rows):
if transformed_table.get_column(constructed_column)[i] == 1.0:
original_columns[original_column_name][i] = value

decoded = pd.DataFrame(
self._wrapped_transformer.inverse_transform(
transformed_table.keep_only_columns(self._wrapped_transformer.get_feature_names_out())._data,
),
columns=list(self._column_names.keys()),
)
unchanged = data.drop(self._wrapped_transformer.get_feature_names_out(), axis=1)
table = transformed_table

for column_name, encoded_column in original_columns.items():
table = table.add_column(Column(column_name, encoded_column))

res = Table._from_pandas_dataframe(pd.concat([unchanged, decoded], axis=1))
column_names = [
(
name
Expand All @@ -201,11 +231,13 @@ def inverse_transform(self, transformed_table: Table) -> Table:
][0]
]
)
for name in transformed_table.column_names
for name in table.column_names
]
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return res
# Drop old column names:
table = table.remove_columns(list(self._value_to_column.values()))

return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

def is_fitted(self) -> bool:
"""
Expand All @@ -216,4 +248,4 @@ def is_fitted(self) -> bool:
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None
return self._column_names is not None and self._value_to_column is not None
2 changes: 2 additions & 0 deletions src/safeds/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SchemaMismatchError,
TransformerNotFittedError,
UnknownColumnNameError,
ValueNotPresentWhenFittedError,
)
from safeds.exceptions._ml import (
DatasetContainsTargetError,
Expand All @@ -29,6 +30,7 @@
"SchemaMismatchError",
"TransformerNotFittedError",
"UnknownColumnNameError",
"ValueNotPresentWhenFittedError",
# ML exceptions
"DatasetContainsTargetError",
"DatasetMissesFeaturesError",
Expand Down
7 changes: 7 additions & 0 deletions src/safeds/exceptions/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ class TransformerNotFittedError(Exception):

def __init__(self) -> None:
super().__init__("The transformer has not been fitted yet.")


class ValueNotPresentWhenFittedError(Exception):
"""Exception raised when attempting to one-hot-encode a table containing values not present in the fitting phase."""

def __init__(self, value: str, column: str) -> None:
super().__init__(f"Value not present in the table the transformer was fitted on: \n{value} in column {column}.")
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def test_should_not_change_transformed_table() -> None:

expected = Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
None,
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
},
),
),
Expand All @@ -32,9 +32,9 @@
["col1"],
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
"col2": ["a", "b", "b", "c"],
},
),
Expand All @@ -49,12 +49,12 @@
["col1", "col2"],
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col2_a": [1.0, 0.0, 0.0, 0.0],
"col2_b": [0.0, 1.0, 1.0, 0.0],
"col2_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
"col2__a": [1.0, 0.0, 0.0, 0.0],
"col2__b": [0.0, 1.0, 1.0, 0.0],
"col2__c": [0.0, 0.0, 0.0, 1.0],
},
),
),
Expand Down
Loading