Skip to content

Commit

Permalink
Merge pull request multi-ego#523 from brunostega/new_input
Browse files Browse the repository at this point in the history
Code cleaning, better input reading and fixed intra-domain tool and flag handling. Tested on lysozyme 2 domain
  • Loading branch information
carlocamilloni authored Jan 5, 2025
2 parents 1b6ed26 + 11fd3f0 commit 3d484c1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 52 deletions.
8 changes: 4 additions & 4 deletions src/multiego/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"corresponding to the subfolders to process and where the contacts are learned.",
},
"--input_refs": {
"type": dict,
"default": {},
"type": list,
"default": [],
"help": "A list of the training simulations to be included in multi-eGO, "
"corresponding to the subfolders to process and where the contacts are learned.",
},
Expand Down Expand Up @@ -112,8 +112,8 @@
"production: creates a force-field combining random coil simulations and training simulations.",
},
"--input_refs": {
"type": dict,
"default": {},
"type": list,
"default": [],
"help": "A list of the training simulations to be included in multi-eGO, "
"corresponding to the subfolders to process and where the contacts are learned.",
},
Expand Down
4 changes: 2 additions & 2 deletions src/multiego/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def initialize_molecular_contacts(contact_matrix, prior_matrix, args, reference)
"""

# remove un-learned contacts (intra-inter domain)
contact_matrix["learned"] = prior_matrix["rc_learned"]
contact_matrix["learned"] = prior_matrix["rc_learned"].to_numpy()
contact_matrix["reference"] = reference["reference"]
# calculate adaptive rc/md threshold
# sort probabilities, and calculate the normalized cumulative distribution
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def set_sig_epsilon(meGO_LJ, parameters):
meGO_LJ.loc[(meGO_LJ["epsilon"] < 0.0), "sigma"] = (-meGO_LJ["epsilon"]) ** (1.0 / 12.0)

# add a flag to identify learned contacts
meGO_LJ.loc[:, "learned"] = 1
# meGO_LJ.loc[:, "learned"] = 1

return meGO_LJ

Expand Down
48 changes: 13 additions & 35 deletions src/multiego/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,26 @@


def read_arguments(args, args_dict, args_dict_global, args_dict_single_reference):
# TODO UGLY, make it nicer
new_input = False

if args.config:

config_yaml, new_input = read_config(args.config, args_dict)
config_yaml = read_config(args.config, args_dict)
# check if yaml file is empty
if not config_yaml:
print("WARNING: Configuration file was parsed, but the dictionary is empty")
elif new_input:
else:
args = combine_configurations(config_yaml, args, args_dict_global)
args = read_new_input(args, args_dict_single_reference)
else:
args = combine_configurations(config_yaml, args, args_dict)
args = convert_command_line_to_new_input(args, args_dict_single_reference)

# if not new_input convert command line to input_ref dictionary format
# if config does not exists convert command line to input_ref dictionary format
else:
if not new_input:
if args.egos == "production" and not args.reference:
args.reference = ["reference"]
if args.input_refs:
raise ValueError("ERROR: input_refs should be used only with a configuration file")

args = convert_command_line_to_new_input(args, args_dict_single_reference)
if new_input and (args.reference or args.train or args.epsilon):
raise ValueError(
"""--reference, --train and --epsilon should not be used with input_refs. Either use the firsts or use the second via yml config file. e.g.:
--reference ref --train training --epsilon 0.3\n
or in config file use:
- input_refs:
- reference: ref
train: training
matrix: intramat_1_1
epsilon: 0.2
"""
)
if args.egos == "production" and not args.reference:
args.reference = ["reference"]

args = convert_command_line_to_new_input(args, args_dict_single_reference)

return args

Expand All @@ -63,7 +49,6 @@ def read_config(file, args_dict):
args_dict : dict
The content of the YAML file as a dictionary
"""
new_input = False
with open(file, "r") as f:
yml = yaml.safe_load(f)
# check if the keys in the yaml file are valid
Expand All @@ -74,22 +59,19 @@ def read_config(file, args_dict):
key = list(element.keys())[0]
if f"--{key}" not in args_dict:
raise ValueError(f"ERROR: {key} in {file} is not a valid argument.")
if key == "input_refs":
new_input = True
print("\n\n New input format detected. \n\n")
return yml, new_input
return yml


