Skip to content

Commit

Permalink
added first attempt at fasttext filter creator
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanThrush committed Sep 28, 2024
1 parent 024446a commit 64697de
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/get_error_and_bpb/get_error_and_bpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parser.add_argument("--raw_job_output_path", required=False)

parser.add_argument("--hf_llm_revision", default="main")
parser.add_argument("--loss_shards", type=int, default=50)
parser.add_argument("--num_loss_shards", type=int, default=50)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--save_model_info", action="store_true")
parser.add_argument("--device", default="cuda")
Expand Down Expand Up @@ -181,12 +181,12 @@ def get_loss_hf(examples):
shards = []

# Shard the dataset and add each shard to the list
for i in range(args.loss_shards):
for i in range(args.num_loss_shards):
if args.resume and os.path.exists(f"{args.raw_job_output_path}/loss_shards/{i}"):
shard = load_from_disk(f"{args.raw_job_output_path}/loss_shards/{i}")
shards.append(shard)
else:
shard = ds.shard(num_shards=args.loss_shards, index=i)
shard = ds.shard(num_shards=args.num_loss_shards, index=i)

shard = shard.map(
lambda example: get_loss_hf(example),
Expand Down
9 changes: 9 additions & 0 deletions examples/get_fasttext_filter/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
bpb_csv: ../get_error_and_bpb/rpjv2_sample_bpb_matrix.csv
error_csv: ../get_error_and_bpb/error_matrix.csv
target_benchmarks:
- sciq
- piqa
aggregation: domain
estimator: spearmanr
chunked_pretraining_data_sample: ../get_error_and_bpb/chunked_rpjv2_sample
fasttext_model_output_name: sciq_piqa_target_fasttext
146 changes: 146 additions & 0 deletions examples/get_fasttext_filter/get_fasttext_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import yaml
from types import SimpleNamespace
import argparse
from datasets import load_from_disk
import pandas as pd
import fasttext
from perplexty_correlations.estimation import (
product,
sign,
sign_cdf,
sign_sign,
spearmanr,
)
import numpy as np

estimators = {
"product": product,
"sign": sign,
"sign_cdf": sign_cdf,
"sign_sign": sign_sign,
"spearmanr": spearmanr,
}


def get_X_no_aggregation(df):
if "domain" in df.columns:
df = df.drop(columns=["domain"])
ordered_ids_and_chunks = df[["id", "chunk"]]
df = df.drop(columns=["id", "chunk"])
df = df.sort_index(axis=1)
X = df.to_numpy().T
return X, ordered_ids_and_chunks


def get_X_id_aggregation(df):
df = df.groupby("id", as_index=False)
df = df.mean(numeric_only=True)
ordered_ids = df[["id"]]
df = df.drop(columns=["id", "chunk"])
df = df.sort_index(axis=1)
X = df.to_numpy().T
return X, ordered_ids


def get_X_domain_aggregation(df):
df = df.groupby("domain", as_index=False)
df = df.mean(numeric_only=True)
ordered_domains = df[["domain"]]
df = df.drop(columns=["domain", "chunk"])
df = df.sort_index(axis=1)
X = df.to_numpy().T
return X, ordered_domains


get_X_functions = {
None: get_X_no_aggregation,
"id": get_X_id_aggregation,
"domain": get_X_domain_aggregation,
}


def get_y(df, target_benchmarks):
df = df[df["benchmark"].isin(["arc_easy", "piqa"])]
df = df.sort_index(axis=1)
df = df.mean(numeric_only=True)
y = df.to_numpy()
return y


parser = argparse.ArgumentParser()
parser.add_argument("--config")
args = parser.parse_args()

with open(args.config, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))

estimator = estimators[config.estimator]
get_X_function = get_X_functions[config.aggregation]
X, labels_df = get_X_function(pd.read_csv(config.bpb_csv))
y = get_y(pd.read_csv(config.error_csv), config.target_benchmarks)

estimate = estimator(X, y)

# Assume the sample used to generate the BPB data comes from the same
# dist as the data we want to pretrain on. Now, assume that we want to
# pretrain on the best half of the chunks/pages/domains. Because the linear
# projection is not sensitive to the particular values of the estimate
# (only their ranks), we can just take the half of the text with the top
# values in estimate as our pretraining data. We can also train a fastText
# model to distinguish this good pretraining data from other data, which is
# what we do here. You can then use this fastText model as a pretraining
# data filter.

labels = np.array(["__label__exclude"] * len(estimate))
labels[np.argsort(estimate)[int(len(estimate) / 2) :]] = "__label__include"

labels_df["label"] = labels

# Load the training dataset
ds = load_from_disk(config.chunked_pretraining_data_sample).to_pandas()

ds = ds.train_test_split(test_size=0.01)

train_df = ds["train"].to_pandas()
train_df = pd.merge(
train_df,
labels_df,
on=[column for column in ["id", "chunk", "domain"] if column in labels_df.columns],
how="inner",
)
for column in ["id", "chunk", "domain"]:
if column in train_df.columns:
train_df.drop([column], inplace=True)

test_df = ds["test"].to_pandas()
test_df = pd.merge(
test_df,
labels_df,
on=[column for column in ["id", "chunk", "domain"] if column in labels_df.columns],
how="inner",
)
for column in ["id", "chunk", "domain"]:
if column in test_df.columns:
test_df.drop([column], inplace=True)

# Save the processed data to a file
train_df.to_csv(
f"{config.fasttext_model_output_name}_train.txt", index=False, sep=" ", header=False
)
test_df.to_csv(
f"{config.fasttext_model_output_name}_test.txt", index=False, sep=" ", header=False
)

# Train the FastText model
model = fasttext.train_supervised(
input=f"{config.fasttext_model_output_name}_train.txt", wordNgrams=2
)

# Evaluate the model
result = model.test(f"{config.fasttext_model_output_name}_test.txt")
print(f"Number of samples: {result[0]}")
print(f"Precision@1: {result[1]}")
print(f"Recall@1: {result[2]}")

# Save the model
model.save_model(f"{config.fasttext_model_output_name}.bin")

0 comments on commit 64697de

Please sign in to comment.