Skip to content

Commit b9dcd23

Browse files
committed
fixes
1 parent 2942615 commit b9dcd23

File tree

4 files changed

+110
-119
lines changed

4 files changed

+110
-119
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ function DiffEqBase.interp_summary(::Type{cacheType},
99
"1st order linear"
1010
end
1111

12-
function DiffEqBase.interp_summary(::Type{cacheType},
12+
function DiffEqBase.interp_summary(cache::Type{cacheType},
1313
dense::Bool) where {
1414
cacheType <:
1515
Union{RosenbrockCombinedConstantCache,
1616
RosenbrockCache}}
17-
dense ? "specialized $(cache.interp_order) order \"free\" stiffness-aware interpolation" :
17+
dense ? "specialized ? order \"free\" stiffness-aware interpolation" :
1818
"1st order linear"
1919
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,54 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u)
88
(cache.fsalfirst, cache.fsallast)
99
end
1010

11+
mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
12+
TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
13+
u::uType
14+
uprev::uType
15+
dense::Vector{rateType}
16+
du::rateType
17+
du1::rateType
18+
du2::rateType
19+
ks::Vector{rateType}
20+
fsalfirst::rateType
21+
fsallast::rateType
22+
dT::rateType
23+
J::JType
24+
W::WType
25+
tmp::rateType
26+
atmp::uNoUnitsType
27+
weight::uNoUnitsType
28+
tab::TabType
29+
tf::TFType
30+
uf::UFType
31+
linsolve_tmp::rateType
32+
linsolve::F
33+
jac_config::JCType
34+
grad_config::GCType
35+
reltol::RTolType
36+
alg::A
37+
algebraic_vars::AV
38+
step_limiter!::StepLimiter
39+
stage_limiter!::StageLimiter
40+
interp_order::Int
41+
end
42+
43+
function full_cache(c::RosenbrockCache)
44+
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
45+
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
46+
end
47+
48+
struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
49+
tf::TF
50+
uf::UF
51+
tab::Tab
52+
J::JType
53+
W::WType
54+
linsolve::F
55+
autodiff::AD
56+
interp_order::Int
57+
end
58+
1159
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
1260
TabType, TFType, UFType, F, JCType, GCType,
1361
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
@@ -74,6 +122,10 @@ end
74122
stage_limiter!::StageLimiter
75123
end
76124

125+
function get_fsalfirstlast(cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, u)
126+
(cache.fsalfirst, cache.fsallast)
127+
end
128+
77129
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
78130
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
79131
dt, reltol, p, calck,
@@ -222,57 +274,6 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
222274
alg_autodiff(alg))
223275
end
224276

225-
################################################################################
226-
227-
# Shampine's Low-order Rosenbrocks
228-
mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
229-
TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
230-
u::uType
231-
uprev::uType
232-
dense::Vector{rateType}
233-
du::rateType
234-
du1::rateType
235-
du2::rateType
236-
ks::Vector{rateType}
237-
fsalfirst::rateType
238-
fsallast::rateType
239-
dT::rateType
240-
J::JType
241-
W::WType
242-
tmp::rateType
243-
atmp::uNoUnitsType
244-
weight::uNoUnitsType
245-
tab::TabType
246-
tf::TFType
247-
uf::UFType
248-
linsolve_tmp::rateType
249-
linsolve::F
250-
jac_config::JCType
251-
grad_config::GCType
252-
reltol::RTolType
253-
alg::A
254-
algebraic_vars::AV
255-
step_limiter!::StepLimiter
256-
stage_limiter!::StageLimiter
257-
interp_order::Int
258-
end
259-
260-
function full_cache(c::RosenbrockCache)
261-
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
262-
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
263-
end
264-
265-
struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
266-
tf::TF
267-
uf::UF
268-
tab::Tab
269-
J::JType
270-
W::WType
271-
linsolve::F
272-
autodiff::AD
273-
interp_order::Int
274-
end
275-
276277
@ROS2(:cache)
277278

