Skip to content

Commit 167ae7c

Browse files
committed
Move ArgSizeLimitingPytatoLoopyPyOpenCLTarget to pytato.utils, remove hard pytato dep
1 parent dc81090 commit 167ae7c

2 files changed

Lines changed: 27 additions & 17 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
if TYPE_CHECKING:
6060
import pytato
6161
import pyopencl as cl
62-
import loopy as lp
6362

6463
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
6564
import pyopencl as cl # noqa: F811
@@ -219,20 +218,6 @@ def get_target(self):
219218

220219
# {{{ PytatoPyOpenCLArrayContext
221220

222-
from pytato.target.loopy import LoopyPyOpenCLTarget
223-
224-
225-
class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
226-
def __init__(self, limit_arg_size_nbytes: int) -> None:
227-
super().__init__()
228-
self.limit_arg_size_nbytes = limit_arg_size_nbytes
229-
230-
@memoize_method
231-
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
232-
from loopy import PyOpenCLTarget
233-
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)
234-
235-
236221
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
237222
"""
238223
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -408,7 +393,9 @@ def get_target(self):
408393

409394
logger.info(f"limiting argument buffer size for {dev} to {limit} bytes")
410395

411-
return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
396+
from arraycontext.impl.pytato.utils import \
397+
ArgSizeLimitingPytatoLoopyPyOpenCLTarget
398+
return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
412399
else:
413400
return super().get_target()
414401

arraycontext/impl/pytato/utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@
2323
"""
2424

2525

26-
from typing import Any, Dict, Set, Tuple, Mapping
26+
from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING
27+
from pytools import memoize_method
28+
2729
from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis
2830
from pytato.array import Array, DataWrapper, DictOfNamedArrays
2931
from pytato.transform import CopyMapper
3032
from pytools import UniqueNameGenerator
3133
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
34+
from pytato.target.loopy import LoopyPyOpenCLTarget
35+
36+
if TYPE_CHECKING:
37+
import loopy as lp
3238

3339

3440
class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
@@ -91,3 +97,20 @@ def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]:
9197

9298
def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]:
9399
return tuple(ClAxis(axis.tags) for axis in axes)
100+
101+
102+
# {{{ arg-size-limiting loopy target
103+
104+
class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
105+
def __init__(self, limit_arg_size_nbytes: int) -> None:
106+
super().__init__()
107+
self.limit_arg_size_nbytes = limit_arg_size_nbytes
108+
109+
@memoize_method
110+
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
111+
from loopy import PyOpenCLTarget
112+
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)
113+
114+
# }}}
115+
116+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)