-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathidealize_backbone.py
executable file
·72 lines (62 loc) · 2.05 KB
/
idealize_backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import glob
from icecream import ic
from tqdm import tqdm
import fire
import aa_model
import torch
import rf2aa.util
def get_ligands(pdb_lines):
ligands = set()
for l in pdb_lines:
if 'HETATM' not in l:
continue
curr_ligand = l[17:17+4].strip()
ligands.add(curr_ligand)
return ligands
def rewrite(path, outpath):
with open(path, 'r') as fh:
stream = [l for l in fh if "HETATM" in l or "CONECT" in l]
ligands = get_ligands(stream)
indep = aa_model.make_indep(path, ','.join(ligands), center=False)
xyz = indep.xyz[~indep.is_sm]
idx = indep.idx[~indep.is_sm]
L = xyz.shape[0]
ala_seq = torch.zeros((L,))
xyz = rf2aa.util.idealize_reference_frame(ala_seq[None], xyz[None])[0]
xyz_ideal = get_o(xyz, idx)
indep.xyz[~indep.is_sm, :4] = xyz_ideal
ligands = list(ligands)
assert len(ligands) == 1, f'Found >1 ligand: {ligands}'
indep.write_pdb(outpath, lig_name=ligands[0])
def get_o(xyz, idx):
idx_pad = torch.concat([idx, torch.tensor([-1])])
is_adj = (idx_pad[:-1] - idx_pad[1:]) == -1
L = xyz.shape[0]
xyz_ideal = torch.zeros((L, 4, 3))
xyz_ideal[:,:3] = xyz[:,:3]
for frames, idxs, ideal_pos in [
(
(xyz[:,0,:],xyz[:,1,:],xyz[:,2,:]),
torch.nonzero(~is_adj),
torch.tensor([2.1428, 0.7350, -0.7413]),
),
(
(xyz[:-1,1,:],xyz[:-1,2,:],xyz[1:,0,:]),
torch.nonzero(is_adj),
torch.tensor([ -0.7247, -1.0032, -0.0003])
)]:
idxs = idxs[:,0]
Rs, Ts = rf2aa.util.rigid_from_3_points(frames[0], frames[1], frames[2])
Rs = Rs[idxs]
Ts = Ts[idxs]
xyz_ideal[idxs, 3] = torch.einsum('lij,j->li', Rs, ideal_pos) + Ts
return xyz_ideal
def main(pattern, outdir):
for pdb in tqdm(sorted(glob.glob(pattern))):
d, name = os.path.split(pdb)
outpath = os.path.join(outdir, name)
ic(name, pdb, outpath)
rewrite(pdb, outpath)
if __name__ == '__main__':
fire.Fire(main)