Skip to content

Commit 43f732a

Browse files
interface improvements
1 parent 675af4a commit 43f732a

File tree

2 files changed

+18
-37
lines changed

2 files changed

+18
-37
lines changed

src/sfdp.jl

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ the nodes.
2424
positions will be truncated or filled up with random values between [-1,1] in every coordinate.
2525
2626
- `seed=1`: Seed for random initial positions.
27-
- `pin::Union{Nothing, Vector{Point{Dim,Bool}}}=nothing` : Anchors positions of nodes to initial position
27+
- `pin=Dict{Int,Bool}()` : Anchors positions of nodes to initial position
2828
"""
2929
@addcall struct SFDP{Dim,Ptype,T<:AbstractFloat} <: IterativeLayout{Dim,Ptype}
3030
tol::T
@@ -33,11 +33,11 @@ the nodes.
3333
iterations::Int
3434
initialpos::Dict{Int, Point{Dim,Ptype}}
3535
seed::UInt
36-
pin::Vector{Point{Dim,Bool}}
37-
function SFDP(tol, C, K, iterations, initialpos, seed, pin)
38-
for (ix, p) in enumerate(pin)
39-
for (id, c) in enumerate(p)
40-
c && !haskey(initialpos, ix) && error("Please provide coordinate for every pinned position")
36+
pin::Dict{Int, Bool}
37+
function SFDP(tol::T, C::T, K::T, iterations::Int, initialpos::Dict, seed::UInt, pin::Dict) where T<:AbstractFloat
38+
for (ix, p) in pin
39+
if !haskey(initialpos, ix) && p
40+
@warn "No coordinate provided for pinned position $ix"
4141
end
4242
end
4343
dim = get_pt_dim(initialpos)
@@ -48,38 +48,20 @@ end
4848

4949
# TODO: check SFDP default parameters
5050
function SFDP(; dim=2, Ptype=Float64, tol=1.0, C=0.2, K=1.0, iterations=100, initialpos=Dict{Int, Point{dim,Ptype}}(),
51-
seed::UInt=UInt(1), pin = Vector{Bool}())
52-
@show initialpos
51+
seed::UInt=UInt(1), pin = Dict{Int, Bool}())
5352
return SFDP(tol, C, K, iterations, initialpos, seed, pin)
5453
end
5554

5655
function SFDP(tol::T, C::T, K::T, iterations::Int, initialpos::Vector, seed::UInt, pin) where T<:AbstractFloat
57-
@info "sfdp with ip vec"
58-
dim = get_pt_dim(initialpos)
5956
initialpos = Dict(zip(1:length(initialpos), Point.(initialpos)))
60-
Ptype = get_pt_ptype(initialpos)
6157
# TODO fix initial pos if list has points of multiple types
6258
return SFDP(tol, C, K, iterations, initialpos, seed, pin)
6359
end
6460

65-
function SFDP(tol::T, C::T, K::T, iterations::Int, initialpos::Dict{Int, <:Point}, seed::UInt, pin::Dict{Int, <:Point}) where {T<:AbstractFloat}
66-
@info "sfdp with ip dict and pin dict"
67-
fixed = falses(maximum(keys(pin)))
68-
for (i, p) in pin
69-
haskey(initialpos, i) && @warn "overwriting initial position of node $i with pin position"
70-
initialpos[i] = p
71-
fixed[i] = true
72-
end
73-
dim = get_pt_dim(initialpos)
74-
Ptype = get_pt_ptype(initialpos)
75-
return SFDP(tol, C, K, iterations, initialpos, seed, Point{dim,Bool}.(fixed))
76-
end
77-
7861
function SFDP(tol::T, C::T, K::T, iterations::Int, initialpos::Dict{Int, <:Point}, seed::UInt, pin::Vector{Bool}) where T<:AbstractFloat
79-
@info "sfdp with ip dict and pin vec"
8062
dim = get_pt_dim(initialpos)
81-
Ptype = get_pt_ptype(initialpos)
82-
return SFDP(tol, C, K, iterations, initialpos, seed, Point{dim, Bool}.(pin))
63+
fixed = Dict(zip(1:length(pin), pin))
64+
return SFDP(tol, C, K, iterations, initialpos, seed, fixed)
8365
end
8466

8567
function get_pt_ptype(ip::Dict{Int, <:Point})
@@ -141,7 +123,6 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
141123
energy = zero(energy0)
142124
Ftype = eltype(locs)
143125
N = size(adj_matrix, 1)
144-
pin = N > length(algo.pin) ? vcat(algo.pin, falses(N-length(algo.pin))) : algo.pin
145126
for i in 1:N
146127
force = zero(Ftype)
147128
for j in 1:N
@@ -156,7 +137,7 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
156137
((locs[j] .- locs[i]) / norm(locs[j] .- locs[i])))
157138
end
158139
end
159-
if !pin[i]
140+
if !get(algo.pin, i, false)
160141
locs[i] = locs[i] .+ step .* (force ./ norm(force))
161142
end
162143
energy = energy + norm(force)^2

test/runtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jagmesh_adj = jagmesh()
3232
ip = Dict(1=>Point(1.0,3.0))
3333
algo = SFDP(; initialpos = ip)
3434
@test algo isa SFDP{2, Float64}
35-
p = [Point((true, true))]
35+
p = [true, true]
3636
algo = SFDP(; initialpos = ip, pin = p)
3737
end
3838

@@ -51,24 +51,24 @@ jagmesh_adj = jagmesh()
5151

5252
@testset "Testing Jagmesh1 graph" begin
5353
println("SFDP Jagmesh1")
54-
positions = @time SFDP(; dim=2, Ptype=Float32, tol=0.9, K=1, iterations=10)(jagmesh_adj)
54+
positions = @time SFDP(; dim=2, Ptype=Float32, tol=0.9, K=1.0, iterations=10)(jagmesh_adj)
5555
@test typeof(positions) == Vector{Point2f}
56-
positions = @time SFDP(; dim=3, Ptype=Float32, tol=0.9, K=1, iterations=10)(jagmesh_adj)
56+
positions = @time SFDP(; dim=3, Ptype=Float32, tol=0.9, K=1.0, iterations=10)(jagmesh_adj)
5757
@test typeof(positions) == Vector{Point3f}
5858
end
5959

6060
@testset "Testing wheel_graph" begin
6161
println("SFDP Wheelgraph")
6262
g = wheel_graph(10)
6363
adj_matrix = adjacency_matrix(g)
64-
positions = @time SFDP(; dim=2, Ptype=Float32, tol=0.1, K=1)(adj_matrix)
64+
positions = @time SFDP(; dim=2, Ptype=Float32, tol=0.1, K=1.0)(adj_matrix)
6565
@test typeof(positions) == Vector{Point2f}
66-
@test positions == sfdp(adj_matrix; dim=2, Ptype=Float32, tol=0.1, K=1)
67-
positions = @time SFDP(; dim=3, Ptype=Float32, tol=0.1, K=1)(adj_matrix)
66+
@test positions == sfdp(adj_matrix; dim=2, Ptype=Float32, tol=0.1, K=1.0)
67+
positions = @time SFDP(; dim=3, Ptype=Float32, tol=0.1, K=1.0)(adj_matrix)
6868
@test typeof(positions) == Vector{Point3f}
69-
@test positions == sfdp(adj_matrix; dim=3, Ptype=Float32, tol=0.1, K=1)
69+
@test positions == sfdp(adj_matrix; dim=3, Ptype=Float32, tol=0.1, K=1.0)
7070
ip = [Point2f(3.0, 1.0)]
71-
@test ip[1] == sfdp(adj_matrix; dim=3, Ptype=Float32, tol=0.1, K=1, initialpos = ip, pin = [true])[1]
71+
@test ip[1] == sfdp(adj_matrix; dim=3, Ptype=Float32, tol=0.1, K=1.0, initialpos = ip, pin = [true])[1]
7272
end
7373
end
7474

0 commit comments

Comments
 (0)