Skip to content

Commit 2c48da4

Browse files
authored
Merge pull request #258 from devmotion/dw/remotecall_wait
Don't fetch results that are discarded
2 parents 1faf00f + 08ae695 commit 2c48da4

File tree

7 files changed

+68
-59
lines changed

7 files changed

+68
-59
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ end
6767
# This will turn local AbstractArrays into DArrays
6868
dbc = bcdistribute(bc)
6969

70-
asyncmap(procs(dest)) do p
71-
remotecall_fetch(p) do
70+
@sync for p in procs(dest)
71+
@async remotecall_wait(p) do
7272
# get the indices for the localpart
7373
lpidx = localpartindex(dest)
7474
@assert lpidx != 0
7575
# create a local version of the broadcast, by constructing views
7676
# Note: creates copies of the argument
7777
lbc = bclocal(dbc, dest.indices[lpidx])
7878
copyto!(localpart(dest), lbc)
79-
return nothing
8079
end
8180
end
81+
8282
return dest
8383
end
8484

src/core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function close_by_id(id, pids)
2121
global refs
2222
@sync begin
2323
for p in pids
24-
@async remotecall_fetch(release_localpart, p, id)
24+
@async remotecall_wait(release_localpart, p, id)
2525
end
2626
if !(myid() in pids)
2727
release_localpart(id)

src/darray.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function DArray(id, init, dims, pids, idxs, cuts)
9797

9898
if length(unique(localtypes)) != 1
9999
@sync for p in pids
100-
@async remotecall_fetch(release_localpart, p, id)
100+
@async remotecall_wait(release_localpart, p, id)
101101
end
102102
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
103103
end
@@ -147,8 +147,8 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
147147
end
148148
end
149149

150-
@sync for i = 1:length(pids)
151-
@async remotecall_fetch(construct_localparts, pids[i], init, id, (npids,), pids, idxs, cuts; T=T, A=T)
150+
@sync for p in pids
151+
@async remotecall_wait(construct_localparts, p, init, id, (npids,), pids, idxs, cuts; T=T, A=T)
152152
end
153153

154154
if myid() in pids
@@ -161,9 +161,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
161161
end
162162

163163
function gather(d::DArray{T,1,T}) where T
164-
a=Array{T}(undef, length(procs(d)))
165-
@sync for (i,p) in enumerate(procs(d))
166-
@async a[i] = remotecall_fetch(localpart, p, d)
164+
pids = procs(d)
165+
a = Vector{T}(undef, length(pids))
166+
asyncmap!(a, pids) do p
167+
remotecall_fetch(localpart, p, d)
167168
end
168169
a
169170
end
@@ -195,12 +196,9 @@ function DArray(refs)
195196
dimdist = size(refs)
196197
id = next_did()
197198

198-
npids = [r.where for r in refs]
199199
nsizes = Array{Tuple}(undef, dimdist)
200-
@sync for i in 1:length(refs)
201-
let i=i
202-
@async nsizes[i] = remotecall_fetch(sz_localpart_ref, npids[i], refs[i], id)
203-
end
200+
asyncmap!(nsizes, refs) do r
201+
remotecall_fetch(sz_localpart_ref, r.where, r, id)
204202
end
205203

206204
nindices = Array{NTuple{length(dimdist),UnitRange{Int}}}(undef, dimdist...)
@@ -223,7 +221,7 @@ function DArray(refs)
223221
ncuts = Array{Int,1}[pushfirst!(sort(unique(lastidxs[x,:])), 1) for x in 1:length(dimdist)]
224222
ndims = tuple([sort(unique(lastidxs[x,:]))[end]-1 for x in 1:length(dimdist)]...)
225223

226-
DArray(id, refs, ndims, reshape(npids, dimdist), nindices, ncuts)
224+
DArray(id, refs, ndims, map(r -> r.where, refs), nindices, ncuts)
227225
end
228226

229227
macro DArray(ex0::Expr)
@@ -673,8 +671,8 @@ Base.copy(d::SubDArray) = copyto!(similar(d), d)
673671
Base.copy(d::SubDArray{<:Any,2}) = copyto!(similar(d), d)
674672

675673
function Base.copyto!(dest::SubOrDArray, src::AbstractArray)
676-
asyncmap(procs(dest)) do p
677-
remotecall_fetch(p) do
674+
@sync for p in procs(dest)
675+
@async remotecall_wait(p) do
678676
ldest = localpart(dest)
679677
copyto!(ldest, view(src, localindices(dest)...))
680678
end
@@ -684,8 +682,8 @@ end
684682

685683
function Base.deepcopy(src::DArray)
686684
dest = similar(src)
687-
asyncmap(procs(src)) do p
688-
remotecall_fetch(p) do
685+
@sync for p in procs(src)
686+
@async remotecall_wait(p) do
689687
dest[:L] = deepcopy(src[:L])
690688
end
691689
end
@@ -835,14 +833,15 @@ end
835833

