Skip to content

Commit ec34889

Browse files
authored
Merge branch 'main' into refactor-multiAgent-proposal
2 parents 798a5dd + 1f7f347 commit ec34889

File tree

5 files changed

+120
-72
lines changed

5 files changed

+120
-72
lines changed

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,23 +249,42 @@ function (model::CovGaussianNetwork)(state::AbstractMatrix, action::AbstractMatr
249249
return dropdims(output, dims=2)
250250
end
251251

252+
"""
253+
cholesky_matrix_to_vector_index(i, j)
254+
255+
Return the position in a cholesky_vec (of length da) of the element of the lower triangular matrix at coordinates (i,j).
256+
257+
For example if `cholesky_vec = [1,2,3,4,5,6]`, the corresponding lower triangular matrix is
258+
```
259+
L = [1 0 0
260+
2 4 0
261+
3 5 6]
262+
```
263+
and `cholesky_matrix_to_vector_index(3, 2) == 5`
264+
265+
"""
266+
cholesky_matrix_to_vector_index(i, j, da) = ((2da - j) * (j - 1)) ÷ 2 + i
267+
softplusbeta(x, beta = 10f0) = log(exp(x/beta) +1)*beta #a softer softplus to avoid vanishing values
268+
269+
function cholesky_columns(cholesky_vec, j, batch_size, da) #return a slice (da x 1 x batchsize) containing the jth columns of the lower triangular cholesky decomposition of the covariance
270+
diag_idx = cholesky_matrix_to_vector_index(j, j, da)
271+
tc_diag = softplusbeta.(cholesky_vec[diag_idx:diag_idx, :, :]) .+ 1f-5
272+
other_idxs = cholesky_matrix_to_vector_index(j, j, da)+1:cholesky_matrix_to_vector_index(j + 1, j + 1, da)-1 #indices of elements between two diagonal elements
273+
tc_other = cholesky_vec[other_idxs, :, :]
274+
zs = ignore_derivatives() do
275+
zs = similar(cholesky_vec, da - size(tc_other, 1) - 1, 1, batch_size)
276+
zs .= zero(eltype(cholesky_vec))
277+
return zs
278+
end
279+
[zs; tc_diag; tc_other]
280+
end
281+
252282
"""
253283
Transform a vector containing the non-zero elements of a lower triangular da x da matrix into that matrix.
254284
"""
255285
function vec_to_tril(cholesky_vec, da)
256-
batch_size = size(cholesky_vec, 3)
257-
c2idx(i, j) = ((2da - j) * (j - 1)) ÷ 2 + i #return the position in cholesky_vec of the element of the triangular matrix at coordinates (i,j)
258-
function f(j) #return a slice (da x 1 x batchsize) containing the jth columns of the lower triangular cholesky decomposition of the covariance
259-
tc_diag = softplus.(cholesky_vec[c2idx(j, j):c2idx(j, j), :, :])
260-
tc_other = cholesky_vec[c2idx(j, j)+1:c2idx(j + 1, j + 1)-1, :, :]
261-
zs = ignore_derivatives() do
262-
zs = similar(cholesky_vec, da - size(tc_other, 1) - 1, 1, batch_size)
263-
zs .= zero(eltype(cholesky_vec))
264-
return zs
265-
end
266-
[zs; tc_diag; tc_other]
267-
end
268-
return mapreduce(f, hcat, 1:da)
286+
batch_size = size(cholesky_vec, 3)
287+
return mapreduce(j->cholesky_columns(cholesky_vec, j, batch_size, da), hcat, 1:da)
269288
end
270289

271290
#####

src/ReinforcementLearningCore/test/utils/networks.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,30 @@ using Flux: params, gradient, unsqueeze
171171
end
172172
end
173173
@testset "CovGaussianNetwork" begin
174+
@testset "utility functions" begin
175+
cholesky_vec = [1:6;]
176+
cholesky_mat = [RLCore.softplusbeta(1) 0 0; 2 RLCore.softplusbeta(4) 0; 3 5 RLCore.softplusbeta(6)]
177+
@test RLCore.vec_to_tril(cholesky_vec, 3) cholesky_mat
178+
for i in 1:3, j in 1:i
179+
inds_mat = [1 0 0; 2 4 0; 3 5 6]
180+
@test RLCore.cholesky_matrix_to_vector_index(i, j, 3) == inds_mat[i,j]
181+
end
182+
for x in -10:10
183+
@test RLCore.softplusbeta(x,1) softplus(x) log(exp(x) +1)
184+
end
185+
for x in -10:10
186+
@test RLCore.softplusbeta(x,2) log(exp(x/2) +1)*2 >= softplus(x)
187+
end
188+
for x in -10:10
189+
@test RLCore.softplusbeta(x,0.5) log(exp(x/0.5) +1)*0.5 <= softplus(x)
190+
end
191+
cholesky_mats = stack([cholesky_mat for _ in 1:5], dims = 3)
192+
cholesky_vecs = stack([reshape(cholesky_vec, :, 1) for _ in 1:5], dims = 3)
193+
@test RLCore.vec_to_tril(cholesky_vecs, 3) cholesky_mats
194+
for i in 1:3
195+
@test RLCore.cholesky_columns(cholesky_vecs, i, 5, 3) reshape(cholesky_mats[:, i, :], 3, 1, :)
196+
end
197+
end
174198
@testset "identity normalizer" begin
175199
pre = Dense(20,15)
176200
μ = Dense(15,10)
Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
using .PyCall
22

3-
43
np = pyimport("numpy")
54

6-
export PettingzooEnv
5+
export PettingZooEnv
6+
77

88
"""
9-
PettingzooEnv(;kwargs...)
9+
PettingZooEnv(;kwargs...)
1010
11-
`PettingzooEnv` is an interface of the python library pettingzoo for multi agent reinforcement learning environments. It can be used to test multi
11+
`PettingZooEnv` is an interface of the python library Pettingzoo for multi agent reinforcement learning environments. It can be used to test multi
1212
agent reinforcement learning algorithms implemented in JUlia ReinforcementLearning.
1313
"""
14-
function PettingzooEnv(name::String; seed=123, args...)
14+
15+
function PettingZooEnv(name::String; seed=123, args...)
1516
if !PyCall.pyexists("pettingzoo.$name")
1617
error("Cannot import pettingzoo.$name")
1718
end
@@ -20,7 +21,7 @@ function PettingzooEnv(name::String; seed=123, args...)
2021
pyenv.reset(seed=seed)
2122
obs_space = space_transform(pyenv.observation_space(pyenv.agents[1]))
2223
act_space = space_transform(pyenv.action_space(pyenv.agents[1]))
23-
env = PettingzooEnv{typeof(act_space),typeof(obs_space),typeof(pyenv)}(
24+
env = PettingZooEnv{typeof(act_space),typeof(obs_space),typeof(pyenv)}(
2425
pyenv,
2526
obs_space,
2627
act_space,
@@ -33,13 +34,12 @@ end
3334

3435
# basic function needed for simulation ========================================================================
3536

36-
function RLBase.reset!(env::PettingzooEnv)
37+
function RLBase.reset!(env::PettingZooEnv)
3738
pycall!(env.state, env.pyenv.reset, PyObject, env.seed)
38-
env.ts = 1
3939
nothing
4040
end
4141

42-
function RLBase.is_terminated(env::PettingzooEnv)
42+
function RLBase.is_terminated(env::PettingZooEnv)
4343
_, _, t, d, _ = pycall(env.pyenv.last, PyObject)
4444
t || d
4545
end
@@ -48,96 +48,96 @@ end
4848

4949
## State / observation implementations ========================================================================
5050

51-
RLBase.state(env::PettingzooEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players)
51+
RLBase.state(env::PettingZooEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players)
5252

5353

5454
# partial observability is default for pettingzoo
55-
function RLBase.state(env::PettingzooEnv, ::Observation{Any}, player)
55+
function RLBase.state(env::PettingZooEnv, ::Observation{Any}, player)
5656
env.pyenv.observe(player)
5757
end
5858

5959

6060
## state space =========================================================================================================================================
6161

62-
RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, players) = Space(Dict(player => state_space(env, player) for player in players))
62+
RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, players) = Space(Dict(player => state_space(env, player) for player in players))
6363

6464
# partial observability
65-
RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.observation_space(player))
65+
RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::Symbol) = space_transform(env.pyenv.observation_space(String(player)))
6666

