Skip to content

Commit

Permalink
rename files and fix pretraining data collection
Browse files Browse the repository at this point in the history
  • Loading branch information
kuronosec committed Mar 25, 2024
1 parent 273262a commit f3b442d
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"import math\n",
"import torch\n",
"import warnings\n",
"import pandas as pd\n",
"\n",
"from tqdm.notebook import tqdm\n",
"from ml_things import plot_dict, fix_text\n",
"from transformers import (\n",
Expand All @@ -57,6 +59,7 @@
" Trainer,\n",
" set_seed,\n",
" )\n",
"from datasets import load_dataset\n",
"\n",
"# Supress deprecation warnings\n",
"warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
Expand Down Expand Up @@ -249,9 +252,9 @@
"source": [
"# Define arguments for data, tokenizer and model arguments.\n",
"model_data_args = ModelDataArguments(\n",
" train_data_file='/data/forta/ethereum/text/pretraining/pretraining_train.txt',\n",
" eval_data_file='/data/forta/ethereum/text/pretraining/pretraining_val.txt',\n",
" line_by_line=True, \n",
" train_data_file='/data/forta/ethereum/text/pretraining/small_pretraining_train.txt',\n",
" eval_data_file='/data/forta/ethereum/text/pretraining/small_pretraining_val.txt',\n",
" line_by_line=False,\n",
" mlm=False,\n",
" whole_word_mask=False,\n",
" plm_probability=float(1/6),\n",
Expand All @@ -275,9 +278,9 @@
" do_train=True, \n",
" do_eval=True,\n",
" per_device_train_batch_size=10,\n",
" per_device_eval_batch_size=100,\n",
" per_device_eval_batch_size=10,\n",
" evaluation_strategy='steps',\n",
" logging_steps=700,\n",
" logging_steps=500,\n",
" eval_steps = None,\n",
" prediction_loss_only=True,\n",
" learning_rate = 5e-5,\n",
Expand Down Expand Up @@ -463,6 +466,8 @@
"loss_history = {'train_loss':[], 'eval_loss':[]}\n",
"perplexity_history = {'train_perplexity':[], 'eval_perplexity':[]}\n",
"\n",
"print(trainer.state.log_history)\n",
"\n",
"for log_history in trainer.state.log_history:\n",
" if 'loss' in log_history.keys():\n",
" loss_history['train_loss'].append(log_history['loss'])\n",
Expand Down Expand Up @@ -499,12 +504,102 @@
"outputs": [],
"source": [
"if training_args.do_eval:\n",
" eval_output = trainer.evaluate()\n",
" perplexity = math.exp(eval_output[\"eval_loss\"])\n",
" print('\\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))\n",
" eval_output = trainer.evaluate()\n",
" print(eval_output[\"eval_loss\"])\n",
" perplexity = math.exp(eval_output[\"eval_loss\"])\n",
" print('\\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))\n",
"else:\n",
" print('No evaluation needed. No evaluation data provided, `do_eval=False`!')"
" print('No evaluation needed. No evaluation data provided, `do_eval=False`!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def extract_normal_sequences(data, output_file):\n",
" encodings = tokenizer(\"\\n\\n\".join(data), return_tensors=\"pt\")\n",
" max_length = model.config.n_positions\n",
" stride = 512\n",
" seq_len = encodings.input_ids.size(1)\n",
" \n",
" prev_end_loc = 0\n",
" normal_sequences = []\n",
" for begin_loc in tqdm(range(0, seq_len, stride)):\n",
" end_loc = min(begin_loc + max_length, seq_len)\n",
" trg_len = end_loc - prev_end_loc # may be different from stride on last loop\n",
" input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)\n",
" normal_sequences.append(input_ids[0].tolist()) \n",
" prev_end_loc = end_loc\n",
" if end_loc == seq_len:\n",
" break \n",
" normal_data = pd.DataFrame(normal_sequences)\n",
" normal_data.to_csv(output_file, sep='\\t', index=False)\n",
"\n",
"def extract_anomalous_sequences(data, output_file):\n",
" encodings = tokenizer(\"\\n\\n\".join(data), return_tensors=\"pt\")\n",
" max_length = model.config.n_positions\n",
" stride = 512\n",
" seq_len = encodings.input_ids.size(1)\n",
" anomaly_threshold = 1.2\n",
" \n",
" nlls = []\n",
" prev_end_loc = 0\n",
" anomalous_sequences = []\n",
" for begin_loc in tqdm(range(0, seq_len, stride)):\n",
" end_loc = min(begin_loc + max_length, seq_len)\n",
" trg_len = end_loc - prev_end_loc # may be different from stride on last loop\n",
" input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)\n",
" target_ids = input_ids.clone()\n",
" target_ids[:, :-trg_len] = -100\n",
" \n",
" with torch.no_grad():\n",
" outputs = model(input_ids, labels=target_ids)\n",
" \n",
" # loss is calculated using CrossEntropyLoss which averages over valid labels\n",
" # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels\n",
" # to the left by 1.\n",
" neg_log_likelihood = outputs.loss\n",
" local_ppl = torch.exp(neg_log_likelihood)\n",
" if local_ppl > anomaly_threshold:\n",
" # print(\"local_ppl:\"+str(local_ppl))\n",
" # print(\"input_ids.shape:\"+str(input_ids.shape))\n",
" # anomalous_sequences.append(tokenizer.decode(input_ids[0]))\n",
" anomalous_sequences.append(input_ids[0].tolist())\n",
" \n",
" nlls.append(neg_log_likelihood)\n",
" \n",
" prev_end_loc = end_loc\n",
" if end_loc == seq_len:\n",
" break\n",
" \n",
" ppl = torch.exp(torch.stack(nlls).mean())\n",
" anomalous_data = pd.DataFrame(anomalous_sequences)\n",
" anomalous_data.to_csv(output_file, sep='\\t', index=False)\n",
"\n",
"# Load normal SC opcode files\n",
"training = load_dataset(\"text\", data_files={\"train\": \"/data/forta/ethereum/text/finetuning/training/normal/normal.txt\",\n",
" \"val\": \"/data/forta/ethereum/text/finetuning/validation/normal/normal.txt\"})\n",
"\n",
"# Load malicious SC opcode files\n",
"test = load_dataset(\"text\", data_files={\"train\": \"/data/forta/ethereum/text/finetuning/training/malicious/malicious.txt\",\n",
" \"val\": \"/data/forta/ethereum/text/finetuning/validation/malicious/malicious.txt\"})\n",
"# Extract normal SC opcode encodings\n",
"extract_normal_sequences(training[\"train\"][\"text\"], \"/data/forta/ethereum/text/finetuning/normal_training.csv\")\n",
"extract_normal_sequences(training[\"val\"][\"text\"], \"/data/forta/ethereum/text/finetuning/normal_validation.csv\")\n",
"\n",
"#Extract malicious SC opcode encodings\n",
"extract_anomalous_sequences(test[\"train\"][\"text\"], \"/data/forta/ethereum/text/finetuning/anomalous_training.csv\")\n",
"extract_anomalous_sequences(test[\"val\"][\"text\"], \"/data/forta/ethereum/text/finetuning/anomalous_validation.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit f3b442d

Please sign in to comment.