diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d4cf249fe..931c8b483 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1033,7 +1033,10 @@ def postprocess(module, name): original_device = weight.device original_dtype = weight.dtype weight_f64 = weight.to(dtype=torch.float64, device=original_device) - u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) + if original_device.type == "cuda": + u, s, vt = torch.linalg.svd(weight_f64, driver="gesvd", full_matrices=False) + else: + u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) if u.shape[1] < lowrank or vt.shape[0] < lowrank: warnings.warn( "The low-rank dimensions do not match the layer dimensions. "