6767
# for full observability. Be careful: action_space has also to be adjusted
68-
# RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.state_space)
68+
# RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.state_space)
6969

7070

7171
## action space implementations ====================================================================================
7272

73-
RLBase.action_space(env::PettingzooEnv, players::Tuple{String}) =
73+
RLBase.action_space(env::PettingZooEnv, players::Tuple{Symbol}) =
7474
Space(Dict(p => action_space(env, p) for p in players))
7575

76-
RLBase.action_space(env::PettingzooEnv, player::String) = space_transform(env.pyenv.action_space(player))
76+
RLBase.action_space(env::PettingZooEnv, player::Symbol) = space_transform(env.pyenv.action_space(String(player)))
7777

78-
RLBase.action_space(env::PettingzooEnv, player::Integer) = space_transform(env.pyenv.action_space(env.pyenv.agents[player]))
78+
RLBase.action_space(env::PettingZooEnv, player::Integer) = space_transform(env.pyenv.action_space(env.pyenv.agents[player]))
7979

80-
RLBase.action_space(env::PettingzooEnv, player::DefaultPlayer) = env.action_space
80+
RLBase.action_space(env::PettingZooEnv, player::DefaultPlayer) = env.action_space
8181

8282
## action functions ========================================================================================================================
8383

84-
function RLBase.act!(env::PettingzooEnv, actions::Dict, players::Tuple)
85-
@assert length(actions) == length(players)
86-
env.ts += 1
87-
for p in players
88-
env(actions[p])
84+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Int})
85+
@assert length(actions) == length(players(env))
86+
for p in env.pyenv.agents
87+
pycall(env.pyenv.step, PyObject, actions[p])
8988
end
9089
end
9190

