@@ -244,10 +244,14 @@ def remove(
244244 yield p , dec (pattern , p )
245245
246246
247- SQRT = np .sqrt (np .arange (1000 )) # saving the time to recompute square roots
247+ # saving the time to recompute square roots
248+ SQRT = np .sqrt (np .arange (1000 ))
249+ _SQRT = np .sqrt (np .arange (1000 ))
250+ _SQRT [0 ] = 1.0 # avoid division by zero
251+ SQRT_INV = 1 / _SQRT
248252
249253
250- @jit (nopython = True )
254+ @jit (nopython = True , fastmath = True )
251255def _hermite_multidimensional_renorm (R , y , G ): # pragma: no cover
252256 r"""Numba-compiled function to fill an array with the Hermite polynomials. It expects an array
253257 initialized with zeros everywhere except at index (0,...,0) (i.e. the seed value).
@@ -260,20 +264,69 @@ def _hermite_multidimensional_renorm(R, y, G): # pragma: no cover
260264 Returns:
261265 array[complex]: the multidimensional Hermite polynomials
262266 """
263- indices = np .ndindex (G .shape )
264- next (indices ) # skip the first index (0,...,0)
265- for idx in indices :
267+ # numba doesn't like tuples
268+ shape_arr = np .array (G .shape )
269+ D = y .shape [- 1 ]
270+
271+ # calculate the strides (e.g. (100,10,1) for shape (10,10,10))
272+ strides = np .ones_like (shape_arr )
273+ for i in range (D - 1 , 0 , - 1 ):
274+ strides [i - 1 ] = strides [i ] * shape_arr [i ]
275+
276+ # flatten output tensor
277+ shape = G .shape
278+ G = G .ravel ()
279+
280+ # initialize the n-dim index
281+ nd_index = np .ndindex (shape )
282+
283+ # skip corresponding first index (supposed to be already filled)
284+ next (nd_index )
285+
286+ # Iterate over the indices smaller than max(strides) with pivot bound check.
287+ # The check is needed only if the flat index is smaller than the largest stride.
288+ # Afterwards it will be safe to get the pivot by subtracting the first (largest) stride.
289+ for flat_index in range (1 , strides [0 ]):
290+ index = next (nd_index )
291+
266292 i = 0
267- for i , val in enumerate (idx ):
268- if val > 0 :
293+ # calculate (flat) pivot
294+ for s in strides :
295+ pivot = flat_index - s
296+ if pivot >= 0 : # if pivot not outside array
269297 break
270- ki = dec (idx , i )
271- u = y [i ] * G [ki ]
272- for l , kl in remove (ki ):
273- u -= SQRT [ki [l ]] * R [i , l ] * G [kl ]
274- G [idx ] = u / SQRT [idx [i ]]
275- return G
298+ i += 1
299+
300+ # contribution from pivot
301+ value_at_index = y [i ] * G [pivot ]
302+
303+ # contributions from pivot's lower neighbours
304+ # note the first is when j=i which needs a -1 in the sqrt from delta_ij
305+ value_at_index -= R [i , i ] * SQRT [index [i ] - 1 ] * G [pivot - strides [i ]]
306+ for j in range (i + 1 , D ):
307+ value_at_index -= R [i , j ] * SQRT [index [j ]] * G [pivot - strides [j ]]
308+ G [flat_index ] = value_at_index * SQRT_INV [index [i ]]
309+
310+ # Iterate over the rest of the indices.
311+ # Now i can always be 0 (largest stride), and we don't need bounds check
312+ for flat_index in range (strides [0 ], len (G )):
313+ index = next (nd_index )
314+
315+ # pivot can be calculated without bounds check
316+ pivot = flat_index - strides [0 ]
317+
318+ # contribution from pivot
319+ value_at_index = y [0 ] * G [pivot ]
320+
321+ # contribution from pivot's lower neighbours
322+ # note the first is when j=0 which needs a -1 in the sqrt from delta_0j
323+ value_at_index -= R [0 , 0 ] * SQRT [index [0 ] - 1 ] * G [pivot - strides [0 ]]
324+ for j in range (1 , D ):
325+ value_at_index -= R [0 , j ] * SQRT [index [j ]] * G [pivot - strides [j ]]
326+ G [flat_index ] = value_at_index * SQRT_INV [index [0 ]]
276327
328+ # reshape back to original shape
329+ return G .reshape (shape )
277330
278331@jit (nopython = True )
279332def _hermite_multidimensional (R , y , G ): # pragma: no cover
@@ -338,7 +391,7 @@ def _interferometer_renorm(R, G): # pragma: no cover
338391 u = 0
339392 for l , kl in remove (ki ):
340393 u -= SQRT [ki [l ]] * R [i , l ] * G [kl ]
341- G [idx ] = u / SQRT [idx [i ]]
394+ G [idx ] = u * SQRT_INV [idx [i ]]
342395
343396 return G
344397
@@ -448,8 +501,8 @@ def _grad_hermite_multidimensional_renorm(R, y, G, dG_dR, dG_dy): # pragma: no
448501 dy -= SQRT [ki [l ]] * dG_dy [kl ] * R [i , l ]
449502 dR -= SQRT [ki [l ]] * R [i , l ] * dG_dR [kl ]
450503 dR [i , l ] -= SQRT [ki [l ]] * G [kl ]
451- dG_dR [idx ] = dR / SQRT [idx [i ]]
452- dG_dy [idx ] = dy / SQRT [idx [i ]]
504+ dG_dR [idx ] = dR * SQRT_INV [idx [i ]]
505+ dG_dy [idx ] = dy * SQRT_INV [idx [i ]]
453506 return dG_dR , dG_dy
454507
455508
0 commit comments