Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Went back to old setting with only roberta models
Browse files Browse the repository at this point in the history
  • Loading branch information
davidguzmanr committed Nov 30, 2022
1 parent c9ae719 commit bad27b9
Show file tree
Hide file tree
Showing 17 changed files with 6,021 additions and 5,200 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ The attacks we will use are listed in the table below, the metrics were obtained

## Setup

| **Domain** | **Substitute model** | **Target model** |
|:-----------------------------------:|:-----------------------------------------:|:---------------------------------------:|
| **similar_domain_same_task** | textattack/roberta-base-imdb | textattack/roberta-base-rotten-tomatoes |
| **similar_domain_different_task** | cardiffnlp/bertweet-base-irony | cardiffnlp/bertweet-base-offensive |
| **different_domain_same_task** | cardiffnlp/twitter-roberta-base-sentiment | textattack/roberta-base-rotten-tomatoes |
| **different_domain_different_task** | cardiffnlp/twitter-roberta-base-irony | cardiffnlp/twitter-roberta-base-irony |
| **Domain** | **Substitute model** | **Target model** |
|:-----------------------------------:|:-------------------------------------------------:|:-----------------------------------------------:|
| **similar_domain_same_task** | textattack/roberta-base-imdb | textattack/roberta-base-rotten-tomatoes |
| **similar_domain_different_task** | cardiffnlp/twitter-roberta-base-irony | cardiffnlp/twitter-roberta-base-sentiment |
| **different_domain_same_task** | cardiffnlp/twitter-roberta-base-sentiment | textattack/roberta-base-rotten-tomatoes |
| **different_domain_different_task** | cardiffnlp/twitter-roberta-base-irony | cardiffnlp/twitter-roberta-base-sentiment |

## Attacks

Expand Down
29 changes: 6 additions & 23 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@
from os.path import isfile, join
import pandas as pd

from textattack.attack_recipes import (
A2TYoo2021,
BAEGarg2019,
DeepWordBugGao2018,
PWWSRen2019,
TextBuggerLi2018,
TextFoolerJin2019,
)

from src.models.attack_model import AttackModel
from src.models.target_model import TargetModel

DOMAINS = {
Expand All @@ -28,12 +18,12 @@
# David
"similar_domain_different_task": {
"attack_model": "MDL_TWIT_IRONY", # output_labels: 2
"target_model": "MDL_TWIT_OFFENSIVE", # output_labels: 3
"target_model": "MDL_TWIT_SENTIMENT", # output_labels: 3
"target_dataset": "tweet_eval"
},
# Jean
"different_domain_same_task": {
"attack_model": "MDL_TWIT_OFFENSIVE", # output_labels: 3
"attack_model": "MDL_TWIT_SENTIMENT", # output_labels: 3
"target_model": "MDL_RT_SENTIMENT", # output_labels: 2
"target_dataset": "rotten_tomatoes"
},
Expand All @@ -45,15 +35,6 @@
}
}

ATTACKS_RECIPES = [
A2TYoo2021,
BAEGarg2019,
DeepWordBugGao2018,
PWWSRen2019,
TextBuggerLi2018,
TextFoolerJin2019
]

def main():
results = pd.DataFrame()
for domain_name in tqdm(DOMAINS):
Expand All @@ -64,18 +45,20 @@ def main():
use_cuda=True
)

path = f"logs/{domain_name}"
path = f"logs/attacks/{domain_name}"
logs = sorted([file for file in listdir(path) if isfile(join(path, file))])


new_row = {'Domain': domain_name}
for log_csv in logs:
original_accuracy, perturbed_accuracy = target_model.evaluate_attack(join(path, log_csv))
original_accuracy, perturbed_accuracy = target_model.evaluate_attack(join(path, log_csv), save_csv=True)

new_row['Original accuracy'] = original_accuracy
new_row[f'{log_csv.split("-")[0]} accuracy'] = perturbed_accuracy

results = results.append(new_row, ignore_index=True)

return results

if __name__ == "__main__":
main()

This file was deleted.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# David
"similar_domain_different_task": {
"attack_model": "MDL_TWIT_IRONY", # output_labels: 2
"target_model": "MDL_TWIT_OFFENSIVE", # output_labels: 3
"target_model": "MDL_TWIT_SENTIMENT", # output_labels: 3
"target_dataset": "tweet_eval"
},
# Jean
Expand Down
4 changes: 1 addition & 3 deletions src/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ class MODEL_PATHS(ExtendedEnum):
MDL_RT_SENTIMENT = 'textattack/roberta-base-rotten-tomatoes' # Output: LABEL_0 (negative) LABEL_1 (positive)

# TWITTER DOMAIN
MDL_TWIT_IRONY = 'cardiffnlp/bertweet-base-irony' # Output: LABEL_0 (non-ironic) LABEL_1 (ironic)
MDL_TWIT_OFFENSIVE = 'cardiffnlp/bertweet-base-offensive' # Output: LABEL_0 (Not-offensive), LABEL_1 (Offensive)
MDL_TWIT_IRONY = 'cardiffnlp/twitter-roberta-base-irony' # Output: LABEL_0 (non-ironic) LABEL_1 (ironic)
MDL_TWIT_SENTIMENT = 'cardiffnlp/twitter-roberta-base-sentiment' # Output: LABEL_0 (Negative), LABEL_1 (Neutral), LABEL_2 (Positive)
# MDL_TWIT_IRONY = 'cardiffnlp/twitter-roberta-base-irony' # Output: LABEL_0 (non-ironic) LABEL_1 (ironic)


# Datasets:
Expand Down
2 changes: 1 addition & 1 deletion src/models/attack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self.attack_recipe = attack_recipe
self.attack = attack_recipe.build(self.model_wrapped)
# For the Twitter models we need the "sentiment" subset of the dataset
self.subset = "offensive" if target_dataset == "tweet_eval" else None
self.subset = "sentiment" if target_dataset == "tweet_eval" else None
self.target_dataset = self.set_target_dataset(target_dataset)
self.attack_dataset = self.set_attack_dataset()

Expand Down

0 comments on commit bad27b9

Please sign in to comment.