92-
function RLBase.act!(env::PettingzooEnv, actions::Dict, player)
93-
@assert length(actions) == length(players(env))
94-
for p in players(env)
95-
env(actions[p])
91+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Real})
92+
@assert length(actions) == length(env.pyenv.agents)
93+
for p in env.pyenv.agents
94+
pycall(env.pyenv.step, PyObject, np.array(actions[p]; dtype=np.float32))
9695
end
9796
end
9897

99-
function RLBase.act!(env::PettingzooEnv, actions::Dict{String, Int})
100-
@assert length(actions) == length(players(env))
98+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Vector})
99+
@assert length(actions) == length(env.pyenv.agents)
101100
for p in env.pyenv.agents
102-
pycall(env.pyenv.step, PyObject, actions[p])
101+
RLBase.act!(env, p)
103102
end
104103
end
105104

106-
function RLBase.act!(env::PettingzooEnv, actions::Dict{String, Real})
107-
@assert length(actions) == length(players(env))
108-
env.ts += 1
109-
for p in env.pyenv.agents
110-
pycall(env.pyenv.step, PyObject, np.array(actions[p]; dtype=np.float32))
105+
function RLBase.act!(env::PettingZooEnv, actions::NamedTuple)
106+
@assert length(actions) == length(env.pyenv.agents)
107+
for player players(env)
108+
RLBase.act!(env, actions[player])
111109
end
112110
end
113111

114-
function RLBase.act!(env::PettingzooEnv, action::Vector)
112+
# for vectors, pettingzoo need them to be in proper numpy type
113+
function RLBase.act!(env::PettingZooEnv, action::Vector)
115114
pycall(env.pyenv.step, PyObject, np.array(action; dtype=np.float32))
116115
end
117116

