Skip to content

Commit

Permalink
add prediction mode, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
markus-eberts committed Feb 4, 2021
1 parent dd34727 commit 5afcedf
Show file tree
Hide file tree
Showing 13 changed files with 509 additions and 374 deletions.
41 changes: 28 additions & 13 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,13 @@ def _add_common_args(arg_parser):
help="If true, input is lowercased during preprocessing")
arg_parser.add_argument('--sampling_processes', type=int, default=4,
help="Number of sampling processes. 0 = no multiprocessing for sampling")
arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue")

# Logging
arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models")
arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored")
arg_parser.add_argument('--store_predictions', action='store_true', default=False,
help="If true, store predictions on disc (in log directory)")
arg_parser.add_argument('--store_examples', action='store_true', default=False,
help="If true, store evaluation examples on disc (in log directory)")
arg_parser.add_argument('--example_count', type=int, default=None,
help="Count of evaluation example to store (if store_examples == True)")
arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off")

# Model / Training / Evaluation
arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints")
arg_parser.add_argument('--model_type', type=str, default="spert", help="Type of model")
arg_parser.add_argument('--cpu', action='store_true', default=False,
help="If true, train/evaluate on CPU even if a CUDA device is available")
arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size")
arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation/Prediction batch size")
arg_parser.add_argument('--max_pairs', type=int, default=1000,
help="Maximum entity pairs to process during training/evaluation")
arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations")
Expand All @@ -47,6 +35,18 @@ def _add_common_args(arg_parser):
arg_parser.add_argument('--seed', type=int, default=None, help="Seed")
arg_parser.add_argument('--cache_path', type=str, default=None,
help="Path to cache transformer models (for HuggingFace transformers library)")
arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off")


def _add_logging_args(arg_parser):
arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models")
arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored")
arg_parser.add_argument('--store_predictions', action='store_true', default=False,
help="If true, store predictions on disc (in log directory)")
arg_parser.add_argument('--store_examples', action='store_true', default=False,
help="If true, store evaluation examples on disc (in log directory)")
arg_parser.add_argument('--example_count', type=int, default=None,
help="Count of evaluation example to store (if store_examples == True)")


def train_argparser():
Expand Down Expand Up @@ -80,6 +80,7 @@ def train_argparser():
arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm")

_add_common_args(arg_parser)
_add_logging_args(arg_parser)

return arg_parser

Expand All @@ -90,6 +91,20 @@ def eval_argparser():
# Input
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")

_add_common_args(arg_parser)
_add_logging_args(arg_parser)

return arg_parser


def predict_argparser():
arg_parser = argparse.ArgumentParser()

# Input
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")
arg_parser.add_argument('--predictions_path', type=str, help="Path to store predictions")
arg_parser.add_argument('--spacy_model', type=str, help="Label of SpaCy model (used for tokenization)")

_add_common_args(arg_parser)

return arg_parser
1 change: 0 additions & 1 deletion configs/example_eval.conf
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ max_span_size = 10
store_predictions = true
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
log_path = data/log/
15 changes: 15 additions & 0 deletions configs/example_prediction.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[1]
model_type = spert
model_path = data/models/conll04
tokenizer_path = data/models/conll04
dataset_path = data/datasets/conll04/conll04_prediction_example.json
types_path = data/datasets/conll04/conll04_types.json
predictions_path = data/predictions.json
spacy_model = en_core_web_sm
eval_batch_size = 1
rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
sampling_processes = 4
max_pairs = 1000
1 change: 0 additions & 1 deletion configs/example_train.conf
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ max_span_size = 10
store_predictions = true
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
final_eval = true
log_path = data/log/
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ tensorboardX==1.6
torch==1.4.0
tqdm==4.55.1
transformers[sentencepiece]==4.1.1
scikit-learn==0.24.0
scikit-learn==0.24.0
spacy==3.0.1
29 changes: 21 additions & 8 deletions spert.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import argparse

from args import train_argparser, eval_argparser
from args import train_argparser, eval_argparser, predict_argparser
from config_reader import process_configs
from spert import input_reader
from spert.spert_trainer import SpERTTrainer


def _train():
arg_parser = train_argparser()
process_configs(target=__train, arg_parser=arg_parser)


def __train(run_args):
trainer = SpERTTrainer(run_args)
trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)


def _train():
arg_parser = train_argparser()
process_configs(target=__train, arg_parser=arg_parser)
def _eval():
arg_parser = eval_argparser()
process_configs(target=__eval, arg_parser=arg_parser)


def __eval(run_args):
Expand All @@ -23,9 +28,15 @@ def __eval(run_args):
input_reader_cls=input_reader.JsonInputReader)


def _eval():
arg_parser = eval_argparser()
process_configs(target=__eval, arg_parser=arg_parser)
def _predict():
arg_parser = predict_argparser()
process_configs(target=__predict, arg_parser=arg_parser)


def __predict(run_args):
trainer = SpERTTrainer(run_args)
trainer.predict(dataset_path=run_args.dataset_path, types_path=run_args.types_path,
input_reader_cls=input_reader.JsonPredictionInputReader)


if __name__ == '__main__':
Expand All @@ -37,5 +48,7 @@ def _eval():
_train()
elif args.mode == 'eval':
_eval()
elif args.mode == 'predict':
_predict()
else:
raise Exception("Mode not in ['train', 'eval'], e.g. 'python spert.py train ...'")
raise Exception("Mode not in ['train', 'eval', 'predict'], e.g. 'python spert.py train ...'")
Loading

0 comments on commit 5afcedf

Please sign in to comment.