Skip to content

Commit 1c7f6d2

Browse files
committed
add option to pin node positions to sfdp
1 parent 8eca84d commit 1c7f6d2

File tree

4 files changed

+141
-24
lines changed

4 files changed

+141
-24
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
11+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1112

1213
[compat]
1314
GeometryBasics = "0.4"
1415
Requires = "1"
16+
StaticArrays = "1"
1517
julia = "1"
1618

1719
[extras]

src/NetworkLayout.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using GeometryBasics
44
using Requires
55
using LinearAlgebra: norm
66
using Random
7+
using StaticArrays
78

89
export LayoutIterator
910

@@ -137,6 +138,66 @@ function make_symmetric!(A::AbstractMatrix)
137138
return A
138139
end
139140

141+
"""
142+
Initialpos and pin can be given as diffent types (dicts, vectors, ...)
143+
Sanitize and transform them into
144+
145+
_initialpos :: Dict{Int,Point{dim,Ptype}}()
146+
_pin :: Dict{Int,SVector{dim,Bool}}()
147+
"""
148+
function _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
149+
if !isempty(initialpos)
150+
_initialpos = Dict{Int,Point{dim,Ptype}}(k => Point{dim,Ptype}(v) for (k, v) in pairs(initialpos))
151+
else
152+
_initialpos = Dict{Int,Point{dim,Ptype}}()
153+
end
154+
155+
_pin = Dict{Int,SVector{dim,Bool}}()
156+
for (k, v) in pairs(pin)
157+
if v == nothing
158+
continue
159+
elseif v isa Bool
160+
_pin[k] = SVector{dim,Bool}(v for i in 1:dim)
161+
else # some container
162+
if eltype(v) <: Bool
163+
_pin[k] = v
164+
else
165+
# seems to be an initial position
166+
_initialpos[k] = v
167+
_pin[k] = SVector{dim,Bool}([true for i in 1:dim])
168+
end
169+
end
170+
end
171+
return _initialpos, _pin
172+
end
173+
174+
"""
175+
From an point or a colletion of point like objects try to
176+
infer the PType and the dimension.
177+
178+
i.e.
179+
infer_pointtype([(1,2), (2.3, 4)]) == (2, Float64)
180+
"""
181+
infer_pointtype(::AbstractPoint{dim,t}) where {dim,t} = dim, t
182+
infer_pointtype(::NTuple{dim,t}) where {dim,t} = dim, t
183+
infer_pointtype(t::Tuple) = length(t), promote_type(typeof(t).parameters...)
184+
function infer_pointtype(v)
185+
v = values(v) # needed for broadcast ofer dict
186+
isempty(v) && throw(ArgumentError("Can not infer pointtype of empty container!"))
187+
elt = isconcretetype(eltype(v)) ? eltype(v) : promote_type(typeof.(v)...)
188+
189+
if elt <: Number
190+
return (length(v), elt)
191+
else
192+
ty = infer_pointtype.(v)
193+
dims = getindex.(ty, 1)
194+
if !all(isequal(first(dims)), dims)
195+
throw(ArgumentError("Got container with different point dimesions!"))
196+
end
197+
(dims[1], promote_type(getindex.(ty, 2)...))
198+
end
199+
end
200+
140201
"""
141202
@addcall
142203

src/sfdp.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ the nodes.
2020
- `iterations=100`: maximum number of iterations
2121
- `initialpos=Point{dim,Ptype}[]`
2222
23-
Provide list of initial positions. If length does not match Network size the initial
24-
positions will be truncated or filled up with random values between [-1,1] in every coordinate.
23+
Provide `Vector` or `Dict` of initial positions. All positions will be initialized
24+
using random coordinates between [-1,1]. Random positions will be overwritten using
25+
the key-val-pairs provided by this argument.
26+
27+
- `pin=[]`: Pin node positions (won't be updated). Can be given as `Vector` or `Dict`
28+
of node index -> value pairings. Values can be either
29+
- `(12, 4.0)` : overwrite initial position and pin
30+
- `true/false` : pin this position
31+
- `(true, false, false)` : only pin certain coordinates
2532
2633
- `seed=1`: Seed for random initial positions.
2734
"""
@@ -30,44 +37,45 @@ the nodes.
3037
C::T
3138
K::T
3239
iterations::Int
33-
initialpos::Vector{Point{Dim,Ptype}}
40+
initialpos::Dict{Int,Point{Dim,Ptype}}
41+
pin::Dict{Int,SVector{Dim,Bool}}
3442
seed::UInt
3543
end
3644