836834
function Base.fill!(A::DArray, x)
837835
@sync for p in procs(A)
838-
@async remotecall_fetch((A,x)->(fill!(localpart(A), x); nothing), p, A, x)
836+
@async remotecall_wait((A,x)->fill!(localpart(A), x), p, A, x)
839837
end
840838
return A
841839
end
842840

843841
function Random.rand!(A::DArray, ::Type{T}) where T
844-
asyncmap(procs(A)) do p
845-
remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
842+
@sync for p in procs(A)
843+
@async remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
846844
end
845+
return A
847846
end
848847

src/linalg.jl

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ function LinearAlgebra.axpy!(α, x::DArray, y::DArray)
2525
if length(x) != length(y)
2626
throw(DimensionMismatch("vectors must have same length"))
2727
end
28-
asyncmap(procs(y)) do p
29-
@async remotecall_fetch(p) do
28+
@sync for p in procs(y)
29+
@async remotecall_wait(p) do
3030
axpy!(α, localpart(x), localpart(y))
31-
return nothing
3231
end
3332
end
3433
return y
@@ -39,26 +38,22 @@ function LinearAlgebra.dot(x::DVector, y::DVector)
3938
throw(DimensionMismatch(""))
4039
end
4140

42-
results=Any[]
43-
asyncmap(procs(x)) do p
44-
push!(results, remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y))
41+
results = asyncmap(procs(x)) do p
42+
remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y)
4543
end
4644
return reduce(+, results)
4745
end
4846

4947
function LinearAlgebra.norm(x::DArray, p::Real = 2)
50-
results = []
51-
@sync begin
52-
for pp in procs(x)
53-
@async push!(results, remotecall_fetch(() -> norm(localpart(x), p), pp))
54-
end
48+
results = asyncmap(procs(x)) do pp
49+
remotecall_fetch(() -> norm(localpart(x), p), pp)
5550
end
5651
return norm(results, p)
5752
end
5853

5954
function LinearAlgebra.rmul!(A::DArray, x::Number)
6055
@sync for p in procs(A)
61-
@async remotecall_fetch((A,x)->(rmul!(localpart(A), x); nothing), p, A, x)
56+
@async remotecall_wait((A,x)->rmul!(localpart(A), x), p, A, x)
6257
end
6358
return A
6459
end
@@ -104,13 +99,12 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
10499
# Scale y if necessary
105100
if β != one(β)
106101
asyncmap(procs(y)) do p
107-
remotecall_fetch(p) do
102+
remotecall_wait(p) do
108103
if !iszero(β)
109104
rmul!(localpart(y), β)
110105
else
111106
fill!(localpart(y), 0)
112107
end
113-
return nothing
114108
end
115109
end
116110
end
@@ -120,7 +114,7 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
120114
p = y.pids[i]
121115
for j = 1:size(R, 2)
122116
rij = R[i,j]
123-
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
117+
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
124118
end
125119
end
126120

@@ -150,14 +144,13 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab
150144

151145
# Scale y if necessary
152146
if β != one(β)
153-
asyncmap(procs(y)) do p
154-
remotecall_fetch(p) do
147+
@sync for p in procs(y)
148+
@async remotecall_wait(p) do
155149
if !iszero(β)
156150
rmul!(localpart(y), β)
157151
else
158152
fill!(localpart(y), 0)
159153
end
160-
return nothing
161154
end
162155
end
163156
end
@@ -167,7 +160,7 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab
167160
p = y.pids[i]
168161
for j = 1:size(R, 2)
169162
rij = R[i,j]
170-
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
163+
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
171164
end
172165
end
173166
return y
@@ -238,10 +231,10 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
238231
# Scale C if necessary
239232
if β != one(β)
240233
@sync for p in C.pids
241-
if β != zero(β)
242-
@async remotecall_fetch(() -> (rmul!(localpart(C), β); nothing), p)
234+
if iszero(β)
235+
@async remotecall_wait(() -> fill!(localpart(C), 0), p)
243236
else
244-
@async remotecall_fetch(() -> (fill!(localpart(C), 0); nothing), p)
237+
@async remotecall_wait(() -> rmul!(localpart(C), β), p)
245238
end
246239
end
247240
end
@@ -252,7 +245,7 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
252245
p = C.pids[i,k]
253246
for j = 1:size(R, 2)
254247
rijk = R[i,j,k]
255-
@async remotecall_fetch(d -> (add!(localpart(d), fetch(rijk), α); nothing), p, C)
248+
@async remotecall_wait(d -> add!(localpart(d), fetch(rijk), α), p, C)
256249
end
257250
end
258251
end

src/mapreduce.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
Base.map(f, d0::DArray, ds::AbstractArray...) = broadcast(f, d0, ds...)
44