def read_new_input(args, args_dict_single_input):
"""
converts new inputs structure into the correct dataframe structure and combines it with the other args
Checks the input_ref dictionary has the correct keys, and combines it with the non-specified default arguments for each reference.
Parameters
----------
yml : dict
The configuration from the YAML file
args : dict
The command-line arguments
The command-line arguments with default values
Returns
-------
Expand Down Expand Up @@ -158,7 +140,6 @@ def convert_command_line_to_new_input(args, args_dict_single_input):
dict_input_ref[appo].update({var: getattr(args, var)})
appo += 1
args.input_refs = dict_input_ref

return args


Expand All @@ -184,9 +165,6 @@ def combine_configurations(yml, args, args_dict):
for element in yml:
if type(element) is dict:
key, value = list(element.items())[0]
if key == "input_refs":
setattr(args, key, value)
continue
value = args_dict[f"--{key}"]["type"](value)
parse_key = f"--{key}"
default_value = args_dict[parse_key]["default"] if "default" in args_dict[parse_key] else None
Expand Down
33 changes: 22 additions & 11 deletions tools/domain_sectioner/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def dom_range(ranges_str):
doms = [(int(r.split("-")[0]), int(r.split("-")[1])) for r in ranges_str]

if not all([x[0] <= x[1] for x in doms]):
raise ValueError("Elements in each range should be non-decreasing e.g. dom_res 1-10 11-20 ...")
print("WARNING: Elements in each range should be non-decreasing e.g. dom_res 1-10 11-20 ...")

if not all([x1[1] < x2[0] for x1, x2 in zip(doms[:-1], doms[1:])]):
raise ValueError("Ranges should not overlap e.g. dom_res 1-10 11-20 ...")
print("WARNING: Ranges should not overlap e.g. dom_res 1-10 11-20 ...")

return doms

Expand Down Expand Up @@ -140,18 +140,29 @@ def read_topologies(top):
raise ValueError(f"ERROR: number of atoms in intramat ({dim}) does not correspond to that of topology ({n_atoms})")

# define domain mask
domain_mask = np.full(dim, False)
domain_mask_linear = np.full(dim**2, False)
for r in ranges:
start = find_atom_start(topology_mego, r[0])
end = find_atom_end(topology_mego, r[1])
print(f" Domain range: {r[0]}-{r[1]}")
print(f" Atom index range start-end: {start+1} - {end+1}")
print(f" Number of atoms in domain range: {end+1 - (start)}")
print(f" Atom and Residue of start-end {topology_mego.atoms[start]} - {topology_mego.atoms[end]}")
print("\n")
map_appo = np.array([True if x >= start and x <= end else False for x in range(dim)])
domain_mask = np.logical_or(domain_mask, map_appo)
domain_mask_linear = (domain_mask * domain_mask[:, np.newaxis]).reshape(dim**2)
if start >= end:
appo_end = end
end = start
start = appo_end
print(f" Domain range: {r[0]}-{r[1]} INVERTED")
print(f" Atom index range start-end: {start+1} - {end+1}")
print(f" Number of atoms in domain range: {end+1 - (start)}")
print(f" Atom and Residue of start-end {topology_mego.atoms[start]} - {topology_mego.atoms[end]}")
print("\n")
map_appo = np.invert(np.array([True if x >= start and x <= end else False for x in range(dim)]))
else:
print(f" Domain range: {r[0]}-{r[1]}")
print(f" Atom index range start-end: {start+1} - {end+1}")
print(f" Number of atoms in domain range: {end+1 - (start)}")
print(f" Atom and Residue of start-end {topology_mego.atoms[start]} - {topology_mego.atoms[end]}")
print("\n")
map_appo = np.array([True if x >= start and x <= end else False for x in range(dim)])
domain_mask_linear = np.logical_or(domain_mask_linear, (map_appo * map_appo[:, np.newaxis]).reshape(dim**2))

if args.invert:
domain_mask_linear = np.logical_not(domain_mask_linear)
print(domain_mask_linear)
Expand Down

0 comments on commit 3d484c1

Please sign in to comment.