diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py new file mode 100644 index 000000000..3c0796bbc --- /dev/null +++ b/autoregressive_prova_generic_condition.py @@ -0,0 +1,149 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 20), + torch.nn.ReLU(), + torch.nn.Dropout(0.2), + torch.nn.Linear(20, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + + +class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition_0":DataCondition(input=y), + "data_condition_1":DataCondition(input=y), + } + +problem = Problem() + +#for each condition, define unroll instructions with these keys: +# - unroll_length: length of each unroll window +# - num_unrolls: number of unroll windows to create (if None, use all possible) +# - randomize: whether to randomize the starting indices of the unroll windows +unroll_instructions = { + "data_condition_0": { + "unroll_length": 10, + "num_unrolls": 89, + "randomize": True, + "eps": 5.0 + }, + "data_condition_1": { + "unroll_length": 20, + "num_unrolls": 79, + "randomize": True, + "eps": 10.0 + }, +} + +solver = AutoregressiveSolver( + unroll_instructions=unroll_instructions, + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), + eps=10.0, +) + +trainer = Trainer( + solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False +) +trainer.train() + +# ============================================================================ +# VISUALIZATION +# ============================================================================ + +test_start_idx = 50 +num_prediction_steps = 30 + +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# viauzlize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..0606a3fd6 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,172 @@ +import torch +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import DataCondition +from .autoregressive_solver_interface import AutoregressiveSolverInterface + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + """ + Autoregressive Solver class. + """ + + accepted_conditions_types = DataCondition + + def __init__( + self, + unroll_instructions, + problem, + model, + eps=None, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The model to be trained. + :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. + :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. + :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. + :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. + :param bool use_lt: Whether to use learning rate tuning. + """ + + super().__init__( + unroll_instructions=unroll_instructions, + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, data, condition_unroll_instructions): + """ + Compute the data loss for the recursive autoregressive solver. + This will be applied to each condition individually. + :param torch.Tensor data: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. + :rtype: torch.Tensor + """ + + initial_data, unroll_data = self.create_unroll_windows( + data, condition_unroll_instructions + ) + + unroll_length = condition_unroll_instructions["unroll_length"] + current_state = initial_data # [num_unrolls, features] + + losses = [] + for step in range(unroll_length): + + predicted_state = self.forward(current_state) # [num_unrolls, features] + target_state = unroll_data[:, step, :] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + + weighted_loss = (step_losses * weights).sum() + return weighted_loss + + def create_unroll_windows(self, data, condition_unroll_instructions): + """ + Create unroll windows for each condition from the data based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tuple of initial data and unroll data tensors. + :rtype: (torch.Tensor, torch.Tensor) + """ + + unroll_length = condition_unroll_instructions["unroll_length"] + + start_list = [] + unroll_list = [] + for starting_index in self.decide_starting_indices( + data, condition_unroll_instructions + ): + idx = starting_index.item() + start = data[idx] + target_start = idx + 1 + unroll = data[target_start : target_start + unroll_length, :] + start_list.append(start) + unroll_list.append(unroll) + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + return initial_data, unroll_data + + def decide_starting_indices(self, data, condition_unroll_instructions): + """ + Decide the starting indices for unrolling based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tensor of starting indices. + :rtype: torch.Tensor + """ + n_step, n_features = data.shape + num_unrolls = condition_unroll_instructions.get("num_unrolls", None) + unroll_length = condition_unroll_instructions["unroll_length"] + randomize = condition_unroll_instructions.get("randomize", True) + + max_start = n_step - unroll_length + indices = torch.arange(max_start, device=data.device) + + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] + + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] + + return indices + + def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): + """ + Compute adaptive weights for each time step based on cumulative losses. + :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. + :return: Tensor of shape [unroll_length] containing normalized weights. + :rtype: torch.Tensor + """ + num_steps = len(step_losses) + eps = condition_unroll_instructions.get("eps", None) + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + + return weights / weights.sum() + + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] + + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) + current_state = next_state + + return torch.stack(predictions) \ No newline at end of file diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..d0a6f919a --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,85 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...condition import DataCondition + + +class AutoregressiveSolverInterface(SolverInterface): + + def __init__(self, unroll_instructions, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param loss: The loss function to be minimized. If None, defaults to MSELoss. + :type loss: torch.nn.Module or LossInterface, optional + """ + + super().__init__(**kwargs) + + if loss is None: + loss = torch.nn.MSELoss() + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + self._unroll_instructions = unroll_instructions + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each conditions and each time applies the specialized loss_data function. + :param dict batch: A dictionary mapping condition names to data batches. + :return: A dictionary mapping condition names to computed loss values. + :rtype: dict + """ + + condition_loss = {} + for condition_name, points in batch: + condition_unroll_instructions = self._unroll_instructions[condition_name] + loss = self.loss_data( + points["input"], + condition_unroll_instructions, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, condition_unroll_instructions): + """ + Computes the data loss for each condition. + N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. + + :param torch.Tensor input: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. + :rtype: torch.Tensor + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn \ No newline at end of file