diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index dac7aa10a7fb..d6cba28a8f7e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2837,13 +2837,7 @@ def step(self, lr_kwargs=None): if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()): - ev_values = self.block_eigenvalue.values() - for i in range(len(ev_values)): - self.summary_events.append(( - f"Train/Eigenvalues/ModelBlockParam_{i}", - self.ev_values[i][0], - self.global_samples, - )) + self.summary_events.extend(self._get_eigenvalue_monitor_events()) self.monitor.write_events(self.summary_events) # Check flops profiling @@ -2963,6 +2957,11 @@ def _write_monitor(self): ] self.monitor.write_events(self.summary_events) + def _get_eigenvalue_monitor_events(self): + ev_values = list(self.block_eigenvalue.values()) + return [(f"Train/Eigenvalues/ModelBlockParam_{i}", value[0], self.global_samples) + for i, value in enumerate(ev_values)] + def _get_optimizer_param(self, param_name): result = [] if not self.optimizer: diff --git a/tests/unit/runtime/test_engine.py b/tests/unit/runtime/test_engine.py new file mode 100644 index 000000000000..de7ea17f95e0 --- /dev/null +++ b/tests/unit/runtime/test_engine.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.engine import DeepSpeedEngine + + +def test_eigenvalue_monitor_events_use_block_eigenvalue_values(capsys): + engine = DeepSpeedEngine.__new__(DeepSpeedEngine) + engine.block_eigenvalue = { + "first_param": (0.25, 0), + "second_param": (0.5, 1), + } + engine.global_samples = 32 + expected_events = [ + ("Train/Eigenvalues/ModelBlockParam_0", 0.25, 32), + ("Train/Eigenvalues/ModelBlockParam_1", 0.5, 32), + ] + actual_events = engine._get_eigenvalue_monitor_events() + + with capsys.disabled(): + print(f"\nblock_eigenvalue: {engine.block_eigenvalue}") + print(f"generated eigenvalue monitor events: {actual_events}") + + assert actual_events == expected_events