278279
################################################################################
@@ -296,9 +297,6 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W)
296297

297298
###############################################################################
298299

299-
### Rodas methods
300-
tabtype(::Rosenbrock23) = Rosenbrock23Tableau
301-
tabtype(::Rosenbrock32) = Rosenbrock32Tableau
302300
tabtype(::Rodas23W) = Rodas23WTableau
303301
tabtype(::ROS3P) = ROS3PTableau
304302
tabtype(::Rodas3) = Rodas3Tableau

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ end
434434
@muladd function perform_step!(integrator, cache::RosenbrockCombinedConstantCache, repeat_step = false)
435435
(;t, dt, uprev, u, f, p) = integrator
436436
(;tf, uf) = cache
437-
(;A, C, gamma, c, d, H) = cache.tab
437+
(;A, C, b, btilde, gamma, c, d, H) = cache.tab
438438

439439
# Precalculations
440440
dtC = C ./ dt
@@ -489,10 +489,17 @@ end
489489
integrator.stats.nsolve += 1
490490
end
491491
#@show ks
492-
u = u .+ ks[num_stages]
492+
u = uprev
493+
for i in 1:num_stages
494+
u = @.. u + b[i] * ks[i]
495+
end
493496

494497
if integrator.opts.adaptive
495-
atmp = calculate_residuals(ks[num_stages], uprev, u, integrator.opts.abstol,
498+
utilde = uprev
499+
for i in 1:num_stages
500+
utilde = @.. utilde + btilde[i] * ks[i]
501+
end
502+
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
496503
integrator.opts.reltol, integrator.opts.internalnorm, t)
497504
integrator.EEst = integrator.opts.internalnorm(atmp, t)
498505
end
@@ -538,7 +545,7 @@ end
538545
@muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false)
539546
(; t, dt, uprev, u, f, p) = integrator
540547
(; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
541-
(; A, C, gamma, c, d, H) = cache.tab
548+
(; A, C, b, btilde, gamma, c, d, H) = cache.tab
542549

543550
# Assignments
544551
sizeu = size(u)
@@ -549,6 +556,7 @@ end
549556
dtC = C .* inv(dt)
550557
dtd = dt .* d
551558
dtgamma = dt * gamma
559+
utilde = du
552560

553561
f(cache.fsalfirst, uprev, p, t)
554562
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@@ -572,8 +580,8 @@ end
572580

573581
@.. $(_vec(ks[1])) = -linres.u
574582
integrator.stats.nsolve += 1
575-
576-
for stage in 2:length(ks)
583+
num_stages = length(ks)
584+
for stage in 2:num_stages
577585
u .= uprev
578586
for i in 1:(stage - 1)
579587
@.. u += A[stage, i] * ks[i]
@@ -601,19 +609,25 @@ end
601609
@.. $(_vec(ks[stage])) = -linres.u
602610
integrator.stats.nsolve += 1
603611
end
604-
du .= ks[end]
605-
u .+= ks[end]
612+
u .= uprev
613+
for i in 1:num_stages
614+
@.. u += b[i] * ks[i]
615+
end
606616

607617
step_limiter!(u, integrator, p, t + dt)
608618

609619
if integrator.opts.adaptive
620+
utilde .= 0
621+
for i in 1:num_stages
622+
@.. utilde += btilde[i] * ks[i]
623+
end
610624
if (integrator.alg isa Rodas5Pe)
611625
@.. du = 0.2606326497975715 * ks[1] - 0.005158627295444251 * ks[2] +
612626
1.3038988631109731 * ks[3] + 1.235000722062074 * ks[4] +
613627
-0.7931985603795049 * ks[5] - 1.005448461135913 * ks[6] -
614628
0.18044626132120234 * ks[7] + 0.17051519239113755 * ks[8]
615629
end
616-
calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol,
630+
calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol,
617631
integrator.opts.reltol, integrator.opts.internalnorm, t)
618632
integrator.EEst = integrator.opts.internalnorm(atmp, t)
619633
end

0 commit comments

Comments
 (0)