Skip to content

Commit df68ed7

Browse files
committed
update readme and cli
1 parent 8a0c6eb commit df68ed7

File tree

4 files changed

+63
-9
lines changed

4 files changed

+63
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Protenix is built for high-accuracy structure prediction. It serves as an initia
3434
- **[Protenix-Dock](https://github.com/bytedance/Protenix-Dock)**: Our implementation of a classical protein-ligand docking framework that leverages empirical scoring functions. Without using deep neural networks, Protenix-Dock delivers competitive performance in rigid docking tasks.
3535

3636
## 🎉 Updates
37+
- 2025-11-05: [**Protenix-v0.7.0**](./assets/inference_time_vs_ntoken.png) is now open-sourced, with new options for faster diffusion inference: shared variable caching, efficient bias fusion, and TF32 acceleration.
3738
- 2025-07-17: **Protenix-Mini released!**: Lightweight model variants with significantly reduced inference cost are now available. Users can choose from multiple configurations to balance speed and accuracy based on deployment needs. See our [paper](https://arxiv.org/abs/2507.11839) and [model configs](./configs/configs_model_type.py) for more information.
3839
- 2025-07-17: [***New constraint feature***](docs/infer_json_format.md#constraint) is released! Now supports **atom-level contact** and **pocket** constraints, significantly improving performance in our evaluations.
3940
- 2025-05-30: **Protenix-v0.5.0** is now available! You may try Protenix-v0.5.0 by accessing the [server](https://protenix-server.com), or upgrade to the latest version using pip.

runner/batch_inference.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ def get_default_runner(
170170
use_msa: bool = True,
171171
trimul_kernel="cuequivariance",
172172
triatt_kernel="triattention",
173+
enable_cache=True,
174+
enable_fusion=True,
175+
enable_tf32=True,
173176
) -> InferenceRunner:
174177
inference_configs["model_name"] = model_name
175178
configs = {**configs_base, **{"data": data_configs}, **inference_configs}
@@ -194,7 +197,14 @@ def get_default_runner(
194197
configs.use_msa = use_msa
195198
configs.triangle_multiplicative = trimul_kernel
196199
configs.triangle_attention = triatt_kernel
200+
configs.enable_diffusion_shared_vars_cache = enable_cache
201+
configs.enable_efficient_fusion = enable_fusion
202+
configs.enable_tf32 = enable_tf32
197203

204+
logger.info(
205+
f"enable_diffusion_shared_vars_cache: {configs.enable_diffusion_shared_vars_cache}, "
206+
+ f"enable_efficient_fusion: {configs.enable_efficient_fusion}, enable_tf32: {configs.enable_tf32}"
207+
)
198208
download_infercence_cache(configs)
199209
return InferenceRunner(configs)
200210

@@ -210,6 +220,9 @@ def inference_jsons(
210220
model_name: str = "protenix_base_default_v0.5.0",
211221
trimul_kernel="cuequivariance",
212222
triatt_kernel="triattention",
223+
enable_cache=True,
224+
enable_fusion=True,
225+
enable_tf32=True,
213226
msa_server_mode: str = "protenix",
214227
) -> None:
215228
"""
@@ -246,6 +259,9 @@ def inference_jsons(
246259
use_msa,
247260
trimul_kernel,
248261
triatt_kernel,
262+
enable_cache,
263+
enable_fusion,
264+
enable_tf32,
249265
)
250266
configs = runner.configs
251267
for idx, infer_json in enumerate(tqdm.tqdm(infer_jsons)):
@@ -266,15 +282,16 @@ def protenix_cli():
266282

267283

268284
@click.command()
269-
@click.option("--input", type=str, help="json files or dir for inference")
270-
@click.option("--out_dir", default="./output", type=str, help="infer result dir")
285+
@click.option("-i", "--input", type=str, help="json files or dir for inference")
286+
@click.option("-o", "--out_dir", default="./output", type=str, help="infer result dir")
271287
@click.option(
272-
"--seeds", type=str, default="101", help="the inference seed, split by comma"
288+
"-s", "--seeds", type=str, default="101", help="the inference seed, split by comma"
273289
)
274-
@click.option("--cycle", type=int, default=10, help="pairformer cycle number")
275-
@click.option("--step", type=int, default=200, help="diffusion step")
276-
@click.option("--sample", type=int, default=5, help="sample number")
290+
@click.option("-c", "--cycle", type=int, default=10, help="pairformer cycle number")
291+
@click.option("-p", "--step", type=int, default=200, help="diffusion step")
292+
@click.option("-e", "--sample", type=int, default=5, help="sample number")
277293
@click.option(
294+
"-n",
278295
"--model_name",
279296
type=str,
280297
default="protenix_base_default_v0.5.0",
@@ -301,6 +318,24 @@ def protenix_cli():
301318
default="triattention",
302319
help="Kernel to use for triangle attention. Options: 'triattention', 'cuequivariance', 'deepspeed', 'torch'.",
303320
)
321+
@click.option(
322+
"--enable_cache",
323+
type=bool,
324+
default=True,
325+
help="The diffusion module precomputes and caches pair_z, p_lm, and c_l (which are shareable across the N_sample and N_step dimensions)",
326+
)
327+
@click.option(
328+
"--enable_fusion",
329+
type=bool,
330+
default=True,
331+
help="The diffusion transformer consists of 24 transformer blocks, and the biases in these blocks can be pre-transformed in terms of dimensionality and normalization",
332+
)
333+
@click.option(
334+
"--enable_tf32",
335+
type=bool,
336+
default=True,
337+
help="When the diffusion module uses FP32 computation, enabling enable_tf32 reduces the matrix multiplication precision from FP32 to TF32.",
338+
)
304339
@click.option(
305340
"--msa_server_mode",
306341
type=str,
@@ -319,6 +354,9 @@ def predict(
319354
use_default_params,
320355
trimul_kernel,
321356
triatt_kernel,
357+
enable_cache,
358+
enable_fusion,
359+
enable_tf32,
322360
msa_server_mode,
323361
):
324362
"""
@@ -380,6 +418,9 @@ def predict(
380418
model_name=model_name,
381419
trimul_kernel=trimul_kernel,
382420
triatt_kernel=triatt_kernel,
421+
enable_cache=enable_cache,
422+
enable_fusion=enable_fusion,
423+
enable_tf32=enable_tf32,
383424
msa_server_mode=msa_server_mode,
384425
)
385426

runner/inference.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16+
import time
1617
import traceback
1718
import urllib.request
1819
from argparse import Namespace
@@ -307,10 +308,13 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
307308
return
308309

309310
num_data = len(dataloader.dataset)
311+
t0_start = time.time()
310312
for seed in configs.seeds:
311313
seed_everything(seed=seed, deterministic=configs.deterministic)
314+
t1_start = time.time()
312315
for batch in dataloader:
313316
try:
317+
t2_start = time.time()
314318
data, atom_array, data_error_message = batch[0]
315319
sample_name = data["sample_name"]
316320

@@ -338,9 +342,9 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
338342
atom_array=atom_array,
339343
entity_poly_type=data["entity_poly_type"],
340344
)
341-
345+
t2_end = time.time()
342346
logger.info(
343-
f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded.\n"
347+
f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded. Model forward time: {t2_end-t2_start}s.\n"
344348
f"Results saved to {configs.dump_dir}"
345349
)
346350
torch.cuda.empty_cache()
@@ -352,6 +356,14 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
352356
f.write(error_message)
353357
if hasattr(torch.cuda, "empty_cache"):
354358
torch.cuda.empty_cache()
359+
t1_end = time.time()
360+
logger.info(
361+
f"[Rank {DIST_WRAPPER.rank}] seed {seed} succeeded. Total task time: {t1_end-t1_start}s.\n"
362+
)
363+
t0_end = time.time()
364+
logger.info(
365+
f"[Rank {DIST_WRAPPER.rank}] job succeeded. Total job time: {t0_end-t0_start}s.\n"
366+
)
355367

356368

357369
def main(configs: Any) -> None:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
setup(
3939
name="protenix",
4040
python_requires=">=3.11",
41-
version="0.6.3",
41+
version="0.7.0",
4242
description="A trainable PyTorch reproduction of AlphaFold 3.",
4343
long_description=long_description,
4444
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)