Skip to content

Commit 6af1a6b

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit 6af1a6b

4 files changed

Lines changed: 34 additions & 3 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -860,6 +861,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860861
make = TempFunction if opt_ftemps else TempArray
861862

862863
clusters = []
864+
inits = []
863865
subs = {}
864866
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865867
name = sregistry.make_name()
@@ -928,8 +930,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928930
assert writeto.size == 0
929931

930932
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
933+
const = not meta.guards
934+
obj = Temp(name=name, dtype=dtype, is_const=const)
932935
expression = Eq(obj, uxreplace(pivot, subs))
936+
if not const:
937+
inits.append(Eq(obj, 0))
933938

934939
callback = lambda idx: obj # noqa: B023
935940

@@ -959,6 +964,13 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959964
# Finally, build the alias Cluster
960965
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961966

967+
if inits:
968+
# To avoid undefined variables when an constant (Temp) alias is used
969+
# within different guards/loop, we need to initialize it outside of the loops
970+
# so that it's globally defined.
971+
# See tests/test_operators.py
972+
clusters.insert(0, Cluster(inits, null_ispace))
973+
962974
return clusters, subs
963975

964976

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

tests/test_dse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,25 @@ def test_space_and_time_invariant_together(self):
26852685
'tx0_blk0y0_blk0xyzyz'
26862686
)
26872687

2688+
def test_split_cond(self):
2689+
grid = Grid((11, 11))
2690+
time = grid.time_dim
2691+
2692+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2693+
2694+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2695+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2696+
2697+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2698+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2699+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2700+
2701+
op = Operator([eq0, eq1, eq2])
2702+
cond = FindNodes(Conditional).visit(op)
2703+
assert len(cond) == 3
2704+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'r0 = cos(time);'
2705+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2706+
26882707

26892708
class TestIsoAcoustic:
26902709

0 commit comments

Comments
 (0)