diff --git a/loopy/expression.py b/loopy/expression.py index 5a11b8354..a00e4c7f0 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -162,9 +162,7 @@ def map_constant(self, expr: object) -> bool: def map_variable(self, expr: p.Variable) -> bool: if expr.name == self.vec_iname: - # Technically, this is doable. But we're not going there. - raise UnvectorizableError() - + return True # A single variable is always a scalar. return False diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 20ff55fea..2bdd8db70 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -606,6 +606,21 @@ class Literal(LoopyExpressionBase): s: str +@p.expr_dataclass() +class TypedLiteral(Literal): + """A literal to be used during code generation which we know the type of. + + .. note:: + + Only used in the output of + :mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and + similar mappers). Not for use in Loopy source representation. + """ + + s: str + dtype: ToLoopyTypeConvertible + + @p.expr_dataclass() class ArrayLiteral(LoopyExpressionBase): """An array literal. diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 83c13dfe5..6b52256e9 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -49,7 +49,7 @@ from loopy.expression import dtype_to_type_context from loopy.target.c import CExpression from loopy.type_inference import TypeInferenceMapper, TypeReader -from loopy.types import LoopyType +from loopy.types import LoopyType, to_loopy_type from loopy.typing import Expression, is_integer @@ -435,7 +435,7 @@ def map_type_cast(self, expr: TypeCast, type_context: str): return self.rec(expr.child, type_context, expr.type) def map_constant(self, expr, type_context): - from loopy.symbolic import Literal + from loopy.symbolic import TypedLiteral if isinstance(expr, (complex, np.complexfloating)): real = self.rec(expr.real, type_context) @@ -462,10 +462,10 @@ def map_constant(self, expr, type_context): # FIXME: This assumes a 32-bit architecture. if isinstance(expr, np.float32): - return Literal(repr(float(expr))+"f") + return TypedLiteral(repr(float(expr))+"f", to_loopy_type(np.float32)) elif isinstance(expr, np.float64): - return Literal(repr(float(expr))) + return TypedLiteral(repr(float(expr)), to_loopy_type(np.float64)) # Disabled for now, possibly should be a subtarget. # elif isinstance(expr, np.float128): @@ -478,18 +478,19 @@ def map_constant(self, expr, type_context): suffix += "u" if iinfo.max > (2**31-1): suffix += "l" - return Literal(repr(int(expr))+suffix) + return TypedLiteral(repr(int(expr))+suffix, to_loopy_type(iinfo.dtype)) elif isinstance(expr, np.bool_): - return Literal("true") if expr else Literal("false") + return TypedLiteral("true", to_loopy_type(np.bool_)) if expr \ + else TypedLiteral("false", to_loopy_type(np.bool_)) else: raise LoopyError("do not know how to generate code for " "constant of numpy type '%s'" % type(expr).__name__) elif np.isfinite(expr): if type_context == "f": - return Literal(repr(float(expr))+"f") + return TypedLiteral(repr(float(expr))+"f", to_loopy_type(np.float32)) elif type_context == "d": - return Literal(repr(float(expr))) + return TypedLiteral(repr(float(expr)), to_loopy_type(np.float64)) elif type_context in ["i", "b"]: return int(expr) else: diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py index 34a88328c..bcee0e905 100644 --- a/loopy/target/ispc.py +++ b/loopy/target/ispc.py @@ -44,13 +44,14 @@ CoefficientCollector, CombineMapper, GroupHardwareAxisIndex, - Literal, LocalHardwareAxisIndex, SubstitutionMapper, + TypedLiteral, flatten, ) from loopy.target.c import CFamilyASTBuilder, CFamilyTarget from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper +from loopy.types import to_loopy_type if TYPE_CHECKING: @@ -125,10 +126,10 @@ def map_constant(self, expr, type_context): raise NotImplementedError("complex numbers in ispc") else: if type_context == "f": - return Literal(repr(float(expr))) + return TypedLiteral(repr(float(expr)), to_loopy_type(np.float32)) elif type_context == "d": # Keepin' the good ideas flowin' since '66. - return Literal(repr(float(expr))+"d") + return TypedLiteral(repr(float(expr))+"d", to_loopy_type(np.float64)) elif type_context in ["i", "b"]: return expr else: diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 07c5b49d0..524ed7e6e 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -24,6 +24,7 @@ THE SOFTWARE. """ +from contextlib import suppress from typing import TYPE_CHECKING, Literal, Sequence import numpy as np @@ -46,6 +47,7 @@ from loopy.codegen import CodeGenerationState from loopy.codegen.result import CodeGenerationResult + from loopy.kernel import LoopKernel # {{{ dtype registry wrappers @@ -456,7 +458,8 @@ def get_opencl_callables(): # {{{ symbol mangler -def opencl_symbol_mangler(kernel, name): +def opencl_symbol_mangler(kernel: LoopKernel, + name: str) -> tuple[NumpyType, str] | None: # FIXME: should be more picky about exact names if name.startswith("FLT_"): return NumpyType(np.dtype(np.float32)), name @@ -540,11 +543,32 @@ def opencl_preamble_generator(preamble_info): class ExpressionToOpenCLCExpressionMapper(ExpressionToCExpressionMapper): def wrap_in_typecast(self, actual_type, needed_dtype, s): + if needed_dtype.dtype.kind == "b" and actual_type.dtype.kind == "f": # CL does not perform implicit conversion from float-type to a bool. from pymbolic.primitives import Comparison return Comparison(s, "!=", 0) + if needed_dtype == actual_type: + return s + + registry = self.codegen_state.ast_builder.target.get_dtype_registry() + if self.codegen_state.target.is_vector_dtype(needed_dtype): + # OpenCL does not let you do explicit vector type casts between vector + # types. Instead you need to call their function which is of the form + # convert_(src) where n + # is the number of elements in the vector which is the same as in src. + # https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#explicit-casts + + # We infer the data type of (s) before we recurse down into (s) to convert + # to a CExpression. With vectorization, we can change the actual type of (s) + # from a scalar type to a vector type. So we are going to recompute the + # actual type. + type_of_s = self.infer_type(s) + if self.codegen_state.target.is_vector_dtype(type_of_s): + cast = var("convert_%s" % registry.dtype_to_ctype(needed_dtype)) + return cast(s) + return super().wrap_in_typecast(actual_type, needed_dtype, s) def map_group_hw_index(self, expr, type_context): @@ -553,6 +577,74 @@ def map_group_hw_index(self, expr, type_context): def map_local_hw_index(self, expr, type_context): return var("lid")(expr.axis) + def map_variable(self, expr, type_context): + + if self.codegen_state.vectorization_info: + if self.codegen_state.vectorization_info.iname == expr.name: + # This needs to be converted into a vector literal. + from loopy.symbolic import TypedLiteral + vector_length = self.codegen_state.vectorization_info.length + index_type = self.codegen_state.kernel.index_dtype + vector_type = self.codegen_state.target.vector_dtype(index_type, + vector_length) + typename = self.codegen_state.target.dtype_to_typename(vector_type) + vector_literal = f"(({typename})" + " (" + \ + ",".join([f"{i}" for i in range(vector_length)]) + "))" + return TypedLiteral(vector_literal, vector_type) + + # return Literal(vector_literal) + return super().map_variable(expr, type_context) + + def map_if(self, expr, type_context): + from loopy.types import to_loopy_type + result_type = self.infer_type(expr) + conditional_needed_loopy_type = to_loopy_type(np.bool_) + if self.codegen_state.vectorization_info: + from loopy.codegen import UnvectorizableError + from loopy.expression import VectorizabilityChecker + checker = VectorizabilityChecker(self.codegen_state.kernel, + self.codegen_state.vectorization_info.iname, + self.codegen_state.vectorization_info.length) + + with suppress(UnvectorizableError): + # We know there is an expression in codegen which can be vectorized. + # We are checking if this is one of the them. If it is not, then we can + # just continue with scalar code generation for this expression. + is_vector = checker(expr) + + if is_vector: + """ + We could have a vector literal here which may need to be + converted to an appropriate size. The OpenCL specification states + that for ( c ? a : b) a, b, and c must have the same + number of elements and bits and that c must be an integral type. + https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#table-builtin-relational + """ + index_type = to_loopy_type(self.codegen_state.kernel.index_dtype) + types = {8: to_loopy_type(np.int64), 4: to_loopy_type(np.int32), + 2: to_loopy_type(np.int16), 1: to_loopy_type(np.int8)} + length = self.codegen_state.vectorization_info.length + if self.codegen_state.target.is_vector_dtype(result_type): + if (index_type.itemsize != result_type.itemsize and + (result_type.itemsize // length) in types): + index_type = types[result_type.itemsize] + else: + raise ValueError("Types incompatible") + else: + # We know result is going to be a vector. + if (index_type.itemsize != result_type.itemsize and + result_type.itemsize in types): + index_type = types[result_type.itemsize] + vector_type = self.codegen_state.target.vector_dtype(index_type, + length) + conditional_needed_loopy_type = to_loopy_type(vector_type) + + return type(expr)( + self.rec(expr.condition, type_context, + conditional_needed_loopy_type), + self.rec(expr.then, type_context, result_type), + self.rec(expr.else_, type_context, result_type), + ) # }}} diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 8894af573..36ce49460 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -47,6 +47,7 @@ SubArrayRef, SubstitutionRuleExpander, SubstitutionRuleMappingContext, + TypedLiteral, parse_tagged_name, ) from loopy.translation_unit import ( @@ -365,6 +366,9 @@ def map_quotient(self, expr): else: return self.combine([n_dtype_set, d_dtype_set]) + def map_typed_literal(self, expr: TypedLiteral): + return [expr.dtype] + def map_constant(self, expr): if isinstance(expr, np.generic): return [NumpyType(np.dtype(type(expr)))] @@ -540,19 +544,40 @@ def map_lookup(self, expr): dtype = field[0] return [NumpyType(dtype)] + def is_vector_dtype(self, dtype): + target = self.kernel.target + + return target.is_vector_dtype(dtype) + def map_comparison(self, expr): - self(expr.left, return_tuple=False, return_dtype_set=False) - self(expr.right, return_tuple=False, return_dtype_set=False) + left = self(expr.left, return_tuple=False, return_dtype_set=False) + right = self(expr.right, return_tuple=False, return_dtype_set=False) + # We need to return a vector type if we either of the sides is a vector. + + vector_output = [] + for dtype in (left, right): + if self.is_vector_dtype(dtype): + vector_output.append(dtype) + if vector_output: + return vector_output return [NumpyType(np.dtype(np.bool_))] def map_logical_not(self, expr): - self.rec(expr.child) + child = self.rec(expr.child) + if self.is_vector_dtype(child): + return child return [NumpyType(np.dtype(np.bool_))] def map_logical_and(self, expr): + output_type = [] for child in expr.children: - self.rec(child) + type_to_check = self.rec(child) + if self.is_vector_dtype(type_to_check): + output_type.append(type_to_check) + + if output_type: + return output_type return [NumpyType(np.dtype(np.bool_))] diff --git a/test/test_target.py b/test/test_target.py index fe2ad1d8a..76a38518d 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -875,6 +875,35 @@ def test_float3(): assert "float3" in device_code +def test_cl_vectorize_index_variable(ctx_factory): + knl = lp.make_kernel( + "{ [i]: 0<=i0") + + rng = np.random.default_rng(seed=12) + a = rng.normal(size=(16, 4)) + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float64, "n": np.int64}) + _evt, (result,) = knl(queue, a=a, n=a.size) + + i = np.arange(16) + j = np.arange(4) + ind = 4*i[:, None] + j + result_ref = np.where(ind < 32, a*3, np.sin(a)) + + assert np.allclose(result, result_ref) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: