Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions ext/CUDAExt/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ function check_sz(arr, maxshape)
end
end

function nda_to_logical_array(arr::NDArray{T,N}) where {T,N}
st_handle = cuNumeric.get_store(arr)
return Legate.LogicalArray{T,N}(st_handle, size(arr))
end

function Launch(kernel::cuNumeric.CUDATask, inputs::Tuple{Vararg{NDArray}},
outputs::Tuple{Vararg{NDArray}}, scalars::Tuple{Vararg{Any}}; blocks, threads)

Expand Down
1 change: 1 addition & 0 deletions lib/cunumeric_jl_wrapper/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ void wrap_binary_ops(jlcxx::Module& mod) {
void wrap_linalg_ops(jlcxx::Module& mod) {
mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE});
mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE});
mod.set_const("SVD", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SVD});
}
9 changes: 7 additions & 2 deletions src/cuNumeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ const SUPPORTED_NUMERIC_TYPES = Union{
SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES
}

# solve has no integer backend kernel
const SUPPORTED_SOLVE_TYPES = Union{SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES}
const SUPPORTED_LINALG_TYPES = Union{
SUPPORTED_INT_TYPES,Float32,Float64,SUPPORTED_COMPLEX_TYPES
}

# solve has no integer/Float16 backend kernel — float/complex only.
const SUPPORTED_SOLVE_TYPES = Union{Float32,Float64,SUPPORTED_COMPLEX_TYPES}
const SUPPORTED_SVD_TYPES = Union{Float32,Float64,SUPPORTED_COMPLEX_TYPES}
const SUPPORTED_ARRAY_TYPES = Union{Bool,SUPPORTED_NUMERIC_TYPES}
const SUPPORTED_TYPES = Union{SUPPORTED_ARRAY_TYPES,String}

Expand Down Expand Up @@ -148,6 +152,7 @@ include("ndarray/broadcast.jl")
include("ndarray/ndarray.jl")
include("ndarray/unary.jl")
include("ndarray/binary.jl")
include("ndarray/detail/linalg.jl")
include("ndarray/linalg.jl")

# scoping macro
Expand Down
87 changes: 87 additions & 0 deletions src/ndarray/detail/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
function choose_nd_color_shape(shape::NTuple{N,Int}) where {N}
color_shape = Base.ones(Int, N)
if N > 2
color_shape[1] = Legate.num_procs()
done = false
while !done && color_shape[1] % 2 == 0
weight_per_dim = [shape[i] / color_shape[i] for i in 1:(N - 2)]
max_weight, idx = findmax(weight_per_dim)
if weight_per_dim[idx] > 2 * weight_per_dim[1]
color_shape[1] ÷= 2
color_shape[idx] *= 2
else
done = true
end
end
end
return Tuple(color_shape)
end

function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where {N}
initial_color_shape = choose_nd_color_shape(full_shape)
tilesize = Tuple(
(full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N
)
color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N)
return tilesize, color_shape
end

function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N}
nrhs = size(b)[end]
full_shape = size(a)
tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape)
tilesize_b = (tilesize_a[1:(end - 1)]..., nrhs)

store_a = nda_to_logical_store(a)
store_b = nda_to_logical_store(b)
store_x = nda_to_logical_store(x)

tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a))
tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b))
tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b))

rt = Legate.get_runtime()
domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape)))
lib = cuNumeric.get_lib()
task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain)

Legate.add_input(task, tiled_a)
Legate.add_input(task, tiled_b)
Legate.add_output(task, tiled_x)

Legate.submit_manual_task(rt, task)
end

function svd_single(a::NDArray{T,N}, u::NDArray, s::NDArray, vh::NDArray) where {T,N}
rt = Legate.get_runtime();
lib = cuNumeric.get_lib();

task = Legate.create_auto_task(rt, lib, cuNumeric.SVD);

l_a = nda_to_logical_array(a)
l_u = nda_to_logical_array(u)
l_s = nda_to_logical_array(s)
l_vh = nda_to_logical_array(vh)

Legate.add_input(task, l_a)
Legate.add_output(task, l_u)
Legate.add_output(task, l_s)
Legate.add_output(task, l_vh)

Legate.add_broadcast(task, l_a)
Legate.add_broadcast(task, l_u)
Legate.add_broadcast(task, l_s)
Legate.add_broadcast(task, l_vh)

Legate.submit_auto_task(rt, task)
end

function _svd(a::NDArray{T,2}, full_matrices::Bool) where {T}
m, n = size(a)
k = min(m, n)
u = full_matrices ? zeros(T, m, m) : zeros(T, m, k)
s = zeros(T, k)
vh = full_matrices ? zeros(T, n, n) : zeros(T, k, n)
svd_single(a, u, s, vh)
return u, s, vh
end
6 changes: 6 additions & 0 deletions src/ndarray/detail/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,9 @@ function nda_to_logical_store(arr::NDArray{T,N}) where {T,N}
st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle, size(arr)))
return Legate.LogicalStore{T,N}(st_handle, size(arr))
end

function nda_to_logical_array(arr::NDArray{T,N}) where {T,N}
st_handle = cuNumeric.get_store(arr)
return Legate.LogicalArray{T,N}(st_handle, size(arr))
end

147 changes: 57 additions & 90 deletions src/ndarray/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,102 +1,54 @@
function choose_nd_color_shape(shape::NTuple{N,Int}) where {N}
color_shape = Base.ones(Int, N)
if N > 2
color_shape[1] = Legate.num_procs()
done = false
while !done && color_shape[1] % 2 == 0
weight_per_dim = [shape[i] / color_shape[i] for i in 1:(N - 2)]
max_weight, idx = findmax(weight_per_dim)
if weight_per_dim[idx] > 2 * weight_per_dim[1]
color_shape[1] ÷= 2
color_shape[idx] *= 2
else
done = true
end
end
end
return Tuple(color_shape)
end