118-
function RLBase.act!(env::PettingzooEnv, action::Integer)
119-
env.ts += 1
117+
function RLBase.act!(env::PettingZooEnv, action)
120118
pycall(env.pyenv.step, PyObject, action)
121119
end
122120

123121
# reward of player ======================================================================================================================
124-
function RLBase.reward(env::PettingzooEnv, player::String)
125-
env.pyenv.rewards[player]
122+
function RLBase.reward(env::PettingZooEnv, player::Symbol)
123+
env.pyenv.rewards[String(player)]
126124
end
127125

128126

129127
# Multi agent part =========================================================================================================================================
130128

131129

132-
RLBase.players(env::PettingzooEnv) = env.pyenv.agents
130+
RLBase.players(env::PettingZooEnv) = Symbol.(env.pyenv.agents)
131+
132+
function RLBase.current_player(env::PettingZooEnv)
133+
return Symbol(env.pyenv.agents[env.current_player])
134+
end
133135

134-
function RLBase.current_player(env::PettingzooEnv, post_action=false)
135-
cur_id = env.ts % length(env.pyenv.agents) == 0 ? length(env.pyenv.agents) : env.ts % length(env.pyenv.agents)
136-
cur_id = post_action ? (cur_id - 1 == 0 ? length(env.pyenv.agents) : cur_id - 1) : cur_id
137-
return env.pyenv.agents[cur_id]
136+
function RLBase.next_player!(env::PettingZooEnv)
137+
env.current_player = env.current_player < length(env.pyenv.agents) ? env.current_player + 1 : 1
138138
end
139139

140-
function RLBase.NumAgentStyle(env::PettingzooEnv)
140+
function RLBase.NumAgentStyle(env::PettingZooEnv)
141141
n = length(env.pyenv.agents)
142142
if n == 1
143143
SingleAgent()
@@ -146,9 +146,8 @@ function RLBase.NumAgentStyle(env::PettingzooEnv)
146146
end
147147
end
148148

149-
150-
RLBase.DynamicStyle(::PettingzooEnv) = SEQUENTIAL
151-
RLBase.ActionStyle(::PettingzooEnv) = MINIMAL_ACTION_SET
152-
RLBase.InformationStyle(::PettingzooEnv) = IMPERFECT_INFORMATION
153-
RLBase.ChanceStyle(::PettingzooEnv) = EXPLICIT_STOCHASTIC
149+
RLBase.DynamicStyle(::PettingZooEnv) = SIMULTANEOUS
150+
RLBase.ActionStyle(::PettingZooEnv) = MINIMAL_ACTION_SET
151+
RLBase.InformationStyle(::PettingZooEnv) = IMPERFECT_INFORMATION
152+
RLBase.ChanceStyle(::PettingZooEnv) = EXPLICIT_STOCHASTIC
154153

src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
mutable struct PettingzooEnv{Ta,To,P} <: AbstractEnv
2-
pyenv::P
3-
observation_space::To
4-
action_space::Ta
5-
state::P
6-
seed::Union{Int, Nothing}
7-
ts::Int
1+
# Parametrization:
2+
# Ta : Type of action_space
3+
# To : Type of observation_space
4+
# P : Type of environment most common: PyObject
5+
6+
mutable struct PettingZooEnv{Ta,To,P} <: AbstractEnv
7+
pyenv::P
8+
observation_space::To
9+
action_space::Ta
10+
state::P
11+
seed::Union{Int, Nothing}
12+
current_player::Int
813
end
14+
915
export PettingzooEnv
1016

1117
struct GymEnv{T,Ta,To,P} <: AbstractEnv

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ end
224224

225225
function solve_mpodual(Q::AbstractArray, ϵ)
226226
g(η) = η * ϵ + η * mean(logsumexp( Q ./η .- Float32(log(size(Q, 2))), dims = 2))
227-
Optim.minimizer(optimize(g, eps(ϵ), 10f0))
227+
Optim.minimizer(optimize(g, eps(ϵ), maximum(abs.(Q))))
228228
end
229229

230230
#For CovGaussianNetwork

0 commit comments

Comments
 (0)