Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,45 +93,10 @@ def forward(self, t, xt):
print(' qVectorField: vt.shape', vt.shape)
print(' qVectorField: vt', vt)

# sum over arguments of exponential, that is,
# over the d-dimensions of each element in x0,
# so that we get the product of d normal densities.

'''
v2 = (vt*vt).sum(dim=-1)
vv = torch.where(v2 < 207, v2, 207) ## why not use the logsumexp for greater stability
if torch.isnan(vv).any():
raise ValueError("vv contains at least one NAN")

if debug:
print(' qVectorField: vv.shape', vv.shape)

# compute unnormalized probability densities.
pt = torch.exp(-vv/2)
if torch.isnan(pt).any():
raise ValueError("pt contains at least one NAN")

# pt.shape: (N, M)
if debug:
print(' qVectorField: pt.shape', pt.shape)

# sum over the M d-dimensional Gaussian densities
ptsum = pt.sum(dim=-1)
# ptsum.shape: (N, )
if torch.isnan(ptsum).any():
raise ValueError("ptsum contains at least one NAN")

# protect sum against divide by zero
pt_sum = torch.where(ptsum < 1.e-44, 1, ptsum).unsqueeze(-1)
# pt_sum.shape: (N, 1)
if torch.isnan(ptsum).any():
raise ValueError("ptsum contains at least one NAN")

# compute weights
wt = pt / pt_sum
'''


# Compute squared distances from each collocation point xt to every
# MC sample x0, summed over the d feature dimensions.
# Uses the log-sum-exp trick for numerical stability (avoids exp overflow
# that occurred in the old implementation which used torch.where clamping).
v2 = (vt * vt).sum(dim=-1)
log_weights = -0.5 * v2
log_weights = log_weights - torch.logsumexp(log_weights, dim=-1, keepdim=True)
Expand All @@ -140,11 +105,11 @@ def forward(self, t, xt):

# wt.shape: (N, M)
if torch.isnan(wt).any():
print('pt')
print(pt)
print('v2')
print(v2)
print()
print('pt_sum')
print(pt_sum)
print('log_weights')
print(log_weights)
print()
print('wt')
print(wt)
Expand Down