diff --git a/trinity/common/config.py b/trinity/common/config.py index 05d46589d5..2a43b2e235 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -1314,6 +1314,29 @@ def check_and_update(self) -> Config: # noqa: C901 for args in model_args: set_if_none(aux_model, args, getattr(self.model, args)) + # check gpu number + rollout_gpu_num = ( + self.explorer.rollout_model.tensor_parallel_size + * self.explorer.rollout_model.engine_num + + sum( + ( + model.tensor_parallel_size * model.engine_num + for model in self.explorer.auxiliary_models + ) + ) + ) + assert self.cluster.node_num is not None + assert self.cluster.gpu_per_node is not None + total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node + if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num: + raise ValueError( + f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})." + ) + elif self.mode == "both" and rollout_gpu_num >= total_gpu_num: + raise ValueError( + f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}." + ) + if self.explorer.over_rollout.ratio > 0.0: if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")