Skip to content

Commit cce387b

Browse files
authored
Rework the run loop (#921)
1 parent 2752420 commit cce387b

File tree

33 files changed

+151
-230
lines changed

33 files changed

+151
-230
lines changed

docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14804,7 +14804,7 @@ <h2 id="Understand-the-Trajectories">Understand the <em>Trajectories</em><a clas
1480414804
<div class="prompt input_prompt">In&nbsp;[28]:</div>
1480514805
<div class="inner_cell">
1480614806
<div class="input_area">
14807-
<div class=" highlight hl-julia"><pre><span></span><span class="n">t</span> <span class="o">=</span> <span class="n">Trajectories</span><span class="o">.</span><span class="n">CircularArraySARTTraces</span><span class="p">(;</span><span class="n">capacity</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
14807+
<div class=" highlight hl-julia"><pre><span></span><span class="n">t</span> <span class="o">=</span> <span class="n">Trajectories</span><span class="o">.</span><span class="n">CircularArraySARTSTraces</span><span class="p">(;</span><span class="n">capacity</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
1480814808
</pre></div>
1480914809

1481014810
</div>

docs/src/How_to_implement_a_new_algorithm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ A `Trajectory` is composed of three elements: a `container`, a `controller`, and
9494

9595
The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists.
9696

97-
The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.
97+
The most common `AbstractTraces` object is the `CircularArraySARTSTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which together are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.
9898

9999
```julia
100100
function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype)

docs/src/Zoo_Algorithms/MPO.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ The next step is to wrap this policy into an `Agent`. An agent is a combination
5656

5757
```julia
5858
trajectory = Trajectory(
59-
CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,),action = Float32 => (1,)),
59+
CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
6060
MetaSampler(
6161
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
6262
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 1000)

src/ReinforcementLearningCore/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
3-
version = "0.11.3"
3+
version = "0.12.0"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -40,7 +40,7 @@ Parsers = "2"
4040
ProgressMeter = "1"
4141
Reexport = "1"
4242
ReinforcementLearningBase = "0.12"
43-
ReinforcementLearningTrajectories = "^0.1.9"
43+
ReinforcementLearningTrajectories = "^0.3.2"
4444
StatsBase = "0.32, 0.33, 0.34"
4545
TimerOutputs = "0.5"
4646
UnicodePlots = "1.3, 2, 3"

src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module ReinforcementLearningCore
33
using TimerOutputs
44
using ReinforcementLearningBase
55
using Reexport
6-
76
const RLCore = ReinforcementLearningCore
87

98
export RLCore

src/ReinforcementLearningCore/src/core/run.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,17 @@ function _run(policy::AbstractPolicy,
102102
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
103103
@timeit_debug timer "act!" act!(env, action)
104104

105-
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
105+
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
106106
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
107107
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
108108

109109
if check_stop(stop_condition, policy, env)
110110
is_stop = true
111-
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
112-
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
113-
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
114-
@timeit_debug timer "plan!" RLBase.plan!(policy, env) # let the policy see the last observation
115111
break
116112
end
117113
end # end of an episode
118114

119-
@timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
115+
@timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env)
120116
@timeit_debug timer "optimise! PostEpisodeStage" optimise!(policy, PostEpisodeStage())
121117
@timeit_debug timer "push!(hook) PostEpisodeStage" push!(hook, PostEpisodeStage(), policy, env)
122118

src/ReinforcementLearningCore/src/core/stages.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ struct PreActStage <: AbstractStage end
1717
struct PostActStage <: AbstractStage end
1818

1919
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing
20+
Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action) = nothing
2021
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing
22+
Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Symbol) = nothing
2123

2224
RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing
2325

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
include("base.jl")
1+
include("agent_base.jl")
22
include("agent_srt_cache.jl")
33
include("multi_agent.jl")
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
export Agent
2+
3+
using Base.Threads: @spawn
4+
5+
using Functors: @functor
6+
import Base.push!
7+
"""
8+
Agent(;policy, trajectory) <: AbstractPolicy
9+
10+
A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to
11+
update the trajectory and policy appropriately in different stages. Agent
12+
is a Callable and its call method accepts varargs and keyword arguments to be
13+
passed to the policy.
14+
15+
"""
16+
mutable struct Agent{P,T} <: AbstractPolicy
17+
policy::P
18+
trajectory::T
19+
20+
function Agent(policy::P, trajectory::T) where {P<:AbstractPolicy, T<:Trajectory}
21+
agent = new{P,T}(policy, trajectory)
22+
23+
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
24+
bind(trajectory, @spawn(optimise!(policy, trajectory)))
25+
end
26+
agent
27+
end
28+
end
29+
30+
Agent(;policy, trajectory) = Agent(policy, trajectory)
31+
32+
RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
33+
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory)
34+
35+
# already spawn a task to optimise inner policy when initializing the agent
36+
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing
37+
38+
#by default, optimise does nothing at all stage
39+
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
40+
41+
@functor Agent (policy,)
42+
43+
function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv)
44+
push!(agent.trajectory, (state = state(env),))
45+
end
46+
47+
# !!! TODO: In async scenarios, parameters of the policy may still be updating
48+
# (partially), which will result to incorrect action. This should be addressed
49+
# in Oolong.jl with a wrapper
50+
function RLBase.plan!(agent::Agent, env::AbstractEnv)
51+
RLBase.plan!(agent.policy, env)
52+
end
53+
54+
function Base.push!(agent::Agent, ::PostActStage, env::AbstractEnv, action)
55+
next_state = state(env)
56+
push!(agent.trajectory, (state = next_state, action = action, reward = reward(env), terminal = is_terminated(env)))
57+
end
58+
59+
function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv)
60+
if haskey(agent.trajectory, :next_action)
61+
action = RLBase.plan!(agent.policy, env)
62+
push!(agent.trajectory, PartialNamedTuple((action = action, )))
63+
end
64+
end

src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ struct SART{S,A,R,T}
2727
end
2828

2929
# This method is used to push a state and action to a trace
30-
function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SA)
30+
function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SA)
3131
push!(ts.traces[1].trace, xs.state)
3232
push!(ts.traces[2].trace, xs.action)
3333
end
3434

35-
function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SART)
35+
function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SART)
3636
push!(ts.traces[1].trace, xs.state)
3737
push!(ts.traces[2].trace, xs.action)
3838
push!(ts.traces[3], xs.reward)

src/ReinforcementLearningCore/src/policies/agent/base.jl

Lines changed: 0 additions & 89 deletions
This file was deleted.

src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,12 @@ function Base.run(
125125
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
126126
@timeit_debug timer "act!" act!(env, action)
127127

128-
129-
130-
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
128+
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
131129
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
132130
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
133131

134132
if check_stop(stop_condition, policy, env)
135133
is_stop = true
136-
@timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
137-
@timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
138-
@timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
139-
@timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
140134
break
141135
end
142136

@@ -191,21 +185,43 @@ function Base.push!(multiagent::MultiAgentPolicy, stage::S, env::E) where {S<:Ab
191185
end
192186
end
193187

194-
# Like in the single-agent case, push! at the PreActStage() calls push! on each player with the state of the environment
195-
function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PreActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
188+
# Like in the single-agent case, push! at the PostActStage() calls push! on each player.
189+
function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Symbol)
190+
push!(agent.trajectory, (state = state(env, player),))
191+
end
192+
193+
function Base.push!(multiagent::MultiAgentPolicy, s::PreEpisodeStage, env::E) where {E<:AbstractEnv}
196194
for player in players(env)
197-
push!(multiagent[player], state(env, player))
195+
push!(multiagent[player], s, env, player)
198196
end
199197
end
200198

201-
# Like in the single-agent case, push! at the PostActStage() calls push! on each player with the reward and termination status of the environment
202-
function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PostActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
203-
for player in players(env)
204-
push!(multiagent[player].cache, reward(env, player), is_terminated(env))
199+
function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Symbol)
200+
RLBase.plan!(agent.policy, env, player)
201+
end
202+
203+
# Like in the single-agent case, push! at the PostActStage() calls push! on each player to store the action, reward, next_state, and terminal signal.
204+
function Base.push!(multiagent::MultiAgentPolicy, ::PostActStage, env::E, actions) where {E<:AbstractEnv}
205+
for (player, action) in zip(players(env), actions)
206+
next_state = state(env, player)
207+
observation = (
208+
state = next_state,
209+
action = action,
210+
reward = reward(env, player),
211+
terminal = is_terminated(env)
212+
)
213+
push!(multiagent[player].trajectory, observation)
214+
end
215+
end
216+
217+
function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, p::Symbol)
218+
if haskey(agent.trajectory, :next_action)
219+
action = RLBase.plan!(agent.policy, env, p)
220+
push!(agent.trajectory, PartialNamedTuple((action = action, )))
205221
end
206222
end
207223

208-
function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv,S<:AbstractStage}
224+
function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv, S<:AbstractStage}
209225
for player in players(env)
210226
push!(hook[player], stage, multiagent[player], env, player)
211227
end
@@ -227,8 +243,9 @@ function Base.push!(composed_hook::ComposedHook{T},
227243
_push!(stage, policy, env, player, composed_hook.hooks...)
228244
end
229245

246+
#For simultaneous players, plan! returns a Tuple of actions.
230247
function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
231-
return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
248+
return Tuple(RLBase.plan!(multiagent[player], env, player) for player in players(env))
232249
end
233250

234251
function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}

0 commit comments

Comments
 (0)