|
23 | 23 | """ |
24 | 24 |
|
25 | 25 |
|
26 | | -from typing import Any, Dict, Set, Tuple, Mapping |
| 26 | +from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING |
| 27 | +from pytools import memoize_method |
| 28 | + |
27 | 29 | from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis |
28 | 30 | from pytato.array import Array, DataWrapper, DictOfNamedArrays |
29 | 31 | from pytato.transform import CopyMapper |
30 | 32 | from pytools import UniqueNameGenerator |
31 | 33 | from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis |
| 34 | +from pytato.target.loopy import LoopyPyOpenCLTarget |
| 35 | + |
| 36 | +if TYPE_CHECKING: |
| 37 | + import loopy as lp |
32 | 38 |
|
33 | 39 |
|
34 | 40 | class _DatawrapperToBoundPlaceholderMapper(CopyMapper): |
@@ -91,3 +97,20 @@ def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: |
91 | 97 |
|
92 | 98 | def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: |
93 | 99 | return tuple(ClAxis(axis.tags) for axis in axes) |
| 100 | + |
| 101 | + |
| 102 | +# {{{ arg-size-limiting loopy target |
| 103 | + |
| 104 | +class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): |
| 105 | + def __init__(self, limit_arg_size_nbytes: int) -> None: |
| 106 | + super().__init__() |
| 107 | + self.limit_arg_size_nbytes = limit_arg_size_nbytes |
| 108 | + |
| 109 | + @memoize_method |
| 110 | + def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: |
| 111 | + from loopy import PyOpenCLTarget |
| 112 | + return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) |
| 113 | + |
| 114 | +# }}} |
| 115 | + |
| 116 | +# vim: foldmethod=marker |
0 commit comments