Skip to content

Commit

Permalink
BertModel output dictionary access, specify cache path for model, req…
Browse files Browse the repository at this point in the history
…uirements update
  • Loading branch information
markus-eberts committed Jan 7, 2021
1 parent 0032c8f commit bf04a22
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def process_configs(target, arg_parser):
args, _ = arg_parser.parse_known_args()
ctx = mp.get_context('fork')
ctx = mp.get_context('spawn')

for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args):
p = ctx.Process(target=target, args=(run_args,))
Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
Jinja2==2.10.3
numpy==1.17.4
tensorboardX==1.6
torch==1.3.1
torchvision==0.4.2
torch==1.4.0
tqdm==4.19.5
transformers==2.2.0
scikit-learn==0.21.3
transformers==4.1.1
scikit-learn==0.24.0
6 changes: 3 additions & 3 deletions spert/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _score(self, gt: List[List[Tuple]], pred: List[List[Tuple]], print_results:

def _compute_metrics(self, gt_all, pred_all, types, print_results: bool = False):
labels = [t.index for t in types]
per_type = prfs(gt_all, pred_all, labels=labels, average=None)
micro = prfs(gt_all, pred_all, labels=labels, average='micro')[:-1]
macro = prfs(gt_all, pred_all, labels=labels, average='macro')[:-1]
per_type = prfs(gt_all, pred_all, labels=labels, average=None, zero_division=0)
micro = prfs(gt_all, pred_all, labels=labels, average='micro', zero_division=0)[:-1]
macro = prfs(gt_all, pred_all, labels=labels, average='macro', zero_division=0)[:-1]
total_support = sum(per_type[-1])

if print_results:
Expand Down
4 changes: 2 additions & 2 deletions spert/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, e
entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor):
# get contextualized token embeddings from last transformer layer
context_masks = context_masks.float()
h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']

batch_size = encodings.shape[0]

Expand All @@ -85,7 +85,7 @@ def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, en
entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor):
# get contextualized token embeddings from last transformer layer
context_masks = context_masks.float()
h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']

batch_size = encodings.shape[0]
ctx_size = context_masks.shape[-1]
Expand Down
6 changes: 4 additions & 2 deletions spert/spert_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def train(self, train_path: str, valid_path: str, types_path: str, input_reader_
max_pairs=self.args.max_pairs,
prop_drop=self.args.prop_drop,
size_embedding=self.args.size_embedding,
freeze_transformer=self.args.freeze_transformer)
freeze_transformer=self.args.freeze_transformer,
cache_dir=self.args.cache_path)

# SpERT is currently optimized on a single GPU and not thoroughly tested in a multi GPU setup
# If you still want to train SpERT on multiple GPUs, uncomment the following lines
Expand Down Expand Up @@ -161,7 +162,8 @@ def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputRe
max_pairs=self.args.max_pairs,
prop_drop=self.args.prop_drop,
size_embedding=self.args.size_embedding,
freeze_transformer=self.args.freeze_transformer)
freeze_transformer=self.args.freeze_transformer,
cache_dir=self.args.cache_path)

model.to(self._device)

Expand Down

0 comments on commit bf04a22

Please sign in to comment.