-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
192 lines (166 loc) Β· 6.51 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from Bio import SeqIO
import gradio as gr
from gradio_molecule3d import Molecule3D
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": False
}
]
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(input_pdb):
mol = read_mol(input_pdb)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
})
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def convert_outputs_to_pdb(outputs):
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
return pdbs
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
model = model.cuda()
model.esm = model.esm.half()
import torch
torch.backends.cuda.matmul.allow_tf32 = True
model.trunk.set_chunk_size(64)
def fold_protein(test_protein):
tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
tokenized_input = tokenized_input.cuda()
with torch.no_grad():
output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
with open("output_structure.pdb", "w") as f:
f.write("".join(pdb))
html = molecule("output_structure.pdb")
return html, "output_structure.pdb"
def fold_protein_wpdb(test_protein, pdb_path):
tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
tokenized_input = tokenized_input.cuda()
with torch.no_grad():
output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
with open(pdb_path, "w") as f:
f.write("".join(pdb))
html = molecule(pdb_path)
return html
def load_protein_sequences(fasta_file):
protein_sequences = {}
for record in SeqIO.parse(fasta_file, "fasta"):
protein_sequences[record.id] = str(record.seq)
return protein_sequences
def show_split(inputfile):
seqs = load_protein_sequences(inputfile)
htmls = []
for seq in seqs:
pdb_path = f'{seq.replace(" ", "_").replace(",","")}.pdb'
html = fold_protein_wpdb(seqs[seq], pdb_path)
x = f"""<h3>>{seq}</h3>
<br>
"""
htmls.append(x+html)
final = "\n<br>\n".join(htmls)
realhtml = "<div>\n"+final+"\n</div>"
return realhtml
iface = gr.Interface(
title="SingleProteinviz",
fn=fold_protein,
inputs=gr.Textbox(
label="Protein Sequence",
info="Find sequences examples below, and complete examples with images at: https://github.com/AstraBert/proteinviz/tree/main/examples.md; if you input a sequence, you're gonna get the static image and the 3D model to explore and play with",
lines=5,
value=f"Paste or write amino-acidic sequence here",
),
outputs=[gr.HTML(label="Protein 3D model"), Molecule3D(label="Molecular 3D model", reps=reps)],
examples=[
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
"MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGPGCMSCKCVLS",
"MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG",
]
)
demo1 = gr.Interface(
title="BulkProteinviz",
fn=show_split,
inputs=gr.File(
label="FASTA File With Protein Sequences",
),
outputs= gr.HTML(label="Protein 3D models"),
examples = ["proteins.fasta"]
)
demo = gr.TabbedInterface([iface, demo1], ["Single Protein Structure Prediction", "Bulk Protein Structure Prediction"])
demo.launch(server_name="0.0.0.0", server_port=7860)