function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where {N}
initial_color_shape = choose_nd_color_shape(full_shape)
tilesize = Tuple(
(full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N
)
color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N)
return tilesize, color_shape
end

function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N}
nrhs = size(b)[end]
full_shape = size(a)
tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape)
tilesize_b = (tilesize_a[1:(end - 1)]..., nrhs)

store_a = nda_to_logical_store(a)
store_b = nda_to_logical_store(b)
store_x = nda_to_logical_store(x)

tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a))
tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b))
tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b))

rt = Legate.get_runtime()
domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape)))
lib = cuNumeric.get_lib()
task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain)

Legate.add_input(task, tiled_a)
Legate.add_input(task, tiled_b)
Legate.add_output(task, tiled_x)

Legate.submit_manual_task(rt, task)
end

# solve runs in floating point:
# int/bool inputs promote to Float64 (matching cupynumeric)
const _SOLVE_PROMOTABLE = Union{SUPPORTED_INT_TYPES,Bool}
const _SOLVE_ACCEPTED = Union{SUPPORTED_SOLVE_TYPES,_SOLVE_PROMOTABLE}
_solve_eltype(::Type{T}) where {T<:_SOLVE_PROMOTABLE} = Float64
_solve_eltype(::Type{T}) where {T<:SUPPORTED_SOLVE_TYPES} = T

# Type/dim guards dispatch on one argument at a time, then forward to `_solve`.
function solve(a::NDArray{<:_SOLVE_ACCEPTED}, b::NDArray{<:_SOLVE_ACCEPTED})
A, B = eltype(a), eltype(b)
O = promote_type(_solve_eltype(A), _solve_eltype(B))
# int/bool -> float is an implicit promotion, disallowed unless `allowpromotion`
A <: _SOLVE_PROMOTABLE && assertpromotion(solve, A, O)
B <: _SOLVE_PROMOTABLE && assertpromotion(solve, B, O)
return _solve_check_a_dims(unchecked_promote_arr(a, O), unchecked_promote_arr(b, O))
end

function solve(a::NDArray, b::NDArray)
bad = eltype(a) <: _SOLVE_ACCEPTED ? eltype(b) : eltype(a)
throw(ArgumentError("array type $bad is unsupported in solve"))
# Dimension guards
function solve(a::NDArray{T,1}, b::NDArray{S,M}) where {T,S,M}
throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional"))
end

# `a` must be at least 2D, `b` at least 1D.
function _solve_check_a_dims(a::NDArray{<:Any,0}, b::NDArray)
function solve(a::NDArray{T,0}, b::NDArray{S,M}) where {T,S,M}
throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional"))
end
function _solve_check_a_dims(a::NDArray{<:Any,1}, b::NDArray)
throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional"))
end
_solve_check_a_dims(a::NDArray, b::NDArray) = _solve_check_b_dims(a, b)

function _solve_check_b_dims(a::NDArray, b::NDArray{<:Any,0})
function solve(a::NDArray{T,N}, b::NDArray{S,0}) where {T,N,S}
throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional"))
end
_solve_check_b_dims(a::NDArray, b::NDArray) = _solve(a, b)

# 2D case: (m,m),(m)->(m).
# Backend needs rhs "b" to be 2D. We reshape b from (n,) to (n,1)
function _solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S}
m = size(b)[1]
return reshape(_solve(a, reshape(b, (m, 1))), (m,))
# 2D case: (m,m),(m)->( m)
function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S}
size(a)[end - 1] != size(a)[end] &&
throw(ArgumentError("Last 2 dimensions of the array must be square"))
size(a)[2] != size(b)[1] &&
throw(
ArgumentError(
"Input operand 1 has a mismatch in its dimension 0, " *
"with signature (m,m),(m)->(m) (size $(size(b)[1]) " *
"is different from $(size(a)[2]))",
),
)
prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...)
x = zeros(T, size(b)...)
solve_batched(a, b, x)
return x
end

# 2D (m,m),(m,n)->(m,n) and batched (...,m,m),(...,m,n)->(...,m,n)
function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N}
# 2D case: (m,m),(m,n)->(m,n)
function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S}
size(a)[end - 1] != size(a)[end] &&
throw(ArgumentError("Last 2 dimensions of the array must be square"))
size(a)[2] != size(b)[1] &&
throw(
ArgumentError(
"Input operand 1 has a mismatch in its dimension 0, " *
"with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " *
"is different from $(size(a)[2]))",
),
)
prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...)
x = zeros(T, size(b)...)
solve_batched(a, b, x)
return x
end

# Batched case: (...,m,m),(...,m,n)->(...,m,n)
function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N}
size(a)[end - 1] != size(a)[end] &&
throw(ArgumentError("Last 2 dimensions of the array must be square"))
size(a)[end] != size(b)[end - 1] &&
Expand All @@ -114,6 +66,21 @@ function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N}
end

# Mismatched batch dimensions
function _solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M}
function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M}
throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)"))
end

function svd(a::NDArray{T,2}, full_matrices::Bool=true) where {T}
if size(a)[1] < size(a)[2]
throw(ArgumentError("cuNumeric only supports M >= N"))
end
return _svd(a, full_matrices)
end

function svd(a::NDArray{T,1}, full_matrices::Bool=true) where {T}
throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional"))
end

function svd(a::NDArray{T,N}, full_matrices::Bool=true) where {T,N}
throw(ArgumentError("cuNumeric does not yet support stacked 2d arrays"))
end
Loading
Loading