diff --git a/Project.toml b/Project.toml index ce49bf6d7..69163e027 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -77,6 +78,7 @@ Graphs = "1" JSON3 = "1" KernelAbstractions = "0.9" MacroTools = "0.5" +MPI = "0.20.22" MemPool = "0.4.12" Metal = "1.1" NextLA = "0.2.2" diff --git a/benchmarks/check_comm_asymmetry.jl b/benchmarks/check_comm_asymmetry.jl new file mode 100644 index 000000000..684240ec5 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.jl @@ -0,0 +1,111 @@ +#!/usr/bin/env julia +# Parse MPI+Dagger logs and report communication decision asymmetry per tag. +# Asymmetry: for the same tag, one rank decides to send (local+bcast, sender+communicated, etc.) +# and another rank decides to infer (inferred, uninvolved) and never recv → deadlock. +# +# Usage: julia check_comm_asymmetry.jl < logfile +# Or: mpiexec -n 10 julia ... run_matmul.jl 2>&1 | tee matmul.log; julia check_comm_asymmetry.jl < matmul.log + +const SEND_DECISIONS = Set([ + "local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast", + "aliasing", # when followed by local+bcast we already capture local+bcast +]) +const RECV_DECISIONS = Set([ + "communicated", "receiver", "sender+communicated", # received data +]) +const INFER_DECISIONS = Set([ + "inferred", "uninvolved", # did not recv (uses inferred type) +]) + +function parse_line(line) + # Match [rank X][tag Y] then any [...] and capture the last bracket pair before space or end + rank = nothing + tag = nothing + decision = nothing + category = nothing # aliasing, execute!, remotecall_endpoint + for m in eachmatch(r"\[rank\s+(\d+)\]", line) + rank = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[tag\s+(\d+)\]", line) + tag = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + category = m.captures[1] + end + # Decision is usually in last [...] that looks like [word] or [word+word] + for m in eachmatch(r"\]\[([^\]]+)\]", line) + candidate = m.captures[1] + # Normalize: "communicated" "inferred" "local+bcast" "sender+inferred" "receiver" etc. + if occursin("inferred", candidate) && !occursin("communicated", candidate) + decision = "inferred" + break + elseif occursin("communicated", candidate) + decision = "communicated" + break + elseif occursin("local+bcast", candidate) + decision = "local+bcast" + break + elseif occursin("sender+", candidate) + decision = startswith(candidate, "sender+inferred") ? "sender+inferred" : "sender+communicated" + break + elseif candidate == "receiver" + decision = "receiver" + break + elseif candidate == "receiver+bcast" + decision = "receiver+bcast" + break + elseif candidate == "inplace_move" + decision = "inplace_move" + break + end + end + return rank, tag, category, decision +end + +function main() + # tag => Dict(rank => decision) + by_tag = Dict{Int, Dict{Int, String}}() + for line in eachline(stdin) + rank, tag, category, decision = parse_line(line) + isnothing(rank) && continue + isnothing(tag) && continue + isnothing(decision) && continue + if !haskey(by_tag, tag) + by_tag[tag] = Dict{Int, String}() + end + by_tag[tag][rank] = decision + end + + # For each tag, check: is there at least one sender and one inferrer (non-receiver)? + send_keys = Set(["local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"]) + infer_keys = Set(["inferred", "sender+inferred"]) # sender+inferred means sender didn't need to recv + recv_keys = Set(["communicated", "receiver", "sender+communicated"]) + + asymmetries = [] + for (tag, ranks) in sort(collect(by_tag), by = first) + senders = [r for (r, d) in ranks if d in send_keys] + inferrers = [r for (r, d) in ranks if d in infer_keys || d == "uninvolved"] + receivers = [r for (r, d) in ranks if d in recv_keys] + # Asymmetry: someone sends (bcast) so will send to ALL other ranks; someone chose infer and won't recv. + if !isempty(senders) && !isempty(inferrers) + push!(asymmetries, (tag, senders, inferrers, receivers, ranks)) + end + end + + if isempty(asymmetries) + println("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + end + + println("=== Communication decision asymmetry (can cause deadlock) ===\n") + for (tag, senders, inferrers, receivers, ranks) in asymmetries + println("Tag $tag:") + println(" Senders (will bcast to all others): $senders") + println(" Inferrers (did not recv): $inferrers") + println(" Receivers: $receivers") + println(" All decisions: $ranks") + println() + end +end + +main() diff --git a/benchmarks/check_comm_asymmetry.py b/benchmarks/check_comm_asymmetry.py new file mode 100644 index 000000000..31a117442 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Parse MPI+Dagger logs and report communication decision asymmetry per tag. +Asymmetry: for the same tag, one rank decides to send (local+bcast, etc.) +and another decides to infer (inferred) and never recv → deadlock. + +Usage: + # Capture full log (all ranks' Core.println from mpi.jl go to stdout): + mpiexec -n 10 julia --project=/path/to/Dagger.jl benchmarks/run_matmul.jl 2>&1 | tee matmul.log + # Then look for asymmetry (same tag: one rank sends, another infers → deadlock): + python3 check_comm_asymmetry.py < matmul.log +""" + +import re +import sys +from collections import defaultdict + +SEND_DECISIONS = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} +RECV_DECISIONS = {"communicated", "receiver", "sender+communicated"} +INFER_DECISIONS = {"inferred", "uninvolved", "sender+inferred"} + + +def parse_line(line: str): + rank = tag = category = decision = None + m = re.search(r"\[rank\s+(\d+)\]", line) + if m: + rank = int(m.group(1)) + m = re.search(r"\[tag\s+(\d+)\]", line) + if m: + tag = int(m.group(1)) + m = re.search(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + if m: + category = m.group(1) + # Capture decision from [...] blocks + for m in re.finditer(r"\]\[([^\]]+)\]", line): + candidate = m.group(1) + if "inferred" in candidate and "communicated" not in candidate: + decision = "inferred" + break + if "communicated" in candidate: + decision = "communicated" + break + if "local+bcast" in candidate: + decision = "local+bcast" + break + if candidate.startswith("sender+"): + decision = "sender+inferred" if "inferred" in candidate else "sender+communicated" + break + if candidate == "receiver": + decision = "receiver" + break + if candidate == "receiver+bcast": + decision = "receiver+bcast" + break + if candidate == "inplace_move": + decision = "inplace_move" + break + return rank, tag, category, decision + + +def main(): + by_tag = defaultdict(dict) # tag -> {rank: decision} + for line in sys.stdin: + rank, tag, category, decision = parse_line(line) + if rank is None or tag is None or decision is None: + continue + by_tag[tag][rank] = decision + + send_keys = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} + infer_keys = {"inferred", "sender+inferred", "uninvolved"} + recv_keys = {"communicated", "receiver", "sender+communicated"} + + asymmetries = [] + for tag in sorted(by_tag.keys()): + ranks = by_tag[tag] + senders = [r for r, d in ranks.items() if d in send_keys] + inferrers = [r for r, d in ranks.items() if d in infer_keys] + receivers = [r for r, d in ranks.items() if d in recv_keys] + if senders and inferrers: + asymmetries.append((tag, senders, inferrers, receivers, ranks)) + + if not asymmetries: + print("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + + print("=== Communication decision asymmetry (can cause deadlock) ===\n") + for tag, senders, inferrers, receivers, ranks in asymmetries: + print(f"Tag {tag}:") + print(f" Senders (will bcast to all others): {senders}") + print(f" Inferrers (did not recv): {inferrers}") + print(f" Receivers: {receivers}") + print(f" All decisions: {dict(ranks)}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_distribute_fetch.jl b/benchmarks/run_distribute_fetch.jl new file mode 100644 index 000000000..822e1ad2c --- /dev/null +++ b/benchmarks/run_distribute_fetch.jl @@ -0,0 +1,42 @@ +#!/usr/bin/env julia +# Create a matrix with a fixed reproducible pattern, distribute it with an +# MPI procgrid, then on each rank fetch and println the chunk(s) it owns. +# Usage (from repo root, use full path to Dagger.jl): +# mpiexec -n 4 julia --project=/path/to/Dagger.jl benchmarks/run_distribute_fetch.jl + +using MPI +using Dagger + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) + +# Fixed reproducible pattern: 6×6 matrix, M[i,j] = 10*i + j (same on all ranks) +const N = 6 +const BLOCK = 2 +A = [10 * i + j for i in 1:N, j in 1:N] + +# Procgrid: use Dagger's compatible processors so the procgrid passes validation +availprocs = collect(Dagger.compatible_processors()) +nblocks = (cld(N, BLOCK), cld(N, BLOCK)) +procgrid = reshape( + [availprocs[mod(i - 1, length(availprocs)) + 1] for i in 1:prod(nblocks)], + nblocks, +) + +# Distribute so chunk (i,j) is computed on procgrid[i,j] +D = distribute(A, Blocks(BLOCK, BLOCK), procgrid) +D_fetched = fetch(D) + +# On each rank: fetch and print only the chunk(s) this rank owns +for (idx, ch) in enumerate(D_fetched.chunks) + if ch isa Dagger.Chunk && ch.handle isa Dagger.MPIRef && ch.handle.rank == rank + data = fetch(ch) + println("rank $rank chunk $idx: ", data) + end +end diff --git a/benchmarks/run_matmul.jl b/benchmarks/run_matmul.jl new file mode 100644 index 000000000..0eb4ec0d7 --- /dev/null +++ b/benchmarks/run_matmul.jl @@ -0,0 +1,105 @@ +#!/usr/bin/env julia +# N×N matmul benchmark (Float32); block size scales with number of ranks. +# Usage (use the full path to Dagger.jl, not "..."): +# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl +# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline: +# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl + +using MPI +using Dagger +using LinearAlgebra + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const N = 2_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false")) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)") +end + +# Allocate and fill matrices in blocks (Float32) +A = rand(Blocks(BLOCK, BLOCK), Float32, N, N) +B = rand(Blocks(BLOCK, BLOCK), Float32, N, N) + +# Matrix multiply C = A * B +t_matmul = @elapsed begin + C = A * B +end + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + +# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region. +if CHECK_CORRECTNESS + t_collect = @elapsed begin + A_full = Dagger.collect_datadeps(A; root=0) + B_full = Dagger.collect_datadeps(B; root=0) + C_dagger = Dagger.collect_datadeps(C; root=0) + end + if rank == 0 + println("Collecting result and computing baseline for correctness check (GPU)...") + using CUDA + CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.") + t_upload = @elapsed begin + A_g = CUDA.cu(A_full) + B_g = CUDA.cu(B_full) + end + println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s") + + t_baseline = @elapsed begin + C_ref_g = A_g * B_g + end + println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s") + + # Require all elements within 100× machine epsilon relative error (componentwise) + C_dagger_cpu = C_dagger + C_ref_cpu = Array(C_ref_g) + eps_f = eps(Float32) + rtol = 50.0f0 * eps_f + diff = C_dagger_cpu .- C_ref_cpu + # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero + denom = max.(abs.(C_ref_cpu), eps_f) + rel_err = abs.(diff) ./ denom + max_rel_err = Float32(maximum(rel_err)) + ok = max_rel_err <= rtol + if ok + println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")") + else + println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")") + end + + # Per-block: which blocks have any element with rel_err > 100×eps + n_bi = ceil(Int, N / BLOCK) + n_bj = ceil(Int, N / BLOCK) + bad_blocks = Tuple{Int,Int,Float32}[] + for bi in 1:n_bi, bj in 1:n_bj + ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N) + rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N) + block_rel = Float32(maximum(@view(rel_err[ri, rj]))) + if block_rel > rtol + push!(bad_blocks, (bi, bj, block_rel)) + end + end + if isempty(bad_blocks) + println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.") + else + println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):") + sort!(bad_blocks; by = x -> -x[3]) + for (bi, bj, block_rel) in bad_blocks + println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N), + ", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel) + end + end + end +end diff --git a/benchmarks/run_qr.jl b/benchmarks/run_qr.jl new file mode 100644 index 000000000..c5915db2a --- /dev/null +++ b/benchmarks/run_qr.jl @@ -0,0 +1,46 @@ +#!/usr/bin/env julia +# 10k×10k QR + matmul benchmark; block size scales with number of ranks. +# Usage: mpiexec -n 100 julia --project=/path/to/Dagger.jl benchmarks/bench_100rank_qr_matmul.jl +# Or: bash benchmarks/run_100rank_qr_matmul.sh . + +using MPI +using Dagger +using LinearAlgebra + +Dagger.accelerate!(:mpi) + +const N = 10_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (QR + matmul)") +end + +# Allocate and fill 10k×10k matrix in 1k×1k blocks +A = rand(Blocks(BLOCK, BLOCK), Float64, N, N) +MPI.Barrier(comm) + +# QR factorization (computing Q runs the full factorization) +t_qr = @elapsed begin + qr!(A) +end +MPI.Barrier(comm) + +if rank == 0 + println("QR time: ", round(t_qr; digits=4), " s") +end + +# Matrix multiply A * A +t_matmul = @elapsed begin + C = A * A +end +MPI.Barrier(comm) + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + diff --git a/src/Dagger.jl b/src/Dagger.jl index 2e757ebc5..1a5720784 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -53,6 +53,13 @@ import Adapt include("lib/util.jl") include("utils/dagdebug.jl") +# Type definitions (for MPI/acceleration) +include("types/processor.jl") +include("types/scope.jl") +include("types/memory-space.jl") +include("types/chunk.jl") +include("types/acceleration.jl") + # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") @@ -77,12 +84,14 @@ include("queue.jl") include("thunk.jl") include("utils/fetch.jl") include("utils/chunks.jl") +include("weakchunk.jl") include("utils/logging.jl") include("submission.jl") abstract type MemorySpace end include("utils/memory-span.jl") include("utils/interval_tree.jl") include("memory-spaces.jl") +include("acceleration.jl") # Task scheduling include("compute.jl") @@ -90,6 +99,7 @@ include("utils/clock.jl") include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +include("tochunk.jl") # Data dependency task queue include("datadeps/aliasing.jl") @@ -157,6 +167,10 @@ function set_distributed_package!(value) @info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!" end +# MPI (mpi.jl loads MPI; mpi_mempool uses it) +include("mpi.jl") +include("mpi_mempool.jl") + # Precompilation import PrecompileTools: @compile_workload include("precompile.jl") diff --git a/src/acceleration.jl b/src/acceleration.jl new file mode 100644 index 000000000..f95236468 --- /dev/null +++ b/src/acceleration.jl @@ -0,0 +1,46 @@ +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +function _with_default_acceleration(f) + old_accel = ACCELERATION[] + ACCELERATION[] = DistributedAcceleration() + result = try + f() + finally + ACCELERATION[] = old_accel + end + return result +end + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end +accelerate!(::Nothing) = nothing + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + +function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) + comp = compatible_processors(scope, procs) + accel === nothing && return comp + return Set(p for p in comp if accel_matches_proc(accel, p)) +end + +uniform_execution(::DistributedAcceleration) = false +uniform_execution() = uniform_execution(current_acceleration()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) diff --git a/src/affinity.jl b/src/affinity.jl new file mode 100644 index 000000000..aab663a51 --- /dev/null +++ b/src/affinity.jl @@ -0,0 +1,32 @@ +export domain, UnitDomain, project, alignfirst, ArrayDomain + +import Base: isempty, getindex, intersect, ==, size, length, ndims + +""" + domain(x::T) + +Returns metadata about `x`. This metadata will be in the `domain` +field of a Chunk object when an object of type `T` is created as +the result of evaluating a Thunk. +""" +function domain end + +""" + UnitDomain + +Default domain -- has no information about the value +""" +struct UnitDomain end + +""" +If no `domain` method is defined on an object, then +we use the `UnitDomain` on it. A `UnitDomain` is indivisible. +""" +domain(x::Any) = UnitDomain() + +### ChunkIO +affinity(r::DRef) = OSProc(r.owner)=>r.size +# this previously returned a vector with all machines that had the file cached +# but now only returns the owner and size, for consistency with affinity(::DRef), +# see #295 +affinity(r::FileRef) = OSProc(1)=>r.size diff --git a/src/array/alloc.jl b/src/array/alloc.jl index aa1050210..33de3506d 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -93,14 +93,31 @@ function stage(ctx, A::AllocateArray) scope = ExactScope(A.procgrid[CartesianIndex(mod1.(Tuple(I), size(A.procgrid))...)]) end + N = ndims(A.domainchunks) + ret_type = Array{A.eltype, N} if A.want_index - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, i, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, i, size(x)) else - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, size(x)) end tasks[i] = task end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = ndims(A.domainchunks) + expected_type = Array{A.eltype, N} + Dagger.mpi_propagate_chunk_types!(tasks, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in tasks] + if allequal(chunk_types) + @info "[rank $rank] Array creation (alloc): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (alloc): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(A.eltype, A.domain, A.domainchunks, tasks, A.partitioning) end diff --git a/src/array/darray.jl b/src/array/darray.jl index 32336f95d..fc99dc75d 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -1,7 +1,7 @@ -import Base: ==, fetch +import Base: ==, fetch, length, isempty, size export DArray, DVector, DMatrix, DVecOrMat, Blocks, AutoBlocks -export distribute +export distribute, collect_datadeps ###### Array Domains ###### @@ -36,7 +36,7 @@ Base.getindex(arr::AbstractArray{T,0} where T, d::ArrayDomain{0}) = arr Base.getindex(arr::GPUArraysCore.AbstractGPUArray, d::ArrayDomain) = arr[indexes(d)...] Base.getindex(arr::GPUArraysCore.AbstractGPUArray{T,0} where T, d::ArrayDomain{0}) = arr -function intersect(a::ArrayDomain, b::ArrayDomain) +function Base.intersect(a::ArrayDomain, b::ArrayDomain) if a === b return a end @@ -83,7 +83,8 @@ isempty(a::ArrayDomain) = length(a) == 0 The domain of an array is an ArrayDomain. """ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)]) - +# Scalar / non-array values (e.g. for Chunk of immediate data) +domain(x::Any) = ArrayDomain(()) abstract type ArrayOp{T, N} <: AbstractArray{T, N} end Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian() @@ -174,6 +175,7 @@ domain(d::DArray) = d.domain chunks(d::DArray) = d.chunks domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) +Base.ndims(d::DArray{T,N}) where {T,N} = N stage(ctx, c::DArray) = c function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} @@ -200,6 +202,31 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} collect(treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))) end end + +""" + collect_datadeps(d::DArray; root=nothing) + +Collect a DArray to a single array by fetching every chunk on the current rank +and assembling into a full array. No datadeps scheduling or root-only assembly: +each rank that calls this gets the full matrix (useful when correctness matters +more than communication cost). +""" +function collect_datadeps(d::DArray{T,N}; root=nothing) where {T,N} + if isempty(d.chunks) + return Array{eltype(d)}(undef, size(d)...) + end + if N == 0 + return fetch(d.chunks[1]) + end + + chks = d.chunks + doms = domainchunks(d) + out = Array{T,N}(undef, size(d)) + for I in CartesianIndices(chks) + copyto!(view(out, indexes(doms[I])...), fetch(chks[I])) + end + return out +end Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A)) Base.wait(A::DArray) = foreach(wait, A.chunks) @@ -483,6 +510,21 @@ function stage(ctx::Context, d::Distribute) Dagger.@spawn compute_scope=scope identity(d.data[c]) end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = Base.ndims(d.data) + expected_type = Array{eltype(d.data), N} + Dagger.mpi_propagate_chunk_types!(cs, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in cs] + if allequal(chunk_types) + @info "[rank $rank] Array creation (distribute): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (distribute): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(eltype(d.data), domain(d.data), d.domainchunks, @@ -620,7 +662,7 @@ end mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle))) function mapchunks(f, d::DArray{T,N,F}) where {T,N,F} chunks = map(d.chunks) do chunk - owner = get_parent(chunk.processor).pid + owner = root_worker_id(chunk.processor) remotecall_fetch(mapchunk, owner, f, chunk) end DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat) diff --git a/src/array/mul.jl b/src/array/mul.jl index 02b207641..5890473da 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -41,7 +41,7 @@ function LinearAlgebra.generic_matmatmul!( return gemm_dagger!(C, transA, transB, A, B, alpha, beta) end end -function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) +function _repartition_matmatmul(C, A, B, transA::Char, transB::Char)::Tuple{Blocks{2}, Blocks{2}, Blocks{2}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -93,6 +93,24 @@ function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) return Blocks(partC...), Blocks(partA...), Blocks(partB...) end +# Typed BLAS wrappers so that every @spawn kernel has an inferable return type +@inline function _gemm!(transA::Char, transB::Char, alpha::T, A, B, mzone, C)::Matrix{T} where {T} + BLAS.gemm!(transA, transB, alpha, A, B, mzone, C) + return C +end +@inline function _syrk!(uplo::AbstractChar, trans::AbstractChar, alpha::T, A, mzone, C)::Matrix{T} where {T} + BLAS.syrk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _herk!(uplo::AbstractChar, trans::AbstractChar, alpha::Real, A, mzone, C)::Matrix{<:Complex} + BLAS.herk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _gemv!(transA::Char, alpha::T, A, x, mzone, y)::Vector{T} where {T} + BLAS.gemv!(transA, alpha, A, x, mzone, y) + return y +end + """ Performs one of the matrix-matrix operations @@ -136,7 +154,7 @@ function gemm_dagger!( # A: NoTrans / B: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -150,7 +168,7 @@ function gemm_dagger!( # A: NoTrans / B: [Conj]Trans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -166,7 +184,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: NoTrans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -180,7 +198,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -243,7 +261,7 @@ function syrk_dagger!( for k in range(1, Ant) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, trans, real(alpha), @@ -252,7 +270,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -267,7 +285,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -283,7 +301,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -300,7 +318,7 @@ function syrk_dagger!( for k in range(1, Amt) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, transs, real(alpha), @@ -309,7 +327,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -324,7 +342,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -340,7 +358,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -393,16 +411,17 @@ end return A end -@inline function copytile!(A, B) +@inline function copytile!(A::AbstractMatrix{T}, B::AbstractMatrix{T})::Nothing where {T} m, n = size(A) C = B' for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end -@inline function copydiagtile!(A, uplo) +@inline function copydiagtile!(A::AbstractMatrix{T}, uplo::AbstractChar)::Nothing where {T} m, n = size(A) Acpy = copy(A) @@ -417,6 +436,7 @@ end for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end function LinearAlgebra.generic_matvecmul!( C::DVector{T}, @@ -440,7 +460,7 @@ function LinearAlgebra.generic_matvecmul!( return gemv_dagger!(C, transA, A, B, _alpha, _beta) end end -function _repartition_matvecmul(C, A, B, transA::Char) +function _repartition_matvecmul(C, A, B, transA::Char)::Tuple{Blocks{1}, Blocks{2}, Blocks{1}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -495,7 +515,7 @@ function gemv_dagger!( # A: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[m, k]), @@ -508,7 +528,7 @@ function gemv_dagger!( # A: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[k, m]), diff --git a/src/array/trsm.jl b/src/array/trsm.jl index 65e87c5d5..c0c025468 100644 --- a/src/array/trsm.jl +++ b/src/array/trsm.jl @@ -189,4 +189,4 @@ function trsm!(side::Char, uplo::Char, trans::Char, diag::Char, alpha::T, A::DMa end end -end \ No newline at end of file +end diff --git a/src/chunks.jl b/src/chunks.jl index 03bdfb65d..d5e7b6082 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -1,56 +1,4 @@ -export domain, UnitDomain, project, alignfirst, ArrayDomain - -import Base: isempty, getindex, intersect, ==, size, length, ndims - -""" - domain(x::T) - -Returns metadata about `x`. This metadata will be in the `domain` -field of a Chunk object when an object of type `T` is created as -the result of evaluating a Thunk. -""" -function domain end - -""" - UnitDomain - -Default domain -- has no information about the value -""" -struct UnitDomain end - -""" -If no `domain` method is defined on an object, then -we use the `UnitDomain` on it. A `UnitDomain` is indivisible. -""" -domain(x::Any) = UnitDomain() - -###### Chunk ###### - -""" - Chunk - -A reference to a piece of data located on a remote worker. `Chunk`s are -typically created with `Dagger.tochunk(data)`, and the data can then be -accessed from any worker with `collect(::Chunk)`. `Chunk`s are -serialization-safe, and use distributed refcounting (provided by -`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, -as long as a reference exists on some worker. - -Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a -sense) the processor that "owns" or contains the data. Calling -`collect(::Chunk)` will perform data movement and conversions defined by that -processor to safely serialize the data to the calling worker. - -## Constructors -See [`tochunk`](@ref). -""" -mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} - chunktype::Type{T} - domain - handle::H - processor::P - scope::S -end +###### Chunk Methods ###### domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype @@ -72,20 +20,27 @@ function collect(ctx::Context, chunk::Chunk; options=nothing) elseif chunk.handle isa FileRef return poolget(chunk.handle) else - return move(chunk.processor, OSProc(), chunk.handle) + return move(chunk.processor, default_processor(), chunk.handle) end end collect(ctx::Context, ref::DRef; options=nothing) = move(OSProc(ref.owner), OSProc(), ref) collect(ctx::Context, ref::FileRef; options=nothing) = poolget(ref) # FIXME: Do move call -function Base.fetch(chunk::Chunk; raw=false) - if raw - poolget(chunk.handle) - else - collect(chunk) +@warn "Fix semantics of collect" maxlog=1 +function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=uniform_execution(), kwargs...) where T + value = fetch_handle(chunk.handle; uniform)::T + if unwrap && unwrappable(value) + return fetch(value; unwrap, uniform, kwargs...) end + return value end +fetch_handle(ref::DRef; uniform::Bool) = poolget(ref) +fetch_handle(ref::FileRef; uniform::Bool) = poolget(ref) +unwrappable(x::Chunk) = true +unwrappable(x::DRef) = true +unwrappable(x::FileRef) = true +unwrappable(x) = false # Unwrap Chunk, DRef, and FileRef by default move(from_proc::Processor, to_proc::Processor, x::Chunk) = @@ -100,32 +55,3 @@ move(to_proc::Processor, d::DRef) = move(OSProc(d.owner), to_proc, d) move(to_proc::Processor, x) = move(OSProc(), to_proc, x) - -### ChunkIO -affinity(r::DRef) = OSProc(r.owner)=>r.size -# this previously returned a vector with all machines that had the file cached -# but now only returns the owner and size, for consistency with affinity(::DRef), -# see #295 -affinity(r::FileRef) = OSProc(1)=>r.size - -struct WeakChunk - wid::Int - id::Int - x::WeakRef - function WeakChunk(c::Chunk) - return new(c.handle.owner, c.handle.id, WeakRef(c)) - end -end -unwrap_weak(c::WeakChunk) = c.x.value -function unwrap_weak_checked(c::WeakChunk) - cw = unwrap_weak(c) - @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" - return cw -end -wrap_weak(c::Chunk) = WeakChunk(c) -isweak(c::WeakChunk) = true -isweak(c::Chunk) = false -is_task_or_chunk(c::WeakChunk) = true -Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = - error("Cannot serialize a WeakChunk") -chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index c3e0ed20b..848443e8e 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -226,6 +226,7 @@ struct ArgumentWrapper function ArgumentWrapper(arg, dep_mod) h = hash(dep_mod) h = _identity_hash(arg, h) + check_uniform(h, arg) return new(arg, dep_mod, h) end end @@ -242,6 +243,7 @@ struct HistoryEntry end struct AliasedObjectCacheStore + accel::Acceleration keys::Vector{AbstractAliasing} derived::Dict{AbstractAliasing,AbstractAliasing} stored::Dict{MemorySpace,Set{AbstractAliasing}} @@ -249,7 +251,8 @@ struct AliasedObjectCacheStore originals::Set{AbstractAliasing} end AliasedObjectCacheStore() = - AliasedObjectCacheStore(Vector{AbstractAliasing}(), + AliasedObjectCacheStore(current_acceleration(), + Vector{AbstractAliasing}(), Dict{AbstractAliasing,AbstractAliasing}(), Dict{MemorySpace,Set{AbstractAliasing}}(), Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}(), @@ -279,7 +282,7 @@ end function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) @assert !is_stored(cache, dest_space, ainfo) "Cache already has derived ainfo $ainfo" key = cache.derived[ainfo] - value_ainfo = aliasing(value, identity) + value_ainfo = aliasing(cache.accel, value, identity) cache.derived[value_ainfo] = key push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) @@ -296,6 +299,7 @@ function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ain end struct AliasedObjectCache + accel::Acceleration space::MemorySpace chunk::Chunk end @@ -340,7 +344,7 @@ function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::A cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore set_key_stored!(cache_raw, space, ainfo, value) end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(cache.accel, x, identity)) x_space = memory_space(x) if !is_key_present(cache, x_space, ainfo) # Preserve the object's memory-space/processor pairing when inserting @@ -356,7 +360,7 @@ function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, iden @assert y isa Chunk "Didn't get a Chunk from functor" @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" if memory_space(x) != cache.space - @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" + @assert ainfo != aliasing(caache.accel, y, identity) "Aliasing mismatch! $ainfo == $(aliasing(cache.accel, y, identity))" end set_stored!(cache, y, ainfo) return y @@ -436,7 +440,9 @@ struct DataDepsState arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) + ainfo_backing_chunk = _with_default_acceleration() do + tochunk(AliasedObjectCacheStore()) + end supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() @@ -497,7 +503,9 @@ function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, t arg_chunk = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - arg_chunk = tochunk(arg) + arg_chunk = with(MPI_TID=>task.uid) do + tochunk(arg) + end state.raw_arg_to_chunk[arg] = arg_chunk else state.raw_arg_to_chunk[arg] = arg @@ -507,6 +515,7 @@ function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, t # Track the origin space of the argument origin_space = memory_space(arg_chunk) + check_uniform(origin_space) state.arg_origin[arg_chunk] = origin_space state.remote_arg_to_original[arg_chunk] = arg_chunk @@ -568,7 +577,7 @@ function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::Argum end # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) # Cache the result state.ainfo_cache[remote_arg_w] = ainfo @@ -671,7 +680,10 @@ region returns. """ supports_inplace_move(x) = true supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +@warn "Fix supports_inplace_move for MPI" maxlog=1 function supports_inplace_move(c::Chunk) + # FIXME + return true # FIXME: Use MemPool.access_ref pid = root_worker_id(c.processor) if pid == myid() @@ -748,24 +760,38 @@ end isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true +@warn "Properly propagate MPI_TID and uniformity through any remotecalls" maxlog=1 function generate_slot!(state::DataDepsState, dest_space, data) # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. orig_space = memory_space(data) + check_uniform(orig_space) to_proc = first(processors(dest_space)) + check_uniform(to_proc) from_proc = first(processors(orig_space)) + check_uniform(from_proc) + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + display(typeof(data)) + end + check_uniform(typeof(data)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) + aliased_object_cache = AliasedObjectCache(current_acceleration(), dest_space, state.ainfo_backing_chunk) ctx = Sch.eager_context() id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + data_chunk = with(MPI_TID=>DATADEPS_CURRENT_TASK[].uid) do + remotecall_endpoint(move_rewrap, current_acceleration(), aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + end @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data + check_uniform(memory_space(dest_space_args[data])) + check_uniform(processor(dest_space_args[data])) + check_uniform(dest_space_args[data].handle) + return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -778,42 +804,43 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) - to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) - end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) + +function remotecall_fetch_fast(f, wid::Integer, args...; kwargs...) + if wid == myid() + return f(args...; kwargs...) end + return remotecall_fetch(f, wid, args...; kwargs...) end -function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) - return aliased_object!(cache, x) do x - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) +function remotecall_endpoint(f, accel::DistributedAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + from_w = root_worker_id(from_proc) + return remotecall_fetch_fast(from_w) do + data_raw = unwrap(data) + return f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) - # Unwrap so that we hit the right dispatch - wid = root_worker_id(data) - if wid != myid() - return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) +function remotecall_endpoint_transfer(f, accel::DistributedAcceleration, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + return remotecall_fetch_fast(to_w) do + return f(accel, from_proc, to_proc, from_space, to_space, data) end - data_raw = unwrap(data) - return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - # For generic data +@warn "Replace all remotecall_fetch calls with remotecall_endpoint" maxlog=1 +move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) = + remotecall_endpoint(move_rewrap, accel, cache, from_proc, to_proc, from_space, to_space, data) +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # Generic data, do the transfer return aliased_object!(cache, data) do data - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, data) do accel, from_proc, to_proc, from_space, to_space, data + return tochunk(move(from_proc, to_proc, data), to_proc) + end end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + check_uniform(p_chunk.handle) inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) return tochunk(v_new, to_proc) @@ -821,21 +848,16 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P end # FIXME: Do this programmatically via recursive dispatch for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + @eval function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch_fast(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = $(wrapper)(p_new) return tochunk(v_new, to_proc) end end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - return aliased_object!(cache, v) do v - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, v) - end -end #= FIXME: Make this work so we can automatically move-rewrap recursive objects function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T if isstructtype(T) diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 1c2aa600f..418987124 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -31,11 +31,12 @@ end Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) -function aliasing(x::ChunkView{N}) where N +function aliasing(accel::Acceleration, x::ChunkView{N}, dep_mod) where N + @assert dep_mod === identity "Dependency modifiers not yet supported for ChunkView: $dep_mod" return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices x = unwrap(x) v = view(x, slices...) - return aliasing(v) + return aliasing(accel, v, dep_mod) end end memory_space(x::ChunkView) = memory_space(x.chunk) @@ -64,4 +65,4 @@ function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) end end -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 96112cf15..3b3ed5185 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -25,6 +25,8 @@ function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) append!(queue.seen_tasks, pairs) end +const DATADEPS_CURRENT_TASK = TaskLocalValue{Union{DTask,Nothing}}(Returns(nothing)) + """ spawn_datadeps(f::Base.Callable) @@ -88,6 +90,7 @@ end const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing) const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) +@warn "Add reliable, uniform-safe Processor sorting" maxlog=1 function distribute_tasks!(queue::DataDepsTaskQueue) #= TODO: Improvements to be made: # - Support for copying non-AbstractArray arguments @@ -98,20 +101,25 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) + accel = current_acceleration() + accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc + Dagger.accel_matches_proc(accel, proc) end + all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) + # FIXME: This is an unreliable way to ensure processor uniformity + sort!(all_procs, by=short_name) + scope = get_compute_scope() filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end + if uniform_execution(accel) + for proc in all_procs + check_uniform(proc) + end + end all_scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end # Round-robin assign tasks to processors upper_queue = get_options(:task_queue) @@ -128,7 +136,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Copy args from remote to local # N.B. We sort the keys to ensure a deterministic order for uniformity + check_uniform(length(state.arg_owner)) for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + check_uniform(arg_w) arg = arg_w.arg origin_space = state.arg_origin[arg] remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) @@ -199,11 +209,15 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr fargs::Vector{Argument} end + DATADEPS_CURRENT_TASK[] = task + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) scheduler = queue.scheduler our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task) @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) + check_uniform(our_proc) + check_uniform(our_space) # Find the scope for this task (and its copies) task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) @@ -308,6 +322,9 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr if spec.options.syncdeps === nothing spec.options.syncdeps = Set{ThunkSyncdep}() end + if spec.options.tag === nothing + spec.options.tag = to_tag() + end syncdeps = spec.options.syncdeps map_or_ntuple(task_arg_ws) do idx arg_ws = task_arg_ws[idx] @@ -342,7 +359,9 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr new_spec = DTaskSpec(new_fargs, spec.options) new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope - new_spec.options.occupancy = Dict(Any=>0) + if uniform_execution() + new_spec.options.occupancy = Dict(Any=>0) + end ctx = Sch.eager_context() @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) @@ -370,5 +389,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr write_num += 1 + DATADEPS_CURRENT_TASK[] = nothing + return write_num end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 2c2c49920..ee1b060db 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -98,7 +98,7 @@ function compute_remainder_for_arg!(state::DataDepsState, for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end - spaces = collect(spaces_set) + spaces = sort(collect(spaces_set), by=short_name) N = length(spaces) # Lookup all memory spans for arg_w in these spaces @@ -118,6 +118,8 @@ function compute_remainder_for_arg!(state::DataDepsState, @goto restart end end + check_uniform(spaces) + check_uniform(target_ainfos) # We may only need to schedule a full copy from the origin space to the # target space if this is the first time we've written to `arg_w` @@ -159,6 +161,8 @@ function compute_remainder_for_arg!(state::DataDepsState, other_ainfo = aliasing!(state, owner_space, arg_w) other_space = owner_space end + check_uniform(other_ainfo) + check_uniform(other_space) # Lookup all memory spans for arg_w in these spaces other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) @@ -174,6 +178,7 @@ function compute_remainder_for_arg!(state::DataDepsState, foreach(other_many_spans) do span verify_span(span) end + check_uniform(other_many_spans) if other_space == target_space # Only subtract, this data is already up-to-date in target_space @@ -250,7 +255,9 @@ Enqueues a copy operation to update the remainder regions of an object before a function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, f, idx, dest_scope, task, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) end end @@ -304,7 +311,9 @@ Enqueues a copy operation to update the remainder regions of an object back to t function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, dest_scope, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) end end @@ -537,4 +546,4 @@ function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" return A.rowval -end \ No newline at end of file +end diff --git a/src/dtask.jl b/src/dtask.jl index e94803502..c9e9e811f 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -11,14 +11,14 @@ Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do wait(t.future) return end -function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) +function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false, move_value=!raw, unwrap=!raw, uniform=uniform_execution()) error, value = Dagger.Sch.thunk_yield() do fetch(t.future) end if error throw(value) end - if raw + if !move_value return value else return move(proc, value) @@ -65,11 +65,11 @@ function Base.wait(t::DTask) wait(t.future) return end -function Base.fetch(t::DTask; raw=false) +function Base.fetch(t::DTask; raw=false, move_value=!raw, unwrap=!raw, uniform=false) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `DTask`")) end - return fetch(t.future; raw) + return fetch(t.future; move_value, unwrap, uniform) end function waitany(tasks::Vector{DTask}) if isempty(tasks) diff --git a/src/lib/domain-blocks.jl b/src/lib/domain-blocks.jl index 2a0854e3b..95e5c360f 100644 --- a/src/lib/domain-blocks.jl +++ b/src/lib/domain-blocks.jl @@ -6,6 +6,8 @@ struct DomainBlocks{N} <: AbstractArray{ArrayDomain{N, NTuple{N, UnitRange{Int}} end Base.@deprecate_binding BlockedDomains DomainBlocks +ndims(::DomainBlocks{N}) where N = N + size(x::DomainBlocks) = map(length, x.cumlength) function _getindex(x::DomainBlocks{N}, idx::Tuple) where N starts = map((vec, i) -> i == 0 ? 0 : getindex(vec,i), x.cumlength, map(x->x-1, idx)) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index eb4f7ad5b..a531509cf 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -4,18 +4,10 @@ end CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner -memory_space(x) = CPURAMMemorySpace(myid()) -function memory_space(x::Chunk) - proc = processor(x) - if proc isa OSProc - # TODO: This should probably be programmable - return CPURAMMemorySpace(proc.pid) - else - return only(memory_spaces(proc)) - end -end -memory_space(x::DTask) = - memory_space(fetch(x; raw=true)) +memory_space(x, proc::Processor=default_processor()) = first(memory_spaces(proc)) +memory_space(x::Processor) = first(memory_spaces(x)) +memory_space(x::Chunk) = x.space +memory_space(x::DTask) = memory_space(fetch(x; move_value=false, unwrap=false)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) @@ -28,9 +20,10 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement -function unwrap(x::Chunk) - @assert x.handle.owner == myid() - MemPool.poolget(x.handle) +unwrap(x::Chunk) = unwrap(x.handle) +function unwrap(handle::DRef) + @assert root_worker_id(handle) == myid() "DRef $handle is not owned by this process: $(root_worker_id(handle)) != $(myid())" + return MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) @@ -69,6 +62,16 @@ function move!(::Type{<:Tridiagonal}, to_space::MemorySpace, from_space::MemoryS return end +# FIXME: Take MemorySpace instead +function move_type(from_proc::Processor, to_proc::Processor, ::Type{T}) where T + if from_proc == to_proc + return T + end + return Base._return_type(move, Tuple{typeof(from_proc), typeof(to_proc), T}) +end +move_type(from_proc::Processor, to_proc::Processor, ::Type{<:Chunk{T}}) where T = + move_type(from_proc, to_proc, T) + ### Aliasing and Memory Spans type_may_alias(::Type{String}) = false @@ -355,6 +358,7 @@ function memory_spans(oa::ObjectAliasing{S}) where S return [span] end +aliasing(accel::Acceleration, x, T) = aliasing(x, T) function aliasing(x, dep_mod) if dep_mod isa Symbol return aliasing(getfield(x, dep_mod)) @@ -391,19 +395,25 @@ aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() function aliasing(x::Chunk, T) - @assert x.handle isa DRef if root_worker_id(x.processor) == myid() return aliasing(unwrap(x), T) end + @assert x.handle isa DRef return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end end -aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x - aliasing(unwrap(x)) +function aliasing(x::Chunk) + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x)) + end + @assert x.handle isa DRef + return remotecall_fetch(root_worker_id(x.processor), x) do x + aliasing(unwrap(x)) + end end -aliasing(x::DTask, T) = aliasing(fetch(x; raw=true), T) -aliasing(x::DTask) = aliasing(fetch(x; raw=true)) +aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) +aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) function aliasing(x::Base.RefValue{T}) where T addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) @@ -611,5 +621,5 @@ unsafe_free!(x::Chunk) = remotecall_fetch(root_worker_id(x), x) do x unsafe_free!(unwrap(x)) return end -unsafe_free!(x::DTask) = unsafe_free!(fetch(x; raw=true)) +unsafe_free!(x::DTask) = unsafe_free!(fetch(x; move_value=false, unwrap=false)) unsafe_free!(x) = nothing # Do nothing by default diff --git a/src/mpi.jl b/src/mpi.jl new file mode 100644 index 000000000..b5723e6a2 --- /dev/null +++ b/src/mpi.jl @@ -0,0 +1,1004 @@ +@warn "Move to MPIExt.jl" maxlog=1 + +using MPI + +const CHECK_UNIFORMITY = Ref{Bool}(false) +function check_uniformity!(check::Bool=true) + CHECK_UNIFORMITY[] = check +end +function check_uniform(value::Integer, original=value) + CHECK_UNIFORMITY[] && uniform_execution() || return true + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + matched = compare_all(value, comm) + if !matched + if rank == 0 + Core.print("[$rank] Found non-uniform value!\n") + end + Core.print("[$rank] value=$value, original=$original\n") + throw(ArgumentError("Non-uniform value")) + end + MPI.Barrier(comm) + return matched +end +function check_uniform(value, original=value) + CHECK_UNIFORMITY[] && uniform_execution() || return true + return check_uniform(hash(value), original) +end + +function compare_all(value, comm) + rank = MPI.Comm_rank(comm) + size = MPI.Comm_size(comm) + for i in 0:(size-1) + if i != rank + send_yield(value, comm, i, UInt32(0)) + end + end + match = true + for i in 0:(size-1) + if i != rank + other_value = recv_yield(comm, i, UInt32(0)) + if value != other_value + match = false + end + end + end + return match +end + +struct MPIAcceleration <: Acceleration + comm::MPI.Comm +end +MPIAcceleration() = MPIAcceleration(MPI.COMM_WORLD) + +function aliasing(accel::MPIAcceleration, x::Chunk, T) + handle = x.handle::MPIRef + @assert accel.comm == handle.comm "MPIAcceleration comm mismatch" + tag = to_tag() + check_uniform(tag) + rank = MPI.Comm_rank(accel.comm) + if handle.rank == rank + ainfo = _with_default_acceleration() do + aliasing(x, T) + end + #Core.print("[$rank] aliasing: $ainfo, sending\n") + @opcounter :aliasing_bcast_send_yield + bcast_send_yield(ainfo, accel.comm, handle.rank, tag) + else + #Core.print("[$rank] aliasing: receiving from $(handle.rank)\n") + ainfo = recv_yield(accel.comm, handle.rank, tag) + #Core.print("[$rank] aliasing: received $ainfo\n") + end + check_uniform(ainfo) + return ainfo +end + +default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) +default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +uniform_execution(accel::MPIAcceleration) = true + +@warn "Add a lock to MPIClusterProcChildren" maxlog=1 +const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() + +struct MPIClusterProc <: Processor + comm::MPI.Comm + function MPIClusterProc(comm::MPI.Comm) + populate_children!(comm) + return new(comm) + end +end + +Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIOSProc(proc.comm), log_sink) + +MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) + +function populate_children!(comm::MPI.Comm) + children = get_processors(OSProc()) + MPIClusterProcChildren[comm] = children +end + +struct MPIOSProc <: Processor + comm::MPI.Comm + rank::Int +end + +function MPIOSProc(comm::MPI.Comm) + rank = MPI.Comm_rank(comm) + return MPIOSProc(comm, rank) +end + +function MPIOSProc() + return MPIOSProc(MPI.COMM_WORLD) +end + +ProcessScope(p::MPIOSProc) = ProcessScope(myid()) + +function check_uniform(proc::MPIOSProc, original=proc) + return check_uniform(hash(MPIOSProc), original) && + check_uniform(proc.rank, original) +end + +function memory_spaces(proc::MPIOSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end + +struct MPIProcessScope <: AbstractScope + comm::MPI.Comm + rank::Int +end + +Base.isless(::MPIProcessScope, ::MPIProcessScope) = false +Base.isless(::MPIProcessScope, ::NodeScope) = true +Base.isless(::MPIProcessScope, ::UnionScope) = true +Base.isless(::MPIProcessScope, ::TaintScope) = true +Base.isless(::MPIProcessScope, ::AnyScope) = true +constrain(x::MPIProcessScope, y::MPIProcessScope) = + x == y ? y : InvalidScope(x, y) +constrain(x::NodeScope, y::MPIProcessScope) = + x == y.parent ? y : InvalidScope(x, y) + +Base.isless(::ExactScope, ::MPIProcessScope) = true +constrain(x::MPIProcessScope, y::ExactScope) = + x == y.parent ? y : InvalidScope(x, y) + +function enclosing_scope(proc::MPIOSProc) + return MPIProcessScope(proc.comm, proc.rank) +end + +function Dagger.to_scope(::Val{:mpi_rank}, sc::NamedTuple) + if sc.mpi_rank == Colon() + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=Colon()))) + else + @assert sc.mpi_rank isa Integer "Expected a single GPU device ID for :mpi_rank, got $(sc.mpi_rank)\nConsider using :mpi_ranks instead." + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=[sc.mpi_rank]))) + end +end +Dagger.scope_key_precedence(::Val{:mpi_rank}) = 2 +function Dagger.to_scope(::Val{:mpi_ranks}, sc::NamedTuple) + comm = get(sc, :mpi_comm, MPI.COMM_WORLD) + if sc.ranks != Colon() + ranks = sc.ranks + else + ranks = MPI.Comm_size(comm) + end + inner_sc = NamedTuple(filter(kv->kv[1] != :mpi_ranks, Base.pairs(sc))...) + # FIXME: What to do here? + inner_scope = Dagger.to_scope(inner_sc) + scopes = Dagger.ExactScope[] + for rank in ranks + procs = Dagger.get_processors(Dagger.MPIOSProc(comm, rank)) + rank_scope = MPIProcessScope(comm, rank) + for proc in procs + proc_scope = Dagger.ExactScope(proc) + constrain(proc_scope, rank_scope) isa Dagger.InvalidScope && continue + push!(scopes, proc_scope) + end + end + return Dagger.UnionScope(scopes) +end +Dagger.scope_key_precedence(::Val{:mpi_ranks}) = 2 + +struct MPIProcessor{P<:Processor} <: Processor + innerProc::P + comm::MPI.Comm + rank::Int +end +proc_in_scope(proc::Processor, scope::MPIProcessScope) = false +proc_in_scope(proc::MPIProcessor, scope::MPIProcessScope) = + proc.comm == scope.comm && proc.rank == scope.rank + +function check_uniform(proc::MPIProcessor, original=proc) + return check_uniform(hash(MPIProcessor), original) && + check_uniform(proc.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(proc.innerProc), original) +end + +Dagger.iscompatible_func(::MPIProcessor, opts, ::Any) = true +Dagger.iscompatible_arg(::MPIProcessor, opts, ::Any) = true + +default_enabled(proc::MPIProcessor) = default_enabled(proc.innerProc) + +root_worker_id(proc::MPIProcessor) = myid() +root_worker_id(proc::MPIOSProc) = myid() +root_worker_id(proc::MPIClusterProc) = myid() + +get_parent(proc::MPIClusterProc) = proc +get_parent(proc::MPIOSProc) = MPIClusterProc(proc.comm) +get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) + +short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" + +function get_processors(mosProc::MPIOSProc) + populate_children!(mosProc.comm) + children = MPIClusterProcChildren[mosProc.comm] + mpiProcs = Set{Processor}() + for proc in children + push!(mpiProcs, MPIProcessor(proc, mosProc.comm, mosProc.rank)) + end + return mpiProcs +end + +#TODO: non-uniform ranking through MPI groups +#TODO: use a lazy iterator +function get_processors(proc::MPIClusterProc) + children = Set{Processor}() + for i in 0:(MPI.Comm_size(proc.comm)-1) + for innerProc in MPIClusterProcChildren[proc.comm] + push!(children, MPIProcessor(innerProc, proc.comm, i)) + end + end + return children +end + +struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace + innerSpace::S + comm::MPI.Comm + rank::Int +end + +function check_uniform(space::MPIMemorySpace, original=space) + return check_uniform(space.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(space.innerSpace), original) +end + +default_processor(space::MPIMemorySpace) = MPIOSProc(space.comm, space.rank) +default_memory_space(accel::MPIAcceleration) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) + +default_memory_space(accel::MPIAcceleration, x) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) +default_memory_space(accel::MPIAcceleration, x::Chunk) = MPIMemorySpace(CPURAMMemorySpace(myid()), x.handle.comm, x.handle.rank) +default_memory_space(accel::MPIAcceleration, x::Function) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) +default_memory_space(accel::MPIAcceleration, T::Type) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) + +function memory_spaces(proc::MPIClusterProc) + rawMemSpace = Set{MemorySpace}() + for rnk in 0:(MPI.Comm_size(proc.comm) - 1) + for innerSpace in memory_spaces(OSProc()) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, rnk)) + end + end + return rawMemSpace +end + +function memory_spaces(proc::MPIProcessor) + rawMemSpace = Set{MemorySpace}() + for innerSpace in memory_spaces(proc.innerProc) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, proc.rank)) + end + return rawMemSpace +end + +root_worker_id(mem_space::MPIMemorySpace) = myid() + +function processors(memSpace::MPIMemorySpace) + rawProc = Set{Processor}() + for innerProc in processors(memSpace.innerSpace) + push!(rawProc, MPIProcessor(innerProc, memSpace.comm, memSpace.rank)) + end + return rawProc +end + +struct MPIRefID + tid::UInt32 + generic::Bool + id::UInt32 + function MPIRefID(tid, generic, id) + @assert tid > 0 || generic "Invalid MPIRefID: tid=$tid, generic=$generic, id=$id" + return new(tid, generic, id) + end +end +Base.hash(id::MPIRefID, h::UInt=UInt(0)) = + hash(id.tid, hash(id.generic, hash(id.id, hash(MPIRefID, h)))) + +function check_uniform(ref::MPIRefID, original=ref) + return check_uniform(ref.tid, original) && + check_uniform(ref.generic, original) && + check_uniform(ref.id, original) +end + +function to_tag() + if Dagger.in_task() + # Tag is already assigned + opts = Dagger.get_tls().task_spec.options + tag = opts.tag + return tag + end + + # Generate a tag based on the TID + @assert !Sch.SCHED_MOVE[] "We should not create a tag during Sch move" + return to_tag(take_ref_id!()) +end +to_tag(id::MPIRefID) = id.generic ? id.id : id.tid + +# Semi-public internal value for passing TID to MPIRefID generation +const MPI_TID = ScopedValue{Int64}(0) +# Private internal value for tracking TID-based ID generations +#const _MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +# Private internal value for tracking non-TID (uniform) ID generations +#const _MPIREF_GENERIC = Threads.Atomic{Int}(1) + +mutable struct MPIRef + comm::MPI.Comm + rank::Int + size::Int + innerRef::Union{DRef, Nothing} + id::MPIRefID +end +Base.hash(ref::MPIRef, h::UInt=UInt(0)) = hash(ref.id, hash(MPIRef, h)) +root_worker_id(ref::MPIRef) = myid() + +function check_uniform(ref::MPIRef, original=ref) + return check_uniform(ref.rank, original) && + check_uniform(ref.id, original) +end + +function unwrap(handle::MPIRef) + @assert handle.rank == MPI.Comm_rank(handle.comm) "MPIRef $handle is not owned by this rank: $(handle.rank) != $(MPI.Comm_rank(handle.comm))" + return unwrap(handle.innerRef) +end + +to_tag(ref::MPIRef) = to_tag(ref.id) + +move(from_proc::Processor, to_proc::Processor, x::MPIRef) = + move(from_proc, to_proc, poolget(x; uniform=uniform_execution())) + +function affinity(x::MPIRef) + if x.innerRef === nothing + return MPIOSProc(x.comm, x.rank)=>0 + else + return MPIOSProc(x.comm, x.rank)=>x.innerRef.size + end +end + +function take_ref_id!() + tid = 0 + generic = 0 + id = 0 + if Dagger.in_task() + tid = sch_handle().thunk_id.id + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid + elseif MPI_TID[] != 0 + tid = MPI_TID[] + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid + else + if current_task() !== Base.roottask + throw(ConcurrencyViolationError("Attempted to generate generic MPIRefID in a multi-threaded context")) + end + generic = true + #id = Threads.atomic_add!(_MPIREF_GENERIC, 1) + id = next_id() # Abuse the TID counter for generic IDs + check_uniform(id) + end + @assert id < MPI.tag_ub() + return MPIRefID(tid, generic, id) +end + +#TODO: partitioned scheduling with comm bifurcation +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, kwargs...) + @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" + local_rank = MPI.Comm_rank(space.comm) + Mid = take_ref_id!() + if local_rank != space.rank + return MPIRef(space.comm, space.rank, 0, nothing, Mid) + else + # type= is for Chunk metadata only; MemPool.poolset does not accept it + pset_kw = (; (k => v for (k, v) in pairs(kwargs) if k !== :type)...) + return MPIRef(space.comm, space.rank, sizeof(x), poolset(x; device, pset_kw...), Mid) + end +end + +const DEADLOCK_DETECT = TaskLocalValue{Bool}(()->true) +const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) +const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->120.0) +const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) + +@warn "Rename and make generic these in-place structs" maxlog=1 +struct InplaceInfo + type::DataType + shape::Tuple +end +struct InplaceSparseInfo + type::DataType + m::Int + n::Int + colptr::Int + rowval::Int + nzval::Int +end + +function supports_inplace_mpi(value) + if value isa DenseArray && isbitstype(eltype(value)) + return true + else + return false + end +end +function recv_yield!(buffer, comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") + if !supports_inplace_mpi(buffer) + return recv_yield(comm, src, tag), false + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + + buffer = recv_yield_inplace!(buffer, comm, rank, src, tag) + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + + return buffer, true + +end + +function recv_yield(comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Receiving...") + + type = nothing + @label receive + value = recv_yield_serialized(comm, rank, src, tag) + if value isa InplaceInfo || value isa InplaceSparseInfo + value = recv_yield_inplace(value, comm, rank, src, tag) + end + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + return value +end + +function recv_yield_inplace!(array, comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" + buf = MPI.Buffer(array) + req = MPI.Imrecv!(buf, msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return array + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" + array = Array{eltype(T)}(undef, _value.shape) + return recv_yield_inplace!(array, comm, my_rank, their_rank, tag) +end + +function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC" + + colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag) + rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag) + nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag) + + return SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval) +end + +function recv_yield_serialized(comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return MPI.deserialize(buf) + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +const SEEN_TAGS = Dict{Int32, Type}() +send_yield!(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=true) +send_yield(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=false) +function _send_yield(value, comm, dest, tag; inplace::Bool) + rank = MPI.Comm_rank(comm) + + #= + if CHECK_UNIFORMITY[] && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) + @error "[rank $(MPI.Comm_rank(comm))][tag $tag] Already seen tag (previous type: $(SEEN_TAGS[tag]), new type: $(typeof(value)))" exception=(InterruptException(),backtrace()) + end + if CHECK_UNIFORMITY[] + SEEN_TAGS[tag] = typeof(value) + end + =# + + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting send to [$dest]: $(typeof(value)), is support inplace? $(supports_inplace_mpi(value))") + if inplace && supports_inplace_mpi(value) + send_yield_inplace(value, comm, rank, dest, tag) + else + send_yield_serialized(value, comm, rank, dest, tag) + end +end + +function send_yield_inplace(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_inplace + req = MPI.Isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") +end + +function send_yield_serialized(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_serialized + if value isa Array && isbitstype(eltype(value)) + send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag) + send_yield_inplace(value, comm, my_rank, their_rank, tag) + elseif value isa SparseMatrixCSC && isbitstype(eltype(value)) + send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag) + send_yield_inplace(value.colptr, comm, my_rank, their_rank, tag) + send_yield_inplace(value.rowval, comm, my_rank, their_rank, tag) + send_yield_inplace(value.nzval, comm, my_rank, their_rank, tag) + else + req = MPI.isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") + end +end + +function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + while true + finish, status = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(status) != MPI.SUCCESS + error("$fn failed with error $(MPI.Get_error(status))") + end + return + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, kind, their_rank) + yield() + end +end + +function bcast_send_yield(value, comm, root, tag) + @opcounter :bcast_send_yield + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + for other_rank in 0:(sz-1) + rank == other_rank && continue + send_yield(value, comm, other_rank, tag) + end +end + +#= Maybe can be worth it to implement this +function bcast_send_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + + for other_rank in 0:(sz-1) + rank == other_rank && continue + #println("[rank $rank] Sending to rank $other_rank") + send_yield!(value, comm, other_rank, tag) + end +end + +function bcast_recv_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + #println("[rank $rank] receive from rank $root") + recv_yield!(value, comm, root, tag) +end +=# +function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest) + time_elapsed = (time_ns() - time_start) + if detect && time_elapsed > warn_period + @warn "[rank $rank][tag $tag] Hit probable hang on $kind (dest: $srcdest)" + return typemax(UInt64) + end + if detect && time_elapsed > timeout_period + error("[rank $rank][tag $tag] Hit hang on $kind (dest: $srcdest)") + end + return warn_period +end + +#discuss this with julian +@warn "Fix this WeakChunk method" maxlog=1 +WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) + +function MemPool.poolget(ref::MPIRef; uniform::Bool=uniform_execution()) + @assert uniform || ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch: $(ref.rank) != $(MPI.Comm_rank(ref.comm))" + if uniform + tag = to_tag() + if ref.rank == MPI.Comm_rank(ref.comm) + value = poolget(ref.innerRef) + bcast_send_yield(value, ref.comm, ref.rank, tag) + return value + else + return recv_yield(ref.comm, ref.rank, tag) + end + else + return poolget(ref.innerRef) + end +end +fetch_handle(ref::MPIRef; uniform::Bool=uniform_execution()) = poolget(ref; uniform) + +function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + tag = to_tag(from.handle) + if local_rank == from_space.rank + send_yield!(poolget(from.handle; uniform=false), to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + #@dagdebug nothing :mpi "[$local_rank][$tag] Receiving from rank $(from_space.rank) with tag $tag, type of buffer: $(typeof(poolget(to.handle; uniform=false)))" + to_val = poolget(to.handle; uniform=false) + val, inplace = recv_yield!(to_val, from_space.comm, from_space.rank, tag) + if !inplace + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to_val, val) + end + end + end + @dagdebug nothing :mpi "[$local_rank][$tag] Finished moving from $(from_space.rank) to $(to_space.rank) successfuly\n" +end +function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + tag = to_tag(from.handle) + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + if local_rank == from_space.rank + # Get the source data for each span + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + for (from_span, _) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len + #end + end + + # Send the spans + #send_yield(len, to_space.comm, to_space.rank, tag) + send_yield!(copies, to_space.comm, to_space.rank, tag) + #send_yield(copies, to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + # Receive the spans + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + recv_yield!(copies, from_space.comm, from_space.rank, tag) + #copies = recv_yield(from_space.comm, from_space.rank, tag) + + # Copy the data into the destination object + #for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + offset = 1 + for (_, to_span) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copies, offset)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len + #end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + end + end + + return +end + + +move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +#TODO: out of place MPI move +function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) + @assert src.comm == dst.comm "Multi comm move not supported" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permited" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +#= +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) + loc_rank = MPI.Comm_rank(accel.comm) + if loc_rank == from_proc.rank + # FIXME: Descend via move_rewrap, and send data to to_proc + elseif loc_rank == to_proc.rank + # FIXME: Listen for data from from_proc to locally wrap as Chunk + while true + value = recv_yield(accel.comm, from_proc.rank, tag) + end + bcast_recv_yield(data_new, accel.comm, to_proc.rank, tag) + else + # Wait for final Chunk + return recv_yield(accel.comm, to_proc.rank, tag) + end +end +function remotecall_endpoint_transfer(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + loc_rank = MPI.Comm_rank(accel.comm) + if loc_rank == from_proc.rank + elseif loc_rank == to_proc.rank + end +end +=# +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) + loc_rank = MPI.Comm_rank(accel.comm) + task = DATADEPS_CURRENT_TASK[] + return with(MPI_UID=>task.uid) do + space = memory_space(data) + tag = to_tag() + T = move_type(from_proc.innerProc, to_proc.innerProc, chunktype(data)) + T_new = f !== identity ? Base._return_type(f, Tuple{T}) : T + need_bcast = !isconcretetype(T_new) || T_new === Union{} || T_new === Nothing || T_new === Any + + if space.rank != from_proc.rank + # Data is already at destination (to_proc.rank) + @assert space.rank == to_proc.rank + if space.rank == loc_rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end + + # Data is on the source rank + @assert space.rank == from_proc.rank + if loc_rank == from_proc.rank == to_proc.rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + return tochunk(data_converted, to_proc, to_space; type=typeof(data_converted)) + end + + if loc_rank == from_proc.rank + value = poolget(data.handle) + data_moved = move(from_proc.innerProc, to_proc.innerProc, value) + Dagger.send_yield(data_moved, accel.comm, to_proc.rank, tag) + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + elseif loc_rank == to_proc.rank + data_moved = Dagger.recv_yield(accel.comm, from_space.rank, tag) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, data_moved)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end +end + +# Chunk may be MPI-backed (MPIRef) but labeled with OSProc; treat source as the owning rank +function move(src::OSProc, dst::MPIProcessor, x::Chunk) + if x.handle isa MPIRef + return move(MPIOSProc(x.handle.comm, x.handle.rank), dst, x) + end + error("MPI move not supported") +end + +move(src::Processor, dst::MPIProcessor, x::Chunk) = error("MPI move not supported") +move(to_proc::MPIProcessor, chunk::Chunk) = + move(chunk.processor, to_proc, chunk) +move(to_proc::Processor, d::MPIRef) = + move(MPIOSProc(d.rank), to_proc, d) +move(to_proc::MPIProcessor, x) = + move(MPIOSProc(), to_proc, x) + +move(::MPIProcessor, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +@warn "Is this uniform logic valuable to have?" maxlog=1 +function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) + uniform = uniform_execution() + @assert uniform || src.rank == dst.rank "Unwrapping not permitted" + if Sch.SCHED_MOVE[] + # We can either unwrap locally, or return nothing + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + # Either we're uniform (so everyone cooperates), or we're unwrapping locally + if !uniform + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permitted" + @assert src.rank == x.handle.rank == dst.rank + end + return poolget(x.handle; uniform) + end +end + + +#FIXME:try to think of a better move! scheme +function execute!(proc::MPIProcessor, f, args...; kwargs...) + local_rank = MPI.Comm_rank(proc.comm) + islocal = local_rank == proc.rank + inplace_move = f === move! + result = nothing + tag = to_tag() + + if islocal || inplace_move + result = execute!(proc.innerProc, f, args...; kwargs...) + end + + if inplace_move + space = memory_space(nothing, proc)::MPIMemorySpace + dest_type = chunktype(args[4]) + return tochunk(nothing, proc, space; type=dest_type) + end + + # Infer return type; only bcast when inference is not concrete + fname = nameof(f) + arg_types = map(chunktype, args) + inferred_type = Base.promote_op(f, arg_types...) + + need_bcast = !isconcretetype(inferred_type) || inferred_type === Union{} || inferred_type === Nothing || inferred_type === Any + + if islocal + T = typeof(result) + space = memory_space(result, proc)::MPIMemorySpace + if need_bcast + @opcounter :execute_bcast_send_yield + bcast_send_yield((T, space.innerSpace), proc.comm, proc.rank, tag) + end + return tochunk(result, proc, space; type=T) + else + if need_bcast + T, innerSpace = recv_yield(proc.comm, proc.rank, tag) + space = MPIMemorySpace(innerSpace, proc.comm, proc.rank) + return tochunk(nothing, proc, space; type=T) + else + space = memory_space(nothing, proc)::MPIMemorySpace + return tochunk(nothing, proc, space; type=inferred_type) + end + end +end + +accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) + +function initialize_acceleration!(a::MPIAcceleration) + if !MPI.Initialized() + MPI.Init(;threadlevel=:multiple) + end + ctx = Dagger.Sch.eager_context() + sz = MPI.Comm_size(a.comm) + for i in 0:(sz-1) + push!(ctx.procs, MPIOSProc(a.comm, i)) + end + unique!(ctx.procs) +end + +""" + mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + +Ensure all ranks use the same concrete type for the given tasks by setting +each task's options.return_type to expected_type when it is concrete. +This allows chunktype(task) to return the concrete type on every rank +without an MPI allgather of actual result types. +""" +function mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + isconcretetype(expected_type) || return + for t in tasks + if t isa Thunk + if t.options !== nothing + t.options.return_type = expected_type + else + t.options = Options(return_type=expected_type) + end + end + end + return +end + +accel_matches_proc(accel::MPIAcceleration, proc::MPIOSProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIClusterProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIProcessor) = true +accel_matches_proc(accel::MPIAcceleration, proc) = false + +function distribute(accel::MPIAcceleration, A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} + comm = accel.comm + rank = MPI.Comm_rank(comm) + + DA = view(A, dist) + DB = DArray{T,N}(undef, dist, size(A)) + copyto!(DB, DA) + + return DB +end diff --git a/src/mpi_mempool.jl b/src/mpi_mempool.jl new file mode 100644 index 000000000..149c7900a --- /dev/null +++ b/src/mpi_mempool.jl @@ -0,0 +1,36 @@ +# Mempool for received MPI message data only (no envelopes). +# Key: (comm, source, tag). Used when a message is received but not the one the caller was waiting for. +# Included from mpi.jl; runs in Dagger module scope. + +const MPI_RECV_MEMPOOL = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Vector{Any}}()) + +function mpi_mempool_put!(comm::MPI.Comm, source::Integer, tag::Integer, data::Any) + key = (comm, Int(source), Int(tag)) + ref = poolset(data) + lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) + pool[key] = Any[] + end + push!(pool[key], ref) + end + return nothing +end + +function mpi_mempool_take!(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + ref = lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) || isempty(pool[key]) + return nothing + end + popfirst!(pool[key]) + end + ref === nothing && return nothing + return poolget(ref) +end + +function mpi_mempool_has(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + return lock(MPI_RECV_MEMPOOL) do pool + haskey(pool, key) && !isempty(pool[key]) + end +end diff --git a/src/mutable.jl b/src/mutable.jl new file mode 100644 index 000000000..1f48ead53 --- /dev/null +++ b/src/mutable.jl @@ -0,0 +1,41 @@ +function _mutable_inner(@nospecialize(f), proc, scope) + result = f() + return Ref(Dagger.tochunk(result, proc, scope)) +end + +""" + mutable(f::Base.Callable; worker, processor, scope) -> Chunk + +Calls `f()` on the specified worker or processor, returning a `Chunk` +referencing the result with the specified scope `scope`. +""" +function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) + if processor === nothing + if worker === nothing + processor = OSProc() + else + processor = OSProc(worker) + end + else + @assert worker === nothing "mutable: Can't mix worker and processor" + end + if scope === nothing + scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) + end + return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] +end + +""" + @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() + +Helper macro for [`mutable()`](@ref). +""" +macro mutable(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $mutable(f; $(opts...)) + end + end +end diff --git a/src/options.jl b/src/options.jl index eca59fbc9..09067da51 100644 --- a/src/options.jl +++ b/src/options.jl @@ -26,6 +26,7 @@ Stores per-task options to be passed to the scheduler. - `storage_leaf_tag::Union{MemPool.Tag,Nothing}=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." - `storage_retain::Union{Bool,Nothing}=nothing`: The value of `retain` to pass to `MemPool.poolset` when constructing the result `Chunk`. `nothing` defaults to `false`. - `name::Union{String,Nothing}=nothing`: If not `nothing`, annotates the task with a name for logging purposes. +- `tag::Union{UInt32,Nothing}=nothing`: (Data-deps/MPI) MPI message tag for this task; assigned automatically if `nothing`. - `stream_input_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the input buffer of the task. Defaults to 1. - `stream_output_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the output buffer of the task. Defaults to 1. - `stream_buffer_type::Union{Type,Nothing}=nothing`: (Streaming only) Specifies the type of buffer to use for the input and output buffers of the task. Defaults to `Dagger.ProcessRingBuffer`. @@ -61,10 +62,16 @@ Base.@kwdef mutable struct Options name::Union{String,Nothing} = nothing + tag::Union{UInt32,Nothing} = nothing + stream_input_buffer_amount::Union{Int,Nothing} = nothing stream_output_buffer_amount::Union{Int,Nothing} = nothing stream_buffer_type::Union{Type, Nothing} = nothing stream_max_evals::Union{Int,Nothing} = nothing + + acceleration::Union{Acceleration,Nothing} = nothing + + return_type::Union{Type,Nothing} = nothing end Options(::Nothing) = Options() function Options(old_options::NamedTuple) diff --git a/src/processor.jl b/src/processor.jl index ac2e74f14..4944dc083 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -2,16 +2,6 @@ export OSProc, Context, addprocs!, rmprocs! import Base: @invokelatest -""" - Processor - -An abstract type representing a processing device and associated memory, where -data can be stored and operated on. Subtypes should be immutable, and -instances should compare equal if they represent the same logical processing -device/memory. Subtype instances should be serializable between different -nodes. Subtype instances may contain a "parent" `Processor` to make it easy to -transfer data to/from other types of `Processor` at runtime. -""" abstract type Processor end const PROCESSOR_CALLBACKS = Dict{Symbol,Any}() @@ -150,3 +140,20 @@ iscompatible_arg(proc::OSProc, opts, args...) = "Returns a very brief `String` representation of `proc`." short_name(proc::Processor) = string(proc) short_name(p::OSProc) = "W: $(p.pid)" + +"Returns true if the processor is on the local worker (for MPI/ordering)." +is_local_processor(proc::Processor) = (root_worker_id(proc) == myid()) + +"Ordering key for task firing (used by MPI to avoid deadlock)." +fire_order_key(proc::Processor) = (root_worker_id(proc), 0) + +@doc """ + Processor + +An abstract type representing a processing device and associated memory, where +data can be stored and operated on. Subtypes should be immutable, and +instances should compare equal if they represent the same logical processing +device/memory. Subtype instances should be serializable between different +nodes. Subtype instances may contain a "parent" `Processor` to make it easy to +transfer data to/from other types of `Processor` at runtime. +""" Processor diff --git a/src/queue.jl b/src/queue.jl index 37947a0ac..c1e264c06 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -125,7 +125,7 @@ function wait_all(f; check_errors::Bool=false) result = with_options(f; task_queue=queue) for task in queue.tasks if check_errors - fetch(task; raw=true) + fetch(task; move_value=false, unwrap=false) else wait(task) end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 58aed6dc5..3c8353c59 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -15,7 +15,7 @@ import Base: @invokelatest import ..Dagger import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, ThunkID, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, root_worker_id, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc!, is_local_processor, fire_order_key, short_name import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek @@ -25,7 +25,7 @@ import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks, @ import TimespanLogging import TaskLocalValues: TaskLocalValue -import ScopedValues: @with +import ScopedValues: ScopedValue, @with, with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -56,7 +56,7 @@ Fields: - `cache::WeakKeyDict{Thunk, Any}` - Maps from a finished `Thunk` to it's cached result, often a DRef - `valid::WeakKeyDict{Thunk, Nothing}` - Tracks all `Thunk`s that are in a valid scheduling state - `running::Set{Thunk}` - The set of currently-running `Thunk`s -- `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it +- `running_on::Dict{Thunk,Processor}` - Map from `Thunk` to the OS process executing it - `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk` - `node_order::Any` - Function that returns the order of a thunk - `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it @@ -82,18 +82,18 @@ struct ComputeState ready::Vector{Thunk} valid::Dict{Thunk, Nothing} running::Set{Thunk} - running_on::Dict{Thunk,OSProc} + running_on::Dict{Thunk,Processor} thunk_dict::Dict{Int, WeakThunk} node_order::Any - equiv_chunks::WeakKeyDict{DRef,Chunk} - worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} - worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_loadavg::Dict{Int,NTuple{3,Float64}} - worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} + equiv_chunks::WeakKeyDict{Any,Chunk} + worker_time_pressure::Dict{Processor,Dict{Processor,UInt64}} + worker_storage_pressure::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_storage_capacity::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_loadavg::Dict{Processor,NTuple{3,Float64}} + worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} - worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}} + worker_transfer_rate::Dict{Processor,Dict{Processor,UInt64}} halt::Base.Event lock::ReentrantLock futures::Dict{Thunk, Vector{ThunkFuture}} @@ -111,18 +111,18 @@ function start_state(deps::Dict, node_order, chan) Vector{Thunk}(undef, 0), Dict{Thunk, Nothing}(), Set{Thunk}(), - Dict{Thunk,OSProc}(), + Dict{Thunk,Processor}(), Dict{Int, WeakThunk}(), node_order, - WeakKeyDict{DRef,Chunk}(), - Dict{Int,Dict{Processor,UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,NTuple{3,Float64}}(), - Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(), + WeakKeyDict{Any,Chunk}(), + Dict{Processor,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,NTuple{3,Float64}}(), + Dict{Processor,Tuple{RemoteChannel,RemoteChannel}}(), Dict{Signature,UInt64}(), Dict{Signature,UInt64}(), - Dict{Int,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Processor,UInt64}}(), Base.Event(), ReentrantLock(), Dict{Thunk, Vector{ThunkFuture}}(), @@ -152,30 +152,29 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + pid = Dagger.root_worker_id(p) + @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) # Initialize pressure and capacity - gproc = OSProc(p.pid) lock(state.lock) do - state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}() - state.worker_transfer_rate[p.pid] = Dict{Processor,UInt64}() - - state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() - state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_transfer_rate[p] = Dict{Processor,UInt64}() + state.worker_time_pressure[p] = Dict{Processor,UInt64}() + state.worker_storage_pressure[p] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_capacity[p] = Dict{Union{StorageResource,Nothing},UInt64}() #= FIXME for storage in get_storage_resources(gproc) - pressure, capacity = remotecall_fetch(gproc.pid, storage) do storage + pressure, capacity = remotecall_fetch(root_worker_id(gproc), storage) do storage storage_pressure(storage), storage_capacity(storage) end - state.worker_storage_pressure[p.pid][storage] = pressure - state.worker_storage_capacity[p.pid][storage] = capacity + state.worker_storage_pressure[p][storage] = pressure + state.worker_storage_capacity[p][storage] = capacity end =# - state.worker_loadavg[p.pid] = (0.0, 0.0, 0.0) + state.worker_loadavg[p] = (0.0, 0.0, 0.0) end - if p.pid != 1 + if pid != 1 lock(WORKER_MONITOR_LOCK) do - wid = p.pid + wid = pid if !haskey(WORKER_MONITOR_TASKS, wid) t = Threads.@spawn begin try @@ -209,16 +208,16 @@ function init_proc(state, p, log_sink) end # Setup worker-to-scheduler channels - inp_chan = RemoteChannel(p.pid) - out_chan = RemoteChannel(p.pid) + inp_chan = RemoteChannel(pid) + out_chan = RemoteChannel(pid) lock(state.lock) do - state.worker_chans[p.pid] = (inp_chan, out_chan) + state.worker_chans[pid] = (inp_chan, out_chan) end # Setup dynamic listener - dynamic_listener!(ctx, state, p.pid) + dynamic_listener!(ctx, state, pid) - @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -236,7 +235,7 @@ function _cleanup_proc(uid, log_sink) end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - wid = p.pid + wid = root_worker_id(p) @maybelog ctx timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) @@ -299,7 +298,7 @@ function compute_dag(ctx::Context, d::Thunk, options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order, chan) - master = OSProc(myid()) + master = Dagger.default_processor() @maybelog ctx timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try @@ -394,8 +393,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt res = tresult.result @dagdebug thunk_id :take "Got finished task" - gproc = OSProc(pid) safepoint(state) + gproc = proc != nothing ? get_parent(proc) : OSProc(pid) lock(state.lock) do thunk_failed = false if res isa Exception @@ -422,11 +421,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk metadata = tresult.metadata if metadata !== nothing - state.worker_time_pressure[pid][proc] = metadata.time_pressure + state.worker_time_pressure[gproc][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - #state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[gproc] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -440,8 +439,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt end end if res isa Chunk - if !haskey(state.equiv_chunks, res) - state.equiv_chunks[res.handle::DRef] = res + if !haskey(state.equiv_chunks, res.handle) + state.equiv_chunks[res.handle] = res end end store_result!(state, node, res; error=thunk_failed) @@ -528,7 +527,7 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() struct ScheduleTaskLocation - gproc::OSProc + gproc::Processor proc::Processor end struct ScheduleTaskSpec @@ -538,6 +537,25 @@ struct ScheduleTaskSpec est_alloc_util::UInt64 est_occupancy::UInt32 end + +"Ordering key for task locations when using MPI acceleration (deterministic across ranks)." +function _mpi_fire_order_key(loc::ScheduleTaskLocation) + g = loc.gproc + p = loc.proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + +"Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)." +function _mpi_proc_rank(proc::Processor) + g = get_parent(proc) + p = proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + @reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)) lock(state.lock) do safepoint(state) @@ -552,6 +570,7 @@ end to_fire_cleanup = @reuse_defer_cleanup empty!(to_fire) failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 failed_scheduling_cleanup = @reuse_defer_cleanup empty!(failed_scheduling) + # Select a new task and get its options task = nothing @label pop_task @@ -626,9 +645,9 @@ end end @label scope_computed - input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + input_procs = @reusable_vector :schedule!_input_procs Union{Processor,Nothing} nothing 32 input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) - for proc in Dagger.compatible_processors(scope, procs) + for proc in Dagger.compatible_processors(options.acceleration, scope, procs) if !(proc in input_procs) push!(input_procs, proc) end @@ -660,7 +679,7 @@ end can_use, scope = can_use_proc(state, task, gproc, proc, options, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, options.time_util, options.alloc_util, options.occupancy, sig) + has_capacity(state, proc, gproc, options.time_util, options.alloc_util, options.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util @@ -669,10 +688,10 @@ end Vector{ScheduleTaskSpec}() end push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) - state.worker_time_pressure[gproc.pid][proc] = - get(state.worker_time_pressure[gproc.pid], proc, 0) + + state.worker_time_pressure[gproc][proc] = + get(state.worker_time_pressure[gproc], proc, 0) + est_time_util - @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc.pid][proc]))" + @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc][proc]))" sorted_procs_cleanup() costs_cleanup() @goto pop_task @@ -739,14 +758,14 @@ function monitor_procs_changed!(ctx, state, options) end function remove_dead_proc!(ctx, state, proc, options) - @assert options.single !== proc.pid "Single worker failed, cannot continue." + @assert options.single !== root_worker_id(proc) "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) - delete!(state.worker_time_pressure, proc.pid) - delete!(state.worker_transfer_rate, proc.pid) - delete!(state.worker_storage_pressure, proc.pid) - delete!(state.worker_storage_capacity, proc.pid) - delete!(state.worker_loadavg, proc.pid) - delete!(state.worker_chans, proc.pid) + delete!(state.worker_transfer_rate, proc) + delete!(state.worker_time_pressure, proc) + delete!(state.worker_storage_pressure, proc) + delete!(state.worker_storage_capacity, proc) + delete!(state.worker_loadavg, proc) + delete!(state.worker_chans, root_worker_id(proc)) end function finish_task!(ctx, state, node, thunk_failed) @@ -789,7 +808,7 @@ end function evict_all_chunks!(ctx, options, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx, options)) + @sync for w in map(p->root_worker_id(p), procs_to_use(ctx, options)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -860,9 +879,10 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end Tf = chunktype(first(args)) - @assert (options.single === nothing) || (gproc.pid == options.single) + pid = root_worker_id(gproc) + @assert (options.single === nothing) || (pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` - sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) + sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) push!(to_send, TaskSpec( @@ -874,7 +894,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end if !isempty(to_send) - if Dagger.root_worker_id(gproc) == myid() + if root_worker_id(gproc) == myid() @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, to_send) else # N.B. We don't batch these because we might get a deserialization @@ -1080,7 +1100,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re proc_occupancy = istate.proc_occupancy time_pressure = istate.time_pressure - wid = get_parent(to_proc).pid + wid = root_worker_id(to_proc) work_to_do = false while isopen(return_queue) # Wait for new tasks @@ -1138,7 +1158,6 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors states = proc_states_values(uid) - # TODO: Try to pre-allocate this P = randperm(length(states)) for state in getindex.(Ref(states), P) other_istate = state.state @@ -1155,7 +1174,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - if Dagger.proc_in_scope(to_proc, scope) + accel = something(task.options.acceleration, Dagger.DistributedAcceleration()) + if Dagger.proc_in_scope(to_proc, scope) && Dagger.accel_matches_proc(accel, to_proc) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task return dequeue_pair!(queue) @@ -1362,6 +1382,8 @@ function do_tasks(to_proc, return_queue, tasks) @dagdebug nothing :processor "Kicked processors" end +const SCHED_MOVE = ScopedValue{Bool}(false) + """ do_task(to_proc, task::TaskSpec) -> Any @@ -1373,13 +1395,15 @@ Executes a single task specified by `task` on `to_proc`. ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - from_proc = OSProc() + options = task.options + Dagger.accelerate!(options.acceleration) + + from_proc = Dagger.default_processor() data = task.data Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available - options = task.options propagated = get_propagated_options(options) to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] #to_storage_name = nameof(typeof(to_storage)) @@ -1447,7 +1471,7 @@ Executes a single task specified by `task` on `to_proc`. @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) =# - @dagdebug thunk_id :execute "Moving data" + @dagdebug thunk_id :execute "Moving data for $Tf" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) @@ -1470,7 +1494,9 @@ Executes a single task specified by `task` on `to_proc`. Some{Any}(get!(CHUNK_CACHE[x], to_proc) do # Convert from cached value # TODO: Choose "closest" processor of same type first - some_proc = first(keys(CHUNK_CACHE[x])) + cache_procs = keys(CHUNK_CACHE[x]) + some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? + minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) some_x = CHUNK_CACHE[x][some_proc] @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x" @invokelatest move(some_proc, to_proc, some_x) @@ -1505,13 +1531,23 @@ Executes a single task specified by `task` on `to_proc`. end else =# - new_value = @invokelatest move(to_proc, value) + new_value = with(SCHED_MOVE=>true) do + @invokelatest move(to_proc, value) + end #end - if new_value !== value - @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + # Preserve Chunk reference when move returns nothing (placeholder on this rank). This keeps + # type information correct at all ranks: chunktype(Chunk) is concrete even when Chunk holds no data. + # So execute! sees correct arg_types. Materializing the value (for the kernel) must happen in + # execute! and may require lazy recv from the executor if this rank has a placeholder. + if new_value === nothing && (value isa Dagger.Chunk || value isa Dagger.WeakChunk) + arg.value = value + else + if new_value !== value + @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + end + arg.value = new_value end - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) - arg.value = new_value + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=Dagger.value(arg)); tasks=[Base.current_task()]) return end end @@ -1550,7 +1586,7 @@ Executes a single task specified by `task` on `to_proc`. # FIXME #gcnum_start = Base.gc_num() - @dagdebug thunk_id :execute "Executing $(typeof(f))" + @dagdebug thunk_id :execute "Executing $Tf" logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) @@ -1613,7 +1649,7 @@ Executes a single task specified by `task` on `to_proc`. notify(TASK_SYNC) end - @dagdebug thunk_id :execute "Returning" + @dagdebug thunk_id :execute "Returning $Tf with $(typeof(result_meta))" # TODO: debug_storage("Releasing $to_storage_name") metadata = ( diff --git a/src/sch/util.jl b/src/sch/util.jl index 11706382a..164685195 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -383,7 +383,7 @@ function signature(f, args) value = Dagger.value(arg) if value isa Dagger.DTask # Only occurs via manual usage of signature - value = fetch(value; raw=true) + value = fetch(value; move_value=false, unwrap=false) end if istask(value) throw(ConcurrencyViolationError("Must call `collect_task_inputs!(state, task)` before calling `signature`")) @@ -440,8 +440,8 @@ function can_use_proc(state, task, gproc, proc, opts, scope) # Check against single if opts.single !== nothing @warn "The `single` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 - if gproc.pid != opts.single - @dagdebug task :scope "Rejected $proc: gproc.pid ($(gproc.pid)) != single ($(opts.single))" + if root_worker_id(gproc) != opts.single + @dagdebug task :scope "Rejected $proc: gproc root_worker_id ($(root_worker_id(gproc))) != single ($(opts.single))" return false, scope end scope = constrain(scope, Dagger.ProcessScope(opts.single)) @@ -593,9 +593,10 @@ const DEFAULT_TRANSFER_RATE = UInt64(1_000_000) # Add fixed cost for cross-worker task transfer (esimated at 1ms) # TODO: Actually estimate/benchmark this - task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms + task_xfer_cost = root_worker_id(gproc) != myid() ? 1_000_000 : 0 # 1ms + pid = Dagger.root_worker_id(gproc) - tx_rate = get(get(state.worker_transfer_rate, gproc.pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) + tx_rate = get(get(state.worker_transfer_rate, pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end chunks_cleanup() diff --git a/src/scopes.jl b/src/scopes.jl index 79190c292..28aa8fa00 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -101,7 +101,7 @@ struct ExactScope <: AbstractScope parent::ProcessScope processor::Processor end -ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +ExactScope(proc) = ExactScope(ProcessScope(root_worker_id(get_parent(proc))), proc) proc_in_scope(proc::Processor, scope::ExactScope) = proc == scope.processor "Indicates that the applied scopes `x` and `y` are incompatible." diff --git a/src/shard.jl b/src/shard.jl new file mode 100644 index 000000000..ecd0ee570 --- /dev/null +++ b/src/shard.jl @@ -0,0 +1,89 @@ +""" +Maps a value to one of multiple distributed "mirror" values automatically when +used as a thunk argument. Construct using `@shard` or `shard`. +""" +struct Shard + chunks::Dict{Processor,Chunk} +end + +""" + shard(f; kwargs...) -> Chunk{Shard} + +Executes `f` on all workers in `workers`, wrapping the result in a +process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these +`Chunk`s on the current worker. + +Keyword arguments: +- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. +- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. +- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. +""" +function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) + if procs === nothing + if workers !== nothing + procs = [OSProc(w) for w in workers] + else + procs = lock(Sch.eager_context()) do + copy(Sch.eager_context().procs) + end + end + if per_thread + _procs = ThreadProc[] + for p in procs + append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) + end + procs = _procs + end + else + if workers !== nothing + throw(ArgumentError("Cannot combine `procs` and `workers`")) + elseif per_thread + throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) + end + end + isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) + shard_running_dict = Dict{Processor,DTask}() + for proc in procs + scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) + thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) + shard_running_dict[proc] = thunk + end + shard_dict = Dict{Processor,Chunk}() + for proc in procs + shard_dict[proc] = fetch(shard_running_dict[proc])[] + end + return Shard(shard_dict) +end + +"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." +macro shard(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $shard(f; $(opts...)) + end + end +end + +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + # Match either this proc or some ancestor + # N.B. This behavior may bypass the piece's scope restriction + proc = to_proc + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + parent = Dagger.get_parent(proc) + while parent != proc + proc = parent + parent = Dagger.get_parent(proc) + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + end + + throw(KeyError(to_proc)) +end +Base.iterate(s::Shard) = iterate(values(s.chunks)) +Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) +Base.length(s::Shard) = length(s.chunks) diff --git a/src/submission.jl b/src/submission.jl index 4ff4f2294..d3102eacf 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -285,7 +285,13 @@ function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function DTaskMetadata(spec::DTaskSpec) + rt = spec.options.return_type + if rt !== nothing && isconcretetype(rt) && rt !== Any + return DTaskMetadata(rt) + end + return DTaskMetadata(eager_metadata(spec.fargs)) +end function eager_metadata(fargs) f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f @@ -298,6 +304,10 @@ function eager_spawn(spec::DTaskSpec) uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) + # Propagate inferred return type to options + if isconcretetype(metadata.return_type) + spec.options.return_type = metadata.return_type + end return DTask(uid, future, metadata) end @@ -320,10 +330,16 @@ function eager_launch!(pair::DTaskPair) end end + # Propagate DTask return_type into options so the created Thunk has chunktype for downstream inference + options = spec.options + if isconcretetype(task.metadata.return_type) + options = copy(options) + options.return_type = task.metadata.return_type + end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - fargs, spec.options, true)) + fargs, options, true)) task.thunk_ref = thunk_id.ref end # FIXME: Don't convert Tuple to Vector{Argument} @@ -353,7 +369,13 @@ function eager_launch!(pairs::Vector{DTaskPair}) end end end - all_options = Options[pair.spec.options for pair in pairs] + # Propagate DTask return_type into options so created Thunks have chunktype for downstream inference + all_options = Options[ + let opts = pair.spec.options + isconcretetype(pair.task.metadata.return_type) ? (o = copy(opts); o.return_type = pair.task.metadata.return_type; o) : opts + end + for pair in pairs + ] # Submit the tasks #=FIXME:REALLOC=# diff --git a/src/thunk.jl b/src/thunk.jl index e13e299f0..c24e0c329 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -247,6 +247,14 @@ isweak(t) = false Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) +# Use options.return_type when set (e.g. from mpi_propagate_chunk_types! or eager_metadata) +# so that Thunk arguments propagate type to downstream eager_metadata/execute! +function chunktype(t::Thunk) + if t.options !== nothing && t.options.return_type !== nothing && isconcretetype(t.options.return_type) + return t.options.return_type + end + return typeof(t) +end Base.convert(::Type{ThunkSyncdep}, t::WeakThunk) = ThunkSyncdep(nothing, t) ThunkSyncdep(t::WeakThunk) = ThunkSyncdep(nothing, t) @@ -462,7 +470,7 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) - if !isempty(kwargs) + if !Base.isempty(kwargs) kwargs = only(kwargs).args end if body !== nothing @@ -530,7 +538,7 @@ function spawn(f, args...; kwargs...) @nospecialize f args kwargs # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] @@ -545,7 +553,7 @@ function spawn(f, args...; kwargs...) end function typed_spawn(f, args...; kwargs...) # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] diff --git a/src/tochunk.jl b/src/tochunk.jl new file mode 100644 index 000000000..ff15e426e --- /dev/null +++ b/src/tochunk.jl @@ -0,0 +1,119 @@ +@warn "Update tochunk docstring" maxlog=1 +""" + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk + +Create a chunk from data `x` which resides on `proc` and which has scope +`scope`. + +`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a +`Chunk`) which will be used to manage the reference contained in the `Chunk` +generated by this function. If `device` is `nothing` (the default), the data +will be inspected to determine if it's safe to serialize; if so, the default +MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will +be used. + +`type` can be specified manually to force the type to be `Chunk{type}`. + +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + +All other kwargs are passed directly to `MemPool.poolset`. +""" +tochunk(x::X, proc::P, space::M; kwargs...) where {X,P<:Processor,M<:MemorySpace} = + tochunk(x, proc, space, AnyScope(); kwargs...) +function tochunk(x::X, proc::P, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S,M<:MemorySpace} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if x isa Chunk + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +# Disambiguate: Chunk-specific 3-arg so kwcall(tochunk, Chunk, Processor, Scope) is not ambiguous with utils/chunks.jl +function tochunk(x::Chunk, proc::P, scope::S; rewrap=false, kwargs...) where {P<:Processor,S} + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end +function tochunk(x::X, proc::P, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + space = x.space + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + space = default_memory_space(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +function tochunk(x::X, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,M<:MemorySpace,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + proc = x.processor + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + proc = default_processor(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),typeof(proc),S,M}(type, domain(x), ref, proc, scope, space) +end +# 2-arg: avoid overwriting utils/chunks.jl's tochunk(Any, Any) and tochunk(Any); only add Processor/MemorySpace variants +# Chunk + Processor: disambiguate vs utils/chunks.jl's tochunk(x::Chunk, proc; ...) +tochunk(x::Chunk, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, space::MemorySpace; kwargs...) = tochunk(x, space, AnyScope(); kwargs...) + +check_proc_space(x, proc, space) = nothing +function check_proc_space(x::Chunk, proc, space) + if x.space !== space + throw(ArgumentError("Memory space mismatch: Chunk=$(x.space) != Requested=$space")) + end +end +function check_proc_space(x::Thunk, proc, space) + # FIXME: Validate +end +function maybe_rewrap(x, proc, space, scope; type, rewrap) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end + +tochunk_pset(x, space::MemorySpace; device=nothing, type=nothing, kwargs...) = poolset(x; device, kwargs...) + +# savechunk: defined in utils/chunks.jl (fork Chunk has space field; do not duplicate here) diff --git a/src/types/acceleration.jl b/src/types/acceleration.jl new file mode 100644 index 000000000..f9aa1d86f --- /dev/null +++ b/src/types/acceleration.jl @@ -0,0 +1,3 @@ +abstract type Acceleration end + +struct DistributedAcceleration <: Acceleration end diff --git a/src/types/chunk.jl b/src/types/chunk.jl new file mode 100644 index 000000000..9b8102a6d --- /dev/null +++ b/src/types/chunk.jl @@ -0,0 +1,27 @@ +""" + Chunk + +A reference to a piece of data located on a remote worker. `Chunk`s are +typically created with `Dagger.tochunk(data)`, and the data can then be +accessed from any worker with `collect(::Chunk)`. `Chunk`s are +serialization-safe, and use distributed refcounting (provided by +`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, +as long as a reference exists on some worker. + +Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a +sense) the processor that "owns" or contains the data. Calling +`collect(::Chunk)` will perform data movement and conversions defined by that +processor to safely serialize the data to the calling worker. + +## Constructors +See [`tochunk`](@ref). +""" + +mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope, M<:MemorySpace} + chunktype::Type{T} + domain + handle::H + processor::P + scope::S + space::M +end diff --git a/src/types/memory-space.jl b/src/types/memory-space.jl new file mode 100644 index 000000000..247ceccb0 --- /dev/null +++ b/src/types/memory-space.jl @@ -0,0 +1 @@ +abstract type MemorySpace end \ No newline at end of file diff --git a/src/types/processor.jl b/src/types/processor.jl new file mode 100644 index 000000000..1e333413f --- /dev/null +++ b/src/types/processor.jl @@ -0,0 +1,2 @@ +# Docstring for Processor is attached in src/processor.jl after OSProc is defined (avoids "Replacing docs" warning). +abstract type Processor end \ No newline at end of file diff --git a/src/types/scope.jl b/src/types/scope.jl new file mode 100644 index 000000000..0197fddf9 --- /dev/null +++ b/src/types/scope.jl @@ -0,0 +1 @@ +abstract type AbstractScope end \ No newline at end of file diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 9f0c3b487..1300a5a1d 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -161,7 +161,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); device=nothing, re end end ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope) + space = memory_space(proc) + Chunk{X,typeof(ref),P,S,typeof(space)}(X, domain(x), ref, proc, scope, space) end function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) if rewrap @@ -185,5 +186,6 @@ function savechunk(data, dir, f) fr = FileRef(f, sz) proc = OSProc() scope = AnyScope() # FIXME: Scoped to this node - Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) + space = memory_space(proc) + Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope),typeof(space)}(typeof(data), domain(data), fr, proc, scope, space) end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 873e47e79..678445051 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -59,4 +59,7 @@ macro opcounter(category, count=1) end end) end -opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] + +# No-op debug helper for tracking largest values (used alongside @opcounter) +largest_value_update!(::Any) = nothing \ No newline at end of file diff --git a/src/weakchunk.jl b/src/weakchunk.jl new file mode 100644 index 000000000..e31070536 --- /dev/null +++ b/src/weakchunk.jl @@ -0,0 +1,23 @@ +struct WeakChunk + wid::Int + id::Int + x::WeakRef +end + +function WeakChunk(c::Chunk) + return WeakChunk(c.handle.owner, c.handle.id, WeakRef(c)) +end + +unwrap_weak(c::WeakChunk) = c.x.value +function unwrap_weak_checked(c::WeakChunk) + cw = unwrap_weak(c) + @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" + return cw +end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false +is_task_or_chunk(c::WeakChunk) = true +Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = + error("Cannot serialize a WeakChunk") +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 000000000..7d71e801e --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,72 @@ +using Dagger, MPI, LinearAlgebra + +Dagger.accelerate!(:mpi) +Dagger.check_uniformity!(true) +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +sz = MPI.Comm_size(comm) + +mpidagger_all_results = [] + +# Define constants +# You need to define the MPI workers before running the benchmark +# Example: mpirun -n 4 julia --project benchmarks/DaggerMPI_Weak_scale.jl +datatype = [Float32, Float64] +datasize = 40 +try + for T in datatype + A = rand(T, datasize, datasize) + A = A * A' + A[diagind(A)] .+= size(A, 1) + B = copy(A) + @assert ishermitian(B) + DA = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DA.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DA, A) + DB = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DB.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DB, B) + LinearAlgebra._chol!(DA, UpperTriangular) + elapsed_time = @elapsed chol_DB = LinearAlgebra._chol!(DB, UpperTriangular) + + # Store results + result = ( + procs = sz, + dtype = T, + size = datasize, + time = elapsed_time, + gflops = (datasize^3 / 3) / (elapsed_time * 1e9) + ) + push!(mpidagger_all_results, result) + end +catch + if rank == 0 + Core.print("Rank 0:\n") + rethrow() + elseif rank == 1 + Core.print("Rank 1:\n") + sleep(1) + rethrow() + end +finally + MPI.Barrier(comm) +end +if rank == 0 + #= Write results to CSV + mkpath("benchmarks/results") + if !isempty(mpidagger_all_results) + df = DataFrame(mpidagger_all_results) + CSV.write("benchmarks/results/DaggerMPI_Weak_scale_results.csv", df) + + end + =# + # Summary statistics + for result in mpidagger_all_results + println(result.procs, ",", result.dtype, ",", result.size, ",", result.time, ",", result.gflops) + end + #println("\nAll Cholesky tests completed!") +end