Skip to content

Commit

Permalink
Merge pull request #13 from soda-inria/benchmark
Browse files Browse the repository at this point in the history
Benchmark
  • Loading branch information
gaetanbrison authored Dec 17, 2024
2 parents 7662b08 + 9990a01 commit b53a890
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions carte_ai/scripts/evaluate_singletable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
RESNETRegressor,
RESNETClassifier,
)
from huggingface_hub import hf_hub_download


def _load_data(data_name):
Expand Down Expand Up @@ -76,7 +77,7 @@ def _prepare_carte_gnn(
random_state,
):
"""Preprocess for CARTE (graph construction)."""
from carte_table_to_graph_old import Table2GraphTransformer
from carte_ai.src.carte_table_to_graph import Table2GraphTransformer

data_ = data.copy()
X_train, X_test, y_train, y_test = set_split(
Expand All @@ -85,7 +86,9 @@ def _prepare_carte_gnn(
num_train,
random_state=random_state,
)
preprocessor = Table2GraphTransformer()
fasttext_path = hf_hub_download(repo_id="hi-paris/fastText", filename="cc.en.300.bin")
preprocessor = Table2GraphTransformer(lm_model="fasttext",
fasttext_model_path=fasttext_path)
X_train = preprocessor.fit_transform(X_train, y=y_train)
X_test = preprocessor.transform(X_test)
return X_train, X_test, y_train, y_test
Expand Down Expand Up @@ -147,7 +150,7 @@ def _prepare_tablevectorizer(
categorical_preprocessor = Pipeline(
steps=[
("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
("onehot", OneHotEncoder(handle_unknown="ignore", sparse=False)),
("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False)),
]
)

Expand Down

0 comments on commit b53a890

Please sign in to comment.