Skip to content

Commit

Permalink
better check for consistent molecule names
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocamilloni committed May 23, 2024
1 parent 5257fe1 commit 163a76a
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/multiego/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ def initialize_molecular_contacts(contact_matrix, path, ensemble_molecules_idx_s
# for intra-domain
if args.multi_epsilon is not None:
temp_epsi_intra = args.multi_epsilon[contact_matrix["molecule_idx_ai_temp"].to_numpy(dtype=int)[0] - 1]
contact_matrix.loc[
(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "epsilon_0"
] = temp_epsi_intra
contact_matrix.loc[(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "epsilon_0"] = (
temp_epsi_intra
)
if name[0] == "intramat":
print(f" -Intra-domain epsilon {temp_epsi_intra}")
else:
Expand All @@ -272,9 +272,9 @@ def initialize_molecular_contacts(contact_matrix, path, ensemble_molecules_idx_s
temp_epsi_inter_dom = args.multi_epsilon_inter_domain[
contact_matrix["molecule_idx_ai_temp"].to_numpy(dtype=int)[0] - 1
]
contact_matrix.loc[
(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"
] = temp_epsi_inter_dom
contact_matrix.loc[(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"] = (
temp_epsi_inter_dom
)
if name[0] == "intramat":
print(f" -Inter-domain epsilon {temp_epsi_inter_dom}")
else:
Expand All @@ -296,9 +296,9 @@ def initialize_molecular_contacts(contact_matrix, path, ensemble_molecules_idx_s
if name[0] == "intramat":
print(f" -Intra-domain epsilon {args.epsilon}")
# for inter-domain
contact_matrix.loc[
(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"
] = args.inter_domain_epsilon
contact_matrix.loc[(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"] = (
args.inter_domain_epsilon
)
if name[0] == "intramat":
print(f" -Inter-domain epsilon {args.inter_domain_epsilon}")
# for inter-molecular
Expand Down Expand Up @@ -488,6 +488,7 @@ def init_meGO_ensemble(args):
return ensemble

reference_set = set(ensemble["topology_dataframe"]["name"].to_list())
unique_ref_molecule_names = topology_dataframe["molecule_name"].unique()

# now we process the train contact matrices
train_contact_matrices = {}
Expand All @@ -513,6 +514,10 @@ def init_meGO_ensemble(args):
_,
_,
) = initialize_topology(topology, custom_dict, args)
# check that the molecules defined have a reference
unique_temp_molecule_names = temp_topology_dataframe["molecule_name"].unique()
check_molecule_names(unique_ref_molecule_names, unique_temp_molecule_names)

train_topology_dataframe = pd.concat(
[train_topology_dataframe, temp_topology_dataframe],
axis=0,
Expand Down Expand Up @@ -584,6 +589,9 @@ def init_meGO_ensemble(args):
_,
_,
) = initialize_topology(topology, custom_dict, args)
# check that the molecules defined have a reference
unique_temp_molecule_names = temp_topology_dataframe["molecule_name"].unique()
check_molecule_names(unique_ref_molecule_names, unique_temp_molecule_names)
check_topology_dataframe = pd.concat(
[check_topology_dataframe, temp_topology_dataframe],
axis=0,
Expand Down Expand Up @@ -1323,6 +1331,16 @@ def do_apply_check_rules(meGO_ensemble, meGO_LJ, check_dataset, symmetries, para
return meGO_LJ


def check_molecule_names(ref_list, tmp_list):
# Check if all unique entries in temp_topology_dataframe are in other_dataframe
missing_molecules = [molecule for molecule in tmp_list if molecule not in ref_list]

if missing_molecules:
raise ValueError(
f"The following molecule(s) from a train dataset {missing_molecules} are not found in the reference dataset {ref_list}"
)


def consistency_checks(meGO_LJ):
"""
Perform consistency checks on LJ parameters.
Expand Down

0 comments on commit 163a76a

Please sign in to comment.