Skip to content

Commit ff68d0b

Browse files
committed
enforce fp32 and torch kernels for triangle attention and multiplicative on V100
1 parent bc9943b commit ff68d0b

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

runner/batch_inference.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
from protenix.data.json_parser import lig_file_to_atom_info
3737
from protenix.data.utils import pdb_to_cif
3838
from protenix.utils.logger import get_logger
39-
from runner.inference import InferenceRunner, download_infercence_cache, infer_predict
39+
from runner.inference import (
40+
InferenceRunner,
41+
download_infercence_cache,
42+
infer_predict,
43+
update_gpu_compatible_configs,
44+
)
4045
from runner.msa_search import msa_search, update_infer_json
4146

4247
logger = get_logger(__name__)
@@ -166,6 +171,7 @@ def get_default_runner(
166171
n_cycle: int = 10,
167172
n_step: int = 200,
168173
n_sample: int = 5,
174+
dtype: str = "bf16",
169175
model_name: str = "protenix_base_default_v0.5.0",
170176
use_msa: bool = True,
171177
trimul_kernel="cuequivariance",
@@ -184,23 +190,28 @@ def get_default_runner(
184190
configs.seeds = seeds
185191
model_name = configs.model_name
186192
_, model_size, model_feature, model_version = model_name.split("_")
187-
logger.info(
188-
f"Inference by Protenix: model_size: {model_size}, with_feature: {model_feature.replace('-', ',')}, model_version: {model_version}"
189-
)
190193
model_specfics_configs = ConfigDict(model_configs[model_name])
191194
# update model specific configs
192195
configs.update(model_specfics_configs)
193196
# the user input configs has the highest priority
194197
configs.model.N_cycle = n_cycle
195198
configs.sample_diffusion.N_sample = n_sample
196199
configs.sample_diffusion.N_step = n_step
200+
configs.dtype = dtype
197201
configs.use_msa = use_msa
198202
configs.triangle_multiplicative = trimul_kernel
199203
configs.triangle_attention = triatt_kernel
200204
configs.enable_diffusion_shared_vars_cache = enable_cache
201205
configs.enable_efficient_fusion = enable_fusion
202206
configs.enable_tf32 = enable_tf32
203207

208+
configs = update_gpu_compatible_configs(configs)
209+
logger.info(
210+
f"Inference by Protenix: model_size: {model_size}, with_feature: {model_feature.replace('-', ',')}, model_version: {model_version}, dtype: {configs.dtype}"
211+
)
212+
logger.info(
213+
f"Triangle_multiplicative kernel: {trimul_kernel}, Triangle_attention kernel: {triatt_kernel}"
214+
)
204215
logger.info(
205216
f"enable_diffusion_shared_vars_cache: {configs.enable_diffusion_shared_vars_cache}, "
206217
+ f"enable_efficient_fusion: {configs.enable_efficient_fusion}, enable_tf32: {configs.enable_tf32}"
@@ -217,6 +228,7 @@ def inference_jsons(
217228
n_cycle: int = 10,
218229
n_step: int = 200,
219230
n_sample: int = 5,
231+
dtype: str = "bf16",
220232
model_name: str = "protenix_base_default_v0.5.0",
221233
trimul_kernel="cuequivariance",
222234
triatt_kernel="triattention",
@@ -255,6 +267,7 @@ def inference_jsons(
255267
n_cycle,
256268
n_step,
257269
n_sample,
270+
dtype,
258271
model_name,
259272
use_msa,
260273
trimul_kernel,
@@ -290,6 +303,7 @@ def protenix_cli():
290303
@click.option("-c", "--cycle", type=int, default=10, help="pairformer cycle number")
291304
@click.option("-p", "--step", type=int, default=200, help="diffusion step")
292305
@click.option("-e", "--sample", type=int, default=5, help="sample number")
306+
@click.option("-d", "--dtype", type=str, default="bf16", help="sample number")
293307
@click.option(
294308
"-n",
295309
"--model_name",
@@ -349,6 +363,7 @@ def predict(
349363
cycle,
350364
step,
351365
sample,
366+
dtype,
352367
model_name,
353368
use_msa,
354369
use_default_params,
@@ -403,9 +418,6 @@ def predict(
403418
"deepspeed",
404419
"torch",
405420
], "Kernel to use for triangle attention. Options: 'triattention', 'cuequivariance', 'deepspeed', 'torch'."
406-
logger.info(
407-
f"Triangle_multiplicative kernel: {trimul_kernel}, Triangle_attention kernel: {triatt_kernel}"
408-
)
409421
seeds = list(map(int, seeds.split(",")))
410422
inference_jsons(
411423
input,
@@ -415,6 +427,7 @@ def predict(
415427
n_cycle=cycle,
416428
n_step=step,
417429
n_sample=sample,
430+
dtype=dtype,
418431
model_name=model_name,
419432
trimul_kernel=trimul_kernel,
420433
triatt_kernel=triatt_kernel,

runner/inference.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,30 @@ def main(configs: Any) -> None:
372372
infer_predict(runner, configs)
373373

374374

375+
def update_gpu_compatible_configs(configs: Any) -> None:
376+
def is_gpu_capability_between_7_and_8():
377+
# 7.0 <= device_capability < 8.0
378+
if not torch.cuda.is_available():
379+
return False
380+
381+
capability = torch.cuda.get_device_capability()
382+
major, minor = capability
383+
cc = major + minor / 10.0
384+
if 7.0 <= cc < 8.0:
385+
return True
386+
return False
387+
388+
if is_gpu_capability_between_7_and_8():
389+
# Some kernels and BF16 aren’t supported on V100 — enforce specific configurations to work around it.
390+
configs.dtype = "fp32"
391+
configs.triangle_attention = "torch"
392+
configs.triangle_multiplicative = "torch"
393+
logger.info(
394+
"GPU capability is between 7.0 and 8.0, enforce fp32 and torch kernels for triangle attention and multiplicative."
395+
)
396+
return configs
397+
398+
375399
def run() -> None:
376400
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
377401
logging.basicConfig(
@@ -390,11 +414,12 @@ def run() -> None:
390414
model_name = configs.model_name
391415
_, model_size, model_feature, model_version = model_name.split("_")
392416
logger.info(
393-
f"Inference by Protenix: model_size: {model_size}, with_feature: {model_feature.replace('-',', ')}, model_version: {model_version}"
417+
f"Inference by Protenix: model_size: {model_size}, with_feature: {model_feature.replace('-',', ')}, model_version: {model_version}, dtype: {configs.dtype}"
394418
)
395419
model_specfics_configs = ConfigDict(model_configs[model_name])
396420
# update model specific configs
397421
configs.update(model_specfics_configs)
422+
configs = update_gpu_compatible_configs(configs)
398423
logger.info(
399424
f"Triangle_multiplicative kernel: {configs.triangle_multiplicative}, Triangle_attention kernel: {configs.triangle_attention}"
400425
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
setup(
3939
name="protenix",
4040
python_requires=">=3.11",
41-
version="0.7.0",
41+
version="0.7.1",
4242
description="A trainable PyTorch reproduction of AlphaFold 3.",
4343
long_description=long_description,
4444
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)