3745
# TODO: check SFDP default parameters
38-
function SFDP(; dim=2, Ptype=Float64, tol=1.0, C=0.2, K=1.0, iterations=100, initialpos=Point{dim,Ptype}[],
46+
function SFDP(; dim=2, Ptype=Float64,
47+
tol=1.0, C=0.2, K=1.0,
48+
iterations=100,
49+
initialpos=[], pin=[],
3950
seed=1)
4051
if !isempty(initialpos)
41-
initialpos = Point.(initialpos)
42-
Ptype = eltype(eltype(initialpos))
43-
# TODO fix initial pos if list has points of multiple types
44-
Ptype == Any && error("Please provide list of Point{N,T} with same T")
45-
dim = length(eltype(initialpos))
52+
dim, Ptype = infer_pointtype(initialpos)
53+
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
4654
end
47-
return SFDP{dim,Ptype,typeof(tol)}(tol, C, K, iterations, initialpos, seed)
55+
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
56+
57+
return SFDP{dim,Ptype,typeof(tol)}(tol, C, K, iterations, _initialpos, _pin, seed)
4858
end
4959

5060
function Base.iterate(iter::LayoutIterator{SFDP{Dim,Ptype,T}}) where {Dim,Ptype,T}
5161
algo, adj_matrix = iter.algorithm, iter.adj_matrix
5262
N = size(adj_matrix, 1)
53-
M = length(algo.initialpos)
5463
rng = MersenneTwister(algo.seed)
55-
startpos = Vector{Point{Dim,Ptype}}(undef, N)
56-
# take the first
57-
for i in 1:min(N, M)
58-
startpos[i] = algo.initialpos[i]
59-
end
60-
# fill the rest with random points
61-
for i in (M + 1):N
62-
startpos[i] = 2 .* rand(rng, Point{Dim,Ptype}) .- 1
64+
startpos = [2 .* rand(rng, Point{Dim,Ptype}) .- 1 for _ in 1:N]
65+
66+
for (k, v) in algo.initialpos
67+
startpos[k] = v
6368
end
64-
# iteratorstate: (#iter, energy, step, progress, old pos, stopflag)
65-
return startpos, (1, typemax(T), one(T), 0, startpos, false)
69+
70+
pin = [get(algo.pin, i, SVector{Dim,Bool}(false for _ in 1:Dim)) for i in 1:N]
71+
72+
# iteratorstate: (#iter, energy, step, progress, old pos, pin, stopflag)
73+
return startpos, (1, typemax(T), one(T), 0, startpos, pin, false)
6674
end
6775

6876
function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
6977
algo, adj_matrix = iter.algorithm, iter.adj_matrix
70-
iter, energy0, step, progress, locs0, stopflag = state
78+
iter, energy0, step, progress, locs0, pin, stopflag = state
7179
K, C, tol = algo.K, algo.C, algo.tol
7280

7381
# stop if stopflag (tol reached) or nr of iterations reached
@@ -93,7 +101,8 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
93101
((locs[j] .- locs[i]) / norm(locs[j] .- locs[i])))
94102
end
95103
end
96-
locs[i] = locs[i] .+ step .* (force ./ norm(force))
104+
mask = (!).(pin[i]) # where pin=true mask will multiply with 0
105+
locs[i] = locs[i] .+ (step .* (force ./ norm(force))) .* mask
97106
energy = energy + norm(force)^2
98107
end
99108
step, progress = update_step(step, energy, energy0, progress)
@@ -103,7 +112,7 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
103112
stopflag = true
104113
end
105114

106-
return locs, (iter + 1, energy, step, progress, locs, stopflag)
115+
return locs, (iter + 1, energy, step, progress, locs, pin, stopflag)
107116
end
108117

109118
# Calculate Attractive force

test/runtests.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Graphs
33
using GeometryBasics
44
using DelimitedFiles: readdlm
55
using SparseArrays: sparse
6+
using StaticArrays
67
using Test
78

89
function jagmesh()
@@ -357,4 +358,48 @@ jagmesh_adj = jagmesh()
357358
6 2 4;
358359
2 4 3]
359360
end
361+
362+
@testset "infer ptype" begin
363+
using NetworkLayout: infer_pointtype
364+
@test infer_pointtype(Point2f(1,2)) == (2, Float32)
365+
@test infer_pointtype(Point2(1.0,2)) == (2, Float64)
366+
@test infer_pointtype(Point(1,2,4)) == (3, Int64)
367+
@test_throws ArgumentError infer_pointtype([(1,2), (2,3.2,1)])
368+
@test infer_pointtype([(1,2), (2,3.2)]) == (2, Float64)
369+
370+
dany = Dict(1=>Point2(1,1), 4=>[1.0,2.0], 7=>(1.0, 4.0))
371+
@test infer_pointtype(dany) == (2, Float64)
372+
373+
dany[2] = (1,2,3)
374+
@test_throws ArgumentError infer_pointtype(dany)
375+
end
376+
377+
@testset "Sanitize initialpos pin" begin
378+
using NetworkLayout: _sanitize_initialpos_pin
379+
pos = [(0,0),(1,1),(2,2)]
380+
pin = []
381+
_pos, _pin = _sanitize_initialpos_pin(2, Float64, pos, pin)
382+
@test _pos == Dict(pairs(Point2f.(pos)))
383+
@test _pin == Dict{Int, SVector{2, Bool}}()
384+
385+
pos = [(0,0),(1,1),(2,2)]
386+
pin = Dict(1=>(true,false), 3=>true, 2=>(3.0, 3.0))
387+
_pos, _pin = _sanitize_initialpos_pin(2, Float64, pos, pin)
388+
@test _pos == Dict(1=>Point2f(0,0), 2=>Point2f(3.0,3.0), 3=>Point2f(2.0,2.0))
389+
@test _pin == Dict(1=>SVector(true,false), 2=>SVector(true,true), 3=>SVector(true,true))
390+
end
391+
392+
@testset "test pin" begin
393+
for algo in [sfdp]
394+
g = complete_graph(10)
395+
ep = algo(g; pin=[(0,0), (0,0)])
396+
@test ep[1] == [0,0]
397+
@test ep[2] == [0,0]
398+
399+
ep = algo(g; initialpos=Dict(4=>(0,0,0), 5=>(1,2,3)), pin=Dict(4=>true, 5=>(true, false, true)))
400+
@test ep[4] == [0,0,0]
401+
@test ep[5][[1,3]] == [1,3]
402+
@test ep[5][2] != [2]
403+
end
404+
end
360405
end

0 commit comments

Comments
 (0)