Skip to content

Commit 84a22d4

Browse files
committed
wip: cleanup NonLinMPC optim functions
1 parent 8e94d75 commit 84a22d4

File tree

3 files changed

+133
-111
lines changed

3 files changed

+133
-111
lines changed

src/controller/nonlinmpc.jl

Lines changed: 53 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -554,29 +554,27 @@ function init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.Generic
554554
end
555555
end
556556
validate_backends(mpc.gradient, mpc.hessian)
557-
Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! = get_optim_functions(
558-
mpc, optim
559-
)
560-
Jargs = isnothing(∇²Jfunc!) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
561-
@operator(optim, J, nZ̃, Jargs...)
557+
J_args, g_vec_args, geq_vec_args = get_optim_functions(mpc, optim)
558+
#display(J_args)
559+
@operator(optim, J, nZ̃, J_args...)
562560
@objective(optim, Min, J(Z̃var...))
563-
init_nonlincon!(mpc, model, transcription, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!)
561+
init_nonlincon!(mpc, model, transcription, g_vec_args, geq_vec_args)
564562
set_nonlincon!(mpc, model, transcription, optim)
565563
return nothing
566564
end
567565

568566
"""
569567
get_optim_functions(
570568
mpc::NonLinMPC, optim::JuMP.GenericModel
571-
) -> Jfunc, ∇Jfunc!, ∇J²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
569+
) -> J_args, g_vec_args, geq_vec_args
572570
573571
Return the functions for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller.
574-
575-
Return the nonlinear objective `Jfunc` function, and `∇Jfunc!` and `∇²Jfunc!`, to compute
576-
its gradient and hessian, respectively. Also return vectors with the nonlinear inequality
577-
constraint functions `gfuncs`, and `∇gfuncs!`, for the associated gradients. Lastly, also
578-
return vectors with the nonlinear equality constraint functions `geqfuncs` and gradients
579-
`∇geqfuncs!`.
572+
573+
Return the tuple `J_args` containing the functions to compute the objective function
574+
value and its derivatives. Also return the tuple `g_vec_args` containing 2 vectors of
575+
functions to compute the nonlinear inequality values and associated gradients. Lastly, also
576+
return `geq_vec_args` containing 2 vectors of functions to compute the nonlinear equality
577+
values and associated gradients.
580578
581579
This method is really intricate and I'm not proud of it. That's because of 3 elements:
582580
@@ -630,35 +628,53 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630628
end
631629
if !isnothing(hess)
632630
prep_∇²J = prepare_hessian(Jfunc!, hess, Z̃_J, context_J...; strict)
633-
@warn "Here's the objective Hessian sparsity pattern:"
634631
display(sparsity_pattern(prep_∇²J))
635632
else
636633
prep_∇²J = nothing
637634
end
638635
∇J = Vector{JNT}(undef, nZ̃)
639636
∇²J = init_diffmat(JNT, hess, prep_∇²J, nZ̃, nZ̃)
640637
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
641-
update_diff_objective!(
638+
update_memoized_diff!(
642639
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643640
)
644641
return J[]::T
645642
end
646-
∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
643+
∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
647644
function (Z̃arg)
648-
update_diff_objective!(
645+
update_memoized_diff!(
649646
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650647
)
651648
return ∇J[begin]
652649
end
653-
else # multivariate syntax (see JuMP.@operator doc):
650+
else # multivariate syntax (see JuMP.@operator doc):
654651
function (∇Jarg::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
655-
update_diff_objective!(
652+
update_memoized_diff!(
656653
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657654
)
658655
return ∇Jarg .= ∇J
659656
end
660657
end
661-
∇²Jfunc! = nothing
658+
∇²Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
659+
function (Z̃arg)
660+
update_memoized_diff!(
661+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
662+
)
663+
return ∇²J[begin, begin]
664+
end
665+
else # multivariate syntax (see JuMP.@operator doc):
666+
function (∇²Jarg::AbstractMatrix{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
667+
print("d")
668+
update_memoized_diff!(
669+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
670+
)
671+
for i in 1:N, j in 1:i
672+
∇²Jarg[i, j] = ∇²J[i, j]
673+
end
674+
return ∇²Jarg
675+
end
676+
end
677+
J_args = isnothing(hess) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
662678
# --------------------- inequality constraint functions -------------------------------
663679
function gfunc!(g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
664680
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -672,19 +688,13 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
672688
)
673689
# temporarily enable all the inequality constraints for sparsity detection:
674690
mpc.con.i_g[1:end-nc] .= true
675-
∇g_prep = prepare_jacobian(gfunc!, g, jac, Z̃_g, context_g...; strict)
691+
prep_∇g = prepare_jacobian(gfunc!, g, jac, Z̃_g, context_g...; strict)
676692
mpc.con.i_g[1:end-nc] .= false
677-
∇g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
678-
function update_con!(g, ∇g, Z̃, Z̃arg)
679-
if isdifferent(Z̃arg, Z̃)
680-
Z̃ .= Z̃arg
681-
value_and_jacobian!(gfunc!, g, ∇g, ∇g_prep, jac, Z̃, context_g...)
682-
end
683-
end
693+
∇g = init_diffmat(JNT, jac, prep_∇g, nZ̃, ng)
684694
gfuncs = Vector{Function}(undef, ng)
685695
for i in eachindex(gfuncs)
686696
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
687-
update_con!(g, ∇g, Z̃_g, Z̃arg)
697+
update_memoized_diff!(Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc!, Z̃arg)
688698
return g[i]::T
689699
end
690700
gfuncs[i] = gfunc_i
@@ -693,17 +703,18 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
693703
for i in eachindex(∇gfuncs!)
694704
∇gfuncs_i! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
695705
function (Z̃arg::T) where T<:Real
696-
update_con!(g, ∇g, Z̃_g, Z̃arg)
706+
update_memoized_diff!(Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc!, Z̃arg)
697707
return ∇g[i, begin]
698708
end
699-
else # multivariate syntax (see JuMP.@operator doc):
709+
else # multivariate syntax (see JuMP.@operator doc):
700710
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
701-
update_con!(g, ∇g, Z̃_g, Z̃arg)
711+
update_memoized_diff!(Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc!, Z̃arg)
702712
return ∇g_i .= @views ∇g[i, :]
703713
end
704714
end
705715
∇gfuncs![i] = ∇gfuncs_i!
706716
end
717+
g_vec_args = (gfuncs, ∇gfuncs!)
707718
# --------------------- equality constraint functions ---------------------------------
708719
function geqfunc!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
709720
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -715,18 +726,14 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
715726
Cache(Û0), Cache(K0), Cache(X̂0),
716727
Cache(gc), Cache(g)
717728
)
718-
∇geq_prep = prepare_jacobian(geqfunc!, geq, jac, Z̃_geq, context_geq...; strict)
719-
∇geq = init_diffmat(JNT, jac, ∇geq_prep, nZ̃, neq)
720-
function update_con_eq!(geq, ∇geq, Z̃, Z̃arg)
721-
if isdifferent(Z̃arg, Z̃)
722-
Z̃ .= Z̃arg
723-
value_and_jacobian!(geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, context_geq...)
724-
end
725-
end
729+
prep_∇geq = prepare_jacobian(geqfunc!, geq, jac, Z̃_geq, context_geq...; strict)
730+
∇geq = init_diffmat(JNT, jac, prep_∇geq, nZ̃, neq)
726731
geqfuncs = Vector{Function}(undef, neq)
727732
for i in eachindex(geqfuncs)
728733
geqfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
729-
update_con_eq!(geq, ∇geq, Z̃_geq, Z̃arg)
734+
update_memoized_diff!(
735+
Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
736+
)
730737
return geq[i]::T
731738
end
732739
geqfuncs[i] = geqfunc_i
@@ -737,12 +744,15 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
737744
# constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
738745
∇geqfuncs_i! =
739746
function (∇geq_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
740-
update_con_eq!(geq, ∇geq, Z̃_geq, Z̃arg)
747+
update_memoized_diff!(
748+
Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
749+
)
741750
return ∇geq_i .= @views ∇geq[i, :]
742751
end
743752
∇geqfuncs![i] = ∇geqfuncs_i!
744753
end
745-
return Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
754+
geq_vec_args = (geqfuncs, ∇geqfuncs!)
755+
return J_args, g_vec_args, geq_vec_args
746756
end
747757

748758
"""
@@ -770,52 +780,6 @@ function update_predictions!(
770780
return nothing
771781
end
772782

773-
"""
774-
update_diff_objective!(
775-
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776-
grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777-
)
778-
779-
TBW
780-
"""
781-
function update_diff_objective!(
782-
Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783-
grad::AbstractADType, hess::Nothing, Jfunc!::F, Z̃arg
784-
) where F <: Function
785-
if isdifferent(Z̃arg, Z̃_J)
786-
Z̃_J .= Z̃arg
787-
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context...)
788-
end
789-
return nothing
790-
end
791-
792-
function update_diff_objective!(
793-
Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794-
grad::Nothing, hess::AbstractADType, Jfunc!::F, Z̃arg
795-
) where F <: Function
796-
if isdifferent(Z̃arg, Z̃_J)
797-
Z̃_J .= Z̃arg
798-
J[], _ = value_gradient_and_hessian!(
799-
Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800-
)
801-
@warn "Uncomment the following line to print the current Hessian"
802-
# println(∇²J)
803-
end
804-
return nothing
805-
end
806-
807-
function update_diff_objective!(
808-
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809-
grad::AbstractADType, hess::AbstractADType, Jfunc!::F, Z̃arg
810-
) where F<: Function
811-
if isdifferent(Z̃arg, Z̃_J)
812-
Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813-
hessian!(Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J...)
814-
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J...)
815-
end
816-
return nothing
817-
end
818-
819783
@doc raw"""
820784
con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
821785

src/controller/transcription.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -604,21 +604,18 @@ end
604604

605605
"""
606606
init_nonlincon!(
607-
mpc::PredictiveController, model::LinModel, transcription::TranscriptionMethod,
608-
gfuncs , ∇gfuncs!,
609-
geqfuncs, ∇geqfuncs!
610-
)
607+
mpc::PredictiveController, ::LinModel, ::TranscriptionMethod, g_vec_args, geq_vec_args
608+
) -> nothing
611609
612610
Init nonlinear constraints for [`LinModel`](@ref) for all [`TranscriptionMethod`](@ref).
613611
614612
The only nonlinear constraints are the custom inequality constraints `gc`.
615613
"""
616614
function init_nonlincon!(
617-
mpc::PredictiveController, ::LinModel, ::TranscriptionMethod,
618-
gfuncs, ∇gfuncs!,
619-
_ , _
615+
mpc::PredictiveController, ::LinModel, ::TranscriptionMethod, g_vec_args, _
620616
)
621617
optim, con = mpc.optim, mpc.con
618+
gfuncs, ∇gfuncs! = g_vec_args
622619
nZ̃ = length(mpc.Z̃)
623620
if length(con.i_g) 0
624621
i_base = 0
@@ -634,22 +631,20 @@ end
634631

635632
"""
636633
init_nonlincon!(
637-
mpc::PredictiveController, model::NonLinModel, transcription::MultipleShooting,
638-
gfuncs, ∇gfuncs!,
639-
geqfuncs, ∇geqfuncs!
640-
)
634+
mpc::PredictiveController, ::NonLinModel, ::MultipleShooting, g_vec_args, geq_vec_args
635+
) -> nothing
641636
642637
Init nonlinear constraints for [`NonLinModel`](@ref) and [`MultipleShooting`](@ref).
643638
644639
The nonlinear constraints are the output prediction `Ŷ` bounds, the custom inequality
645640
constraints `gc` and all the nonlinear equality constraints `geq`.
646641
"""
647642
function init_nonlincon!(
648-
mpc::PredictiveController, ::NonLinModel, ::MultipleShooting,
649-
gfuncs, ∇gfuncs!,
650-
geqfuncs, ∇geqfuncs!
643+
mpc::PredictiveController, ::NonLinModel, ::MultipleShooting, g_vec_args, geq_vec_args
651644
)
652645
optim, con = mpc.optim, mpc.con
646+
gfuncs , ∇gfuncs! = g_vec_args
647+
geqfuncs, ∇geqfuncs! = geq_vec_args
653648
ny, nx̂, Hp, nZ̃ = mpc.estim.model.ny, mpc.estim.nx̂, mpc.Hp, length(mpc.Z̃)
654649
# --- nonlinear inequality constraints ---
655650
if length(con.i_g) 0
@@ -691,20 +686,19 @@ end
691686

692687
"""
693688
init_nonlincon!(
694-
mpc::PredictiveController, model::NonLinModel, ::SingleShooting,
695-
gfuncs, ∇gfuncs!,
696-
geqfuncs, ∇geqfuncs!
697-
)
689+
mpc::PredictiveController, ::NonLinModel, ::SingleShooting, g_vec_args, geq_vec_args
690+
) -> nothing
698691
699692
Init nonlinear constraints for [`NonLinModel`](@ref) and [`SingleShooting`](@ref).
700693
701694
The nonlinear constraints are the custom inequality constraints `gc`, the output
702695
prediction `Ŷ` bounds and the terminal state `x̂end` bounds.
703696
"""
704697
function init_nonlincon!(
705-
mpc::PredictiveController, ::NonLinModel, ::SingleShooting, gfuncs, ∇gfuncs!, _ , _
698+
mpc::PredictiveController, ::NonLinModel, ::SingleShooting, g_vec_args, _
706699
)
707700
optim, con = mpc.optim, mpc.con
701+
gfuncs, ∇gfuncs! = g_vec_args
708702
ny, nx̂, Hp, nZ̃ = mpc.estim.model.ny, mpc.estim.nx̂, mpc.Hp, length(mpc.Z̃)
709703
if length(con.i_g) 0
710704
i_base = 0

0 commit comments

Comments
 (0)