diff --git a/Physics_Informed_Neural_Network_Diffusion_Equation_Sijil_Jose/flow_de/flow_de.py b/Physics_Informed_Neural_Network_Diffusion_Equation_Sijil_Jose/flow_de/flow_de.py index f4abab4..0d2d5ff 100644 --- a/Physics_Informed_Neural_Network_Diffusion_Equation_Sijil_Jose/flow_de/flow_de.py +++ b/Physics_Informed_Neural_Network_Diffusion_Equation_Sijil_Jose/flow_de/flow_de.py @@ -93,45 +93,10 @@ def forward(self, t, xt): print(' qVectorField: vt.shape', vt.shape) print(' qVectorField: vt', vt) - # sum over arguments of exponential, that is, - # over the d-dimensions of each element in x0, - # so that we get the product of d normal densities. - - ''' - v2 = (vt*vt).sum(dim=-1) - vv = torch.where(v2 < 207, v2, 207) ## why not use the logsumexp for greater stability - if torch.isnan(vv).any(): - raise ValueError("vv contains at least one NAN") - - if debug: - print(' qVectorField: vv.shape', vv.shape) - - # compute unnormalized probability densities. - pt = torch.exp(-vv/2) - if torch.isnan(pt).any(): - raise ValueError("pt contains at least one NAN") - - # pt.shape: (N, M) - if debug: - print(' qVectorField: pt.shape', pt.shape) - - # sum over the M d-dimensional Gaussian densities - ptsum = pt.sum(dim=-1) - # ptsum.shape: (N, ) - if torch.isnan(ptsum).any(): - raise ValueError("ptsum contains at least one NAN") - - # protect sum against divide by zero - pt_sum = torch.where(ptsum < 1.e-44, 1, ptsum).unsqueeze(-1) - # pt_sum.shape: (N, 1) - if torch.isnan(ptsum).any(): - raise ValueError("ptsum contains at least one NAN") - - # compute weights - wt = pt / pt_sum - ''' - - + # Compute squared distances from each collocation point xt to every + # MC sample x0, summed over the d feature dimensions. + # Uses the log-sum-exp trick for numerical stability (avoids exp overflow + # that occurred in the old implementation which used torch.where clamping). v2 = (vt * vt).sum(dim=-1) log_weights = -0.5 * v2 log_weights = log_weights - torch.logsumexp(log_weights, dim=-1, keepdim=True) @@ -140,11 +105,11 @@ def forward(self, t, xt): # wt.shape: (N, M) if torch.isnan(wt).any(): - print('pt') - print(pt) + print('v2') + print(v2) print() - print('pt_sum') - print(pt_sum) + print('log_weights') + print(log_weights) print() print('wt') print(wt)