Skip to content

Commit

Permalink
Release code for v2.3.0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494507694
  • Loading branch information
Augustin-Zidek committed Dec 11, 2022
1 parent 4494af8 commit 9b18d6a
Show file tree
Hide file tree
Showing 30 changed files with 894 additions and 498 deletions.
242 changes: 130 additions & 112 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion afdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names
Expand Down
14 changes: 7 additions & 7 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(self,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
Expand All @@ -135,9 +135,9 @@ def __init__(self,
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
Expand Down Expand Up @@ -211,14 +211,14 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
61 changes: 61 additions & 0 deletions alphafold/model/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,64 @@ def __call__(self, inputs):

return output


class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""

def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)

param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)

param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)

if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)

out = super().__call__(x, scale=scale, offset=offset)

if is_bf16:
out = out.astype(jnp.bfloat16)

return out

88 changes: 64 additions & 24 deletions alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""

if 'multimer' in name:
return CONFIG_MULTIMER

if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg

Expand All @@ -52,11 +52,11 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'model_5_ptm',
),
'multimer': (
'model_1_multimer_v2',
'model_2_multimer_v2',
'model_3_multimer_v2',
'model_4_multimer_v2',
'model_5_multimer_v2',
'model_1_multimer_v3',
'model_2_multimer_v3',
'model_3_multimer_v3',
'model_4_multimer_v3',
'model_5_multimer_v3',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
Expand Down Expand Up @@ -118,8 +118,32 @@ def model_config(name: str) -> ml_collections.ConfigDict:
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
},
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})


CONFIG = ml_collections.ConfigDict({
'data': {
Expand Down Expand Up @@ -260,14 +284,16 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
Expand Down Expand Up @@ -328,14 +354,16 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
Expand All @@ -354,7 +382,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -483,27 +511,29 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'num_msa': 508,
'num_extra_msa': 2048,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
Expand Down Expand Up @@ -564,24 +594,28 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
}
},
},
'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -651,7 +685,13 @@ def model_config(name: str) -> ml_collections.ConfigDict:
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True
}
})
8 changes: 4 additions & 4 deletions alphafold/model/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def safe_dropout_fn(tensor, safe_key):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand All @@ -353,7 +353,7 @@ def safe_dropout_fn(tensor, safe_key):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand Down Expand Up @@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config
sequence_mask = batch['seq_mask'][:, None]

act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand All @@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(),
}

act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand Down
8 changes: 4 additions & 4 deletions alphafold/model/folding_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def safe_dropout_fn(tensor, safe_key):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand All @@ -448,7 +448,7 @@ def safe_dropout_fn(tensor, safe_key):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand Down Expand Up @@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single'])

Expand All @@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
}

act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand Down
Loading

0 comments on commit 9b18d6a

Please sign in to comment.