55
function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
6-
asyncmap(procs(dest)) do p
7-
remotecall_fetch(p) do
6+
@sync for p in procs(dest)
7+
@async remotecall_wait(p) do
88
map!(f, localpart(dest), makelocal(src, localindices(dest)...))
9-
return nothing
109
end
1110
end
1211
return dest
@@ -38,8 +37,8 @@ function Base.reducedim_initarray(A::DArray, region, v0, ::Type{R}) where {R}
3837
# Store reduction on lowest pids
3938
pids = A.pids[ntuple(i -> i in region ? (1:1) : (:), ndims(A))...]
4039
chunks = similar(pids, Future)
41-
@sync for i in eachindex(pids)
42-
@async chunks[i...] = remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), pids[i...])
40+
asyncmap!(chunks, pids) do p
41+
remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), p)
4342
end
4443
return DArray(chunks)
4544
end
@@ -64,13 +63,12 @@ end
6463
# has been run on each localpart with mapreducedim_within. Eventually, we might
6564
# want to write mapreducedim_between! as a binary reduction.
6665
function mapreducedim_between!(f, op, R::DArray, A::DArray, region)
67-
asyncmap(procs(R)) do p
68-
remotecall_fetch(p, f, op, R, A, region) do f, op, R, A, region
66+
@sync for p in procs(R)
67+
@async remotecall_wait(p, f, op, R, A, region) do f, op, R, A, region
6968
localind = [r for r = localindices(A)]
7069
localind[[region...]] = [1:n for n = size(A)[[region...]]]
7170
B = convert(Array, A[localind...])
7271
Base.mapreducedim!(f, op, localpart(R), B)
73-
nothing
7472
end
7573
end
7674
return R
@@ -163,8 +161,8 @@ function map_localparts(f::Callable, A::Array, DA::DArray)
163161
end
164162

165163
function map_localparts!(f::Callable, d::DArray)
166-
asyncmap(procs(d)) do p
167-
remotecall_fetch((f,d)->(f(localpart(d)); nothing), p, f, d)
164+
@sync for p in procs(d)
165+
@async remotecall_wait((f,d)->f(localpart(d)), p, f, d)
168166
end
169167
return d
170168
end

src/spmd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SPMD
22

3-
using Distributed: RemoteChannel, myid, procs, remote_do, remotecall_fetch
3+
using Distributed: RemoteChannel, myid, procs, remote_do, remotecall_fetch, remotecall_wait
44
using ..DistributedArrays: DistributedArrays, gather, next_did
55

66
export sendto, recvfrom, recvfrom_any, barrier, bcast, scatter, gather
@@ -243,7 +243,7 @@ function spmd(f, args...; pids=procs(), context=nothing)
243243
ctxt_id = context.id
244244
end
245245
@sync for p in pids
246-
@async remotecall_fetch(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
246+
@async remotecall_wait(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
247247
end
248248
nothing
249249
end

test/darray.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,8 @@ unpack(ex) = ex
358358
A = randn(100,100)
359359
DA = distribute(A)
360360

361-
# sum either throws an ArgumentError or a CompositeException of ArgumentErrors
361+
# sum either throws an ArgumentError, a CompositeException of ArgumentErrors,
362+
# or a RemoteException wrapping an ArgumentError
362363
try
363364
sum(DA, dims=-1)
364365
catch err
@@ -369,6 +370,9 @@ unpack(ex) = ex
369370
orig_err = unpack(excep)
370371
@test isa(orig_err, ArgumentError)
371372
end
373+
elseif isa(err, RemoteException)
374+
@test err.captured isa CapturedException
375+
@test err.captured.ex isa ArgumentError
372376
else
373377
@test isa(err, ArgumentError)
374378
end
@@ -383,6 +387,9 @@ unpack(ex) = ex
383387
orig_err = unpack(excep)
384388
@test isa(orig_err, ArgumentError)
385389
end
390+
elseif isa(err, RemoteException)
391+
@test err.captured isa CapturedException
392+
@test err.captured.ex isa ArgumentError
386393
else
387394
@test isa(err, ArgumentError)
388395
end
@@ -1039,6 +1046,8 @@ check_leaks()
10391046
close(d)
10401047
end
10411048

1049+
check_leaks()
1050+
10421051
@testset "rand!" begin
10431052
d = dzeros(30, 30)
10441053
rand!(d)
@@ -1048,6 +1057,16 @@ end
10481057

10491058
check_leaks()
10501059

1060+
@testset "fill!" begin
1061+
d = dzeros(30, 30)
1062+
fill!(d, 3.14)
1063+
@test all(x-> x == 3.14, d)
1064+
1065+
close(d)
1066+
end
1067+
1068+
check_leaks()
1069+
10511070
d_closeall()
10521071

10531072
@testset "test for any leaks" begin

0 commit comments

Comments
 (0)