Skip to content

Commit da0f290

Browse files
committed
faster hermite_renorm using strides
1 parent bdffe32 commit da0f290

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

thewalrus/_hermite_multidimensional.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,14 @@ def remove(
244244
yield p, dec(pattern, p)
245245

246246

247-
SQRT = np.sqrt(np.arange(1000)) # saving the time to recompute square roots
247+
# saving the time to recompute square roots
248+
SQRT = np.sqrt(np.arange(1000))
249+
_SQRT = np.sqrt(np.arange(1000))
250+
_SQRT[0] = 1.0 # avoid division by zero
251+
SQRT_INV = 1 / _SQRT
248252

249253

250-
@jit(nopython=True)
254+
@jit(nopython=True, fastmath=True)
251255
def _hermite_multidimensional_renorm(R, y, G): # pragma: no cover
252256
r"""Numba-compiled function to fill an array with the Hermite polynomials. It expects an array
253257
initialized with zeros everywhere except at index (0,...,0) (i.e. the seed value).
@@ -260,20 +264,69 @@ def _hermite_multidimensional_renorm(R, y, G): # pragma: no cover
260264
Returns:
261265
array[complex]: the multidimensional Hermite polynomials
262266
"""
263-
indices = np.ndindex(G.shape)
264-
next(indices) # skip the first index (0,...,0)
265-
for idx in indices:
267+
# numba doesn't like tuples
268+
shape_arr = np.array(G.shape)
269+
D = y.shape[-1]
270+
271+
# calculate the strides (e.g. (100,10,1) for shape (10,10,10))
272+
strides = np.ones_like(shape_arr)
273+
for i in range(D - 1, 0, -1):
274+
strides[i - 1] = strides[i] * shape_arr[i]
275+
276+
# flatten output tensor
277+
shape = G.shape
278+
G = G.ravel()
279+
280+
# initialize the n-dim index
281+
nd_index = np.ndindex(shape)
282+
283+
# skip corresponding first index (supposed to be already filled)
284+
next(nd_index)
285+
286+
# Iterate over the indices smaller than max(strides) with pivot bound check.
287+
# The check is needed only if the flat index is smaller than the largest stride.
288+
# Afterwards it will be safe to get the pivot by subtracting the first (largest) stride.
289+
for flat_index in range(1, strides[0]):
290+
index = next(nd_index)
291+
266292
i = 0
267-
for i, val in enumerate(idx):
268-
if val > 0:
293+
# calculate (flat) pivot
294+
for s in strides:
295+
pivot = flat_index - s
296+
if pivot >= 0: # if pivot not outside array
269297
break
270-
ki = dec(idx, i)
271-
u = y[i] * G[ki]
272-
for l, kl in remove(ki):
273-
u -= SQRT[ki[l]] * R[i, l] * G[kl]
274-
G[idx] = u / SQRT[idx[i]]
275-
return G
298+
i += 1
299+
300+
# contribution from pivot
301+
value_at_index = y[i] * G[pivot]
302+
303+
# contributions from pivot's lower neighbours
304+
# note the first is when j=i which needs a -1 in the sqrt from delta_ij
305+
value_at_index -= R[i, i] * SQRT[index[i] - 1] * G[pivot - strides[i]]
306+
for j in range(i + 1, D):
307+
value_at_index -= R[i, j] * SQRT[index[j]] * G[pivot - strides[j]]
308+
G[flat_index] = value_at_index * SQRT_INV[index[i]]
309+
310+
# Iterate over the rest of the indices.
311+
# Now i can always be 0 (largest stride), and we don't need bounds check
312+
for flat_index in range(strides[0], len(G)):
313+
index = next(nd_index)
314+
315+
# pivot can be calculated without bounds check
316+
pivot = flat_index - strides[0]
317+
318+
# contribution from pivot
319+
value_at_index = y[0] * G[pivot]
320+
321+
# contribution from pivot's lower neighbours
322+
# note the first is when j=0 which needs a -1 in the sqrt from delta_0j
323+
value_at_index -= R[0, 0] * SQRT[index[0] - 1] * G[pivot - strides[0]]
324+
for j in range(1, D):
325+
value_at_index -= R[0, j] * SQRT[index[j]] * G[pivot - strides[j]]
326+
G[flat_index] = value_at_index * SQRT_INV[index[0]]
276327

328+
# reshape back to original shape
329+
return G.reshape(shape)
277330

278331
@jit(nopython=True)
279332
def _hermite_multidimensional(R, y, G): # pragma: no cover
@@ -338,7 +391,7 @@ def _interferometer_renorm(R, G): # pragma: no cover
338391
u = 0
339392
for l, kl in remove(ki):
340393
u -= SQRT[ki[l]] * R[i, l] * G[kl]
341-
G[idx] = u / SQRT[idx[i]]
394+
G[idx] = u * SQRT_INV[idx[i]]
342395

343396
return G
344397

@@ -448,8 +501,8 @@ def _grad_hermite_multidimensional_renorm(R, y, G, dG_dR, dG_dy): # pragma: no
448501
dy -= SQRT[ki[l]] * dG_dy[kl] * R[i, l]
449502
dR -= SQRT[ki[l]] * R[i, l] * dG_dR[kl]
450503
dR[i, l] -= SQRT[ki[l]] * G[kl]
451-
dG_dR[idx] = dR / SQRT[idx[i]]
452-
dG_dy[idx] = dy / SQRT[idx[i]]
504+
dG_dR[idx] = dR * SQRT_INV[idx[i]]
505+
dG_dy[idx] = dy * SQRT_INV[idx[i]]
453506
return dG_dR, dG_dy
454507

455508

0 commit comments

Comments
 (0)