3636from protenix .data .json_parser import lig_file_to_atom_info
3737from protenix .data .utils import pdb_to_cif
3838from protenix .utils .logger import get_logger
39- from runner .inference import InferenceRunner , download_infercence_cache , infer_predict
39+ from runner .inference import (
40+ InferenceRunner ,
41+ download_infercence_cache ,
42+ infer_predict ,
43+ update_gpu_compatible_configs ,
44+ )
4045from runner .msa_search import msa_search , update_infer_json
4146
4247logger = get_logger (__name__ )
@@ -166,6 +171,7 @@ def get_default_runner(
166171 n_cycle : int = 10 ,
167172 n_step : int = 200 ,
168173 n_sample : int = 5 ,
174+ dtype : str = "bf16" ,
169175 model_name : str = "protenix_base_default_v0.5.0" ,
170176 use_msa : bool = True ,
171177 trimul_kernel = "cuequivariance" ,
@@ -184,23 +190,28 @@ def get_default_runner(
184190 configs .seeds = seeds
185191 model_name = configs .model_name
186192 _ , model_size , model_feature , model_version = model_name .split ("_" )
187- logger .info (
188- f"Inference by Protenix: model_size: { model_size } , with_feature: { model_feature .replace ('-' , ',' )} , model_version: { model_version } "
189- )
190193 model_specfics_configs = ConfigDict (model_configs [model_name ])
191194 # update model specific configs
192195 configs .update (model_specfics_configs )
193196 # the user input configs has the highest priority
194197 configs .model .N_cycle = n_cycle
195198 configs .sample_diffusion .N_sample = n_sample
196199 configs .sample_diffusion .N_step = n_step
200+ configs .dtype = dtype
197201 configs .use_msa = use_msa
198202 configs .triangle_multiplicative = trimul_kernel
199203 configs .triangle_attention = triatt_kernel
200204 configs .enable_diffusion_shared_vars_cache = enable_cache
201205 configs .enable_efficient_fusion = enable_fusion
202206 configs .enable_tf32 = enable_tf32
203207
208+ configs = update_gpu_compatible_configs (configs )
209+ logger .info (
210+ f"Inference by Protenix: model_size: { model_size } , with_feature: { model_feature .replace ('-' , ',' )} , model_version: { model_version } , dtype: { configs .dtype } "
211+ )
212+ logger .info (
213+ f"Triangle_multiplicative kernel: { trimul_kernel } , Triangle_attention kernel: { triatt_kernel } "
214+ )
204215 logger .info (
205216 f"enable_diffusion_shared_vars_cache: { configs .enable_diffusion_shared_vars_cache } , "
206217 + f"enable_efficient_fusion: { configs .enable_efficient_fusion } , enable_tf32: { configs .enable_tf32 } "
@@ -217,6 +228,7 @@ def inference_jsons(
217228 n_cycle : int = 10 ,
218229 n_step : int = 200 ,
219230 n_sample : int = 5 ,
231+ dtype : str = "bf16" ,
220232 model_name : str = "protenix_base_default_v0.5.0" ,
221233 trimul_kernel = "cuequivariance" ,
222234 triatt_kernel = "triattention" ,
@@ -255,6 +267,7 @@ def inference_jsons(
255267 n_cycle ,
256268 n_step ,
257269 n_sample ,
270+ dtype ,
258271 model_name ,
259272 use_msa ,
260273 trimul_kernel ,
@@ -290,6 +303,7 @@ def protenix_cli():
290303@click .option ("-c" , "--cycle" , type = int , default = 10 , help = "pairformer cycle number" )
291304@click .option ("-p" , "--step" , type = int , default = 200 , help = "diffusion step" )
292305@click .option ("-e" , "--sample" , type = int , default = 5 , help = "sample number" )
306+ @click .option ("-d" , "--dtype" , type = str , default = "bf16" , help = "sample number" )
293307@click .option (
294308 "-n" ,
295309 "--model_name" ,
@@ -349,6 +363,7 @@ def predict(
349363 cycle ,
350364 step ,
351365 sample ,
366+ dtype ,
352367 model_name ,
353368 use_msa ,
354369 use_default_params ,
@@ -403,9 +418,6 @@ def predict(
403418 "deepspeed" ,
404419 "torch" ,
405420 ], "Kernel to use for triangle attention. Options: 'triattention', 'cuequivariance', 'deepspeed', 'torch'."
406- logger .info (
407- f"Triangle_multiplicative kernel: { trimul_kernel } , Triangle_attention kernel: { triatt_kernel } "
408- )
409421 seeds = list (map (int , seeds .split ("," )))
410422 inference_jsons (
411423 input ,
@@ -415,6 +427,7 @@ def predict(
415427 n_cycle = cycle ,
416428 n_step = step ,
417429 n_sample = sample ,
430+ dtype = dtype ,
418431 model_name = model_name ,
419432 trimul_kernel = trimul_kernel ,
420433 triatt_kernel = triatt_kernel ,
0 commit comments