Skip to content

Commit 8dbef60

Browse files
committed
Improve bang plots
1 parent 673a22f commit 8dbef60

File tree

1 file changed

+66
-34
lines changed

1 file changed

+66
-34
lines changed

src/plotting.jl

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ Renaming and reexport of Plot.jl function `plotlyjs()` to define PlotlyJS.jl as
1818
"""
1919
plotlyjs_backend = StatsPlots.plotlyjs
2020

21+
const _LAST_PLOTS = Ref{Any}(nothing)
22+
23+
_store_last_plots!(p) = (_LAST_PLOTS[] = p)
24+
2125

2226

2327
"""
@@ -53,7 +57,8 @@ If occasionally binding constraints are present in the model, they are not taken
5357
- $MAX_ELEMENTS_PER_LEGENDS_ROW®
5458
- $EXTRA_LEGEND_SPACE®
5559
- $PLOT_ATTRIBUTES®
56-
- `line_labels` [Optional, `Vector{String}`]: legend labels for the shocks.
60+
- `line_label` [Optional, `String`]: legend label for the lines produced by a
61+
single call. When omitted, no legend entry is added.
5762
- $QME®
5863
- $SYLVESTER®
5964
- $LYAPUNOV®
@@ -407,6 +412,7 @@ function plot_model_estimates(𝓂::ℳ,
407412
end
408413
end
409414

415+
_store_last_plots!(return_plots)
410416
return return_plots
411417
end
412418

@@ -499,7 +505,7 @@ function plot_irf(𝓂::ℳ;
499505
initial_state::Union{Vector{Vector{Float64}},Vector{Float64}} = [0.0],
500506
ignore_obc::Bool = false,
501507
plot_attributes::Dict = Dict(),
502-
line_labels::Union{Nothing,Vector{String}} = nothing,
508+
line_label::Union{Nothing,String} = nothing,
503509
verbose::Bool = false,
504510
tol::Tolerances = Tolerances(),
505511
quadratic_matrix_equation_algorithm::Symbol = :schur,
@@ -555,18 +561,16 @@ function plot_irf(𝓂::ℳ;
555561
obc_shocks_included = stochastic_model && obc_model && (intersect((((shock_idx isa Vector) || (shock_idx isa UnitRange)) && (length(shock_idx) > 0)) ? 𝓂.timings.exo[shock_idx] : [𝓂.timings.exo[shock_idx]], 𝓂.timings.exo[contains.(string.(𝓂.timings.exo),"ᵒᵇᶜ")]) != [])
556562
end
557563

558-
if line_labels === nothing
559-
if shocks == :simulate
560-
line_labels = ["simulate all"]
564+
if line_label === nothing
565+
if (shocks isa Union{Symbol_input,String_input}) && (length(shock_idx) == 1)
566+
line_label = replace_indices_in_symbol(𝓂.timings.exo[shock_idx])
567+
elseif shocks == :simulate
568+
line_label = "simulate all"
561569
elseif shocks == :none
562-
line_labels = ["none"]
563-
elseif shocks isa Union{Symbol_input,String_input}
564-
line_labels = replace_indices_in_symbol.(𝓂.timings.exo[shock_idx])
570+
line_label = "none"
565571
else
566-
line_labels = ["shock" * string(i) for i in 1:length(shock_idx)]
572+
line_label = ""
567573
end
568-
else
569-
@assert length(line_labels) == length(shock_idx) "line_labels must match number of shocks"
570574
end
571575

572576
if shocks isa KeyedArray{Float64} || shocks isa Matrix{Float64}
@@ -786,6 +790,7 @@ function plot_irf(𝓂::ℳ;
786790
pp = []
787791
pane = 1
788792
plot_count = 1
793+
label_done = false
789794
for i in 1:length(var_idx)
790795
if all(isapprox.(Y[i,:,shock], 0, atol = eps(Float32)))
791796
n_subplots -= 1
@@ -798,11 +803,12 @@ function plot_irf(𝓂::ℳ;
798803
can_dual_axis = gr_back && all((Y[i,:,shock] .+ SS) .> eps(Float32)) && (SS > eps(Float32))
799804

800805
if !(all(isapprox.(Y[i,:,shock],0,atol = eps(Float32))))
806+
label_here = (!label_done && line_label != "") ? line_label : ""
801807
push!(pp,begin
802808
StatsPlots.plot(Y[i,:,shock] .+ SS,
803809
title = replace_indices_in_symbol(𝓂.timings.var[var_idx[i]]),
804810
ylabel = "Level",
805-
label = line_labels[shock])
811+
label = label_here)
806812

807813
if can_dual_axis
808814
StatsPlots.plot!(StatsPlots.twinx(),
@@ -850,6 +856,7 @@ function plot_irf(𝓂::ℳ;
850856
pane += 1
851857

852858
pp = []
859+
label_done = false
853860
end
854861
end
855862
end
@@ -880,8 +887,10 @@ function plot_irf(𝓂::ℳ;
880887
if save_plots
881888
StatsPlots.savefig(p, save_plots_path * "/irf__" * 𝓂.model_name * "__" * shock_name * "__" * string(pane) * "." * string(save_plots_format))
882889
end
890+
label_done = false
883891
end
884892
end
893+
_store_last_plots!(return_plots)
885894

886895
return return_plots
887896
end
@@ -1123,6 +1132,7 @@ function plot_conditional_variance_decomposition(𝓂::ℳ;
11231132
StatsPlots.savefig(p, save_plots_path * "/fevd__" * 𝓂.model_name * "__" * string(pane) * "." * string(save_plots_format))
11241133
end
11251134
end
1135+
_store_last_plots!(return_plots)
11261136

11271137
return return_plots
11281138
end
@@ -1494,6 +1504,7 @@ function plot_solution(𝓂::ℳ,
14941504
end
14951505
end
14961506

1507+
_store_last_plots!(return_plots)
14971508
return return_plots
14981509
end
14991510

@@ -1866,6 +1877,7 @@ function plot_conditional_forecast(𝓂::ℳ,
18661877
StatsPlots.savefig(p, save_plots_path * "/conditional_forecast__" * 𝓂.model_name * "__" * string(pane) * "." * string(save_plots_format))
18671878
end
18681879
end
1880+
_store_last_plots!(return_plots)
18691881

18701882
return return_plots
18711883

@@ -1878,8 +1890,8 @@ Add the IRFs produced by [`plot_irf`](@ref) to the existing plot or vector of
18781890
plots `p` using `StatsPlots.plot!`.
18791891
18801892
Calling `plot_irf!(args...; kwargs...)` without providing `p` attempts to add
1881-
the lines to the current plot. Additional pages are appended if required and
1882-
legend labels are taken from the shock names. Subplots are matched by title when
1893+
the lines to the previous plot. Additional pages are appended if required and
1894+
legend labels are derived from the `line_label` keyword. Subplots are matched by title when
18831895
merging so the order of variables does not matter. Titles from both plots are
18841896
collected, combined into a sorted list and used to align the panels.
18851897
"""
@@ -1920,13 +1932,17 @@ function plot_irf!(p::Union{StatsPlots.Plot,AbstractVector}, args...; kwargs...)
19201932
else
19211933
p = _merge_plots_by_title(p, q[1])
19221934
end
1935+
_store_last_plots!(p)
19231936
return p
19241937
end
19251938
function plot_irf!(args...; kwargs...)
1926-
p = try
1927-
StatsPlots.current()
1928-
catch
1929-
nothing
1939+
p = _LAST_PLOTS[]
1940+
if p === nothing
1941+
p = try
1942+
StatsPlots.current()
1943+
catch
1944+
nothing
1945+
end
19301946
end
19311947
if p === nothing
19321948
return plot_irf(args...; kwargs...)
@@ -1953,14 +1969,18 @@ function plot_model_estimates!(p::Union{StatsPlots.Plot,AbstractVector}, args...
19531969
else
19541970
StatsPlots.plot!(p, q[1])
19551971
end
1972+
_store_last_plots!(p)
19561973
return p
19571974
end
19581975

19591976
function plot_model_estimates!(args...; kwargs...)
1960-
p = try
1961-
StatsPlots.current()
1962-
catch
1963-
nothing
1977+
p = _LAST_PLOTS[]
1978+
if p === nothing
1979+
p = try
1980+
StatsPlots.current()
1981+
catch
1982+
nothing
1983+
end
19641984
end
19651985
if p === nothing
19661986
return plot_model_estimates(args...; kwargs...)
@@ -1988,14 +2008,18 @@ function plot_conditional_variance_decomposition!(p::Union{StatsPlots.Plot,Abstr
19882008
else
19892009
StatsPlots.plot!(p, q[1])
19902010
end
2011+
_store_last_plots!(p)
19912012
return p
19922013
end
19932014

19942015
function plot_conditional_variance_decomposition!(args...; kwargs...)
1995-
p = try
1996-
StatsPlots.current()
1997-
catch
1998-
nothing
2016+
p = _LAST_PLOTS[]
2017+
if p === nothing
2018+
p = try
2019+
StatsPlots.current()
2020+
catch
2021+
nothing
2022+
end
19992023
end
20002024
if p === nothing
20012025
return plot_conditional_variance_decomposition(args...; kwargs...)
@@ -2020,14 +2044,18 @@ function plot_solution!(p::Union{StatsPlots.Plot,AbstractVector}, args...; kwarg
20202044
else
20212045
StatsPlots.plot!(p, q[1])
20222046
end
2047+
_store_last_plots!(p)
20232048
return p
20242049
end
20252050

20262051
function plot_solution!(args...; kwargs...)
2027-
p = try
2028-
StatsPlots.current()
2029-
catch
2030-
nothing
2052+
p = _LAST_PLOTS[]
2053+
if p === nothing
2054+
p = try
2055+
StatsPlots.current()
2056+
catch
2057+
nothing
2058+
end
20312059
end
20322060
if p === nothing
20332061
return plot_solution(args...; kwargs...)
@@ -2053,14 +2081,18 @@ function plot_conditional_forecast!(p::Union{StatsPlots.Plot,AbstractVector}, ar
20532081
else
20542082
StatsPlots.plot!(p, q[1])
20552083
end
2084+
_store_last_plots!(p)
20562085
return p
20572086
end
20582087

20592088
function plot_conditional_forecast!(args...; kwargs...)
2060-
p = try
2061-
StatsPlots.current()
2062-
catch
2063-
nothing
2089+
p = _LAST_PLOTS[]
2090+
if p === nothing
2091+
p = try
2092+
StatsPlots.current()
2093+
catch
2094+
nothing
2095+
end
20642096
end
20652097
if p === nothing
20662098
return plot_conditional_forecast(args...; kwargs...)

0 commit comments

Comments
 (0)