Skip to content

Commit

Permalink
Merge pull request #24 from anton-bushuiev/main
Browse files Browse the repository at this point in the history
Fix testing for de novo, add loss only validation
  • Loading branch information
anton-bushuiev authored Jun 4, 2024
2 parents a88e97b + 17765a6 commit 32572a7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
19 changes: 18 additions & 1 deletion massspecgym/models/de_novo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
self,
top_ks: T.Iterable[int] = (1, 10),
myopic_mces_kwargs: T.Optional[T.Mapping] = None,
validate_only_loss: bool = False,
*args,
**kwargs
):
Expand All @@ -34,6 +35,7 @@ def __init__(
solver_options=dict(msg=0) # make ILP solver silent
)
self.myopic_mces_kwargs |= myopic_mces_kwargs or {}
self.validate_only_loss = validate_only_loss
self.mol_pred_kind: T.Literal["smiles", "rdkit"] = "smiles"

def on_batch_end(
Expand All @@ -56,7 +58,22 @@ def on_validation_batch_end(
outputs: T.Any,
batch: dict,
batch_idx: int,
metric_pref: str = ''
metric_pref: str = 'val_'
) -> None:
self.on_batch_end(outputs, batch, batch_idx, metric_pref)
if not self.validate_only_loss:
self.evaluate_de_novo_step(
outputs["mols_pred"], # (bs, k) list of generated rdkit molecules or SMILES strings
batch["mol"], # (bs) list of ground truth SMILES strings
metric_pref=metric_pref
)

def on_test_batch_end(
self,
outputs: T.Any,
batch: dict,
batch_idx: int,
metric_pref: str = 'test_'
) -> None:
self.on_batch_end(outputs, batch, batch_idx, metric_pref)
self.evaluate_de_novo_step(
Expand Down
9 changes: 8 additions & 1 deletion massspecgym/models/de_novo/smiles_tranformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def __init__(
k_predictions: int = 1,
temperature: T.Optional[float] = None,
pre_norm=False,
*args,
**kwargs
):
super().__init__()
super().__init__(*args, **kwargs)
self.smiles_tokenizer = smiles_tokenizer
self.vocab_size = smiles_tokenizer.get_vocab_size()
for token in [start_token, end_token, pad_token]:
Expand Down Expand Up @@ -107,6 +109,11 @@ def validation_step(self, batch: dict, batch_idx: torch.Tensor) -> tuple:
outputs = self.step(batch)
decoded_smiles = self.decode_smiles(batch["spec"].float())
return dict(loss=outputs["loss"], mols_pred=decoded_smiles)

def test_step(self, batch: dict, batch_idx: torch.Tensor) -> tuple:
outputs = self.step(batch)
decoded_smiles = self.decode_smiles(batch["spec"].float())
return dict(loss=outputs["loss"], mols_pred=decoded_smiles)

def generate_src_padding_mask(self, spec):
return spec.sum(-1) == 0
Expand Down
3 changes: 2 additions & 1 deletion scripts/submit_simple_train_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def write_train_sh(p):
--run_name="lr={p['lr']},bs={p['batch_size']},k={p['k_predictions']},d={p['d_model']},nhead={p['nhead']},nel={p['num_encoder_layers']}" \
--task=de_novo \
--model=smiles_transformer \
--validate_only_loss \
--batch_size={p['batch_size']} \
--lr={p['lr']} \
--k_predictions={p['k_predictions']} \
Expand Down Expand Up @@ -67,7 +68,7 @@ def main():
grid = {
'lr': [3e-4, 1e-4, 5e-5],
'batch_size': [64, 128], # per GPU
'k_predictions': [1, 10],
'k_predictions': [1],
'd_model': [256, 512],
'nhead': [4, 8],
'num_encoder_layers': [3, 6],
Expand Down
5 changes: 4 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

# - De novo

parser.add_argument('--validate_only_loss', action='store_true')

# 1. SmilesTransformer
parser.add_argument('--input_dim', type=int, default=2)
parser.add_argument('--d_model', type=int, default=512)
Expand Down Expand Up @@ -148,7 +150,8 @@ def main(args):
smiles_tokenizer=args.smiles_tokenizer,
k_predictions=args.k_predictions,
pre_norm=args.pre_norm,
max_smiles_len=args.max_smiles_len
max_smiles_len=args.max_smiles_len,
validate_only_loss=args.validate_only_loss
)
else:
raise NotImplementedError(f"Model {args.model} not implemented.")
Expand Down

0 comments on commit 32572a7

Please sign in to comment.