@@ -554,29 +554,27 @@ function init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.Generic
554
554
end
555
555
end
556
556
validate_backends (mpc. gradient, mpc. hessian)
557
- Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! = get_optim_functions (
558
- mpc, optim
559
- )
560
- Jargs = isnothing (∇²Jfunc!) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
561
- @operator (optim, J, nZ̃, Jargs... )
557
+ J_args, g_vec_args, geq_vec_args = get_optim_functions (mpc, optim)
558
+ # display(J_args)
559
+ @operator (optim, J, nZ̃, J_args... )
562
560
@objective (optim, Min, J (Z̃var... ))
563
- init_nonlincon! (mpc, model, transcription, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! )
561
+ init_nonlincon! (mpc, model, transcription, g_vec_args, geq_vec_args )
564
562
set_nonlincon! (mpc, model, transcription, optim)
565
563
return nothing
566
564
end
567
565
568
566
"""
569
567
get_optim_functions(
570
568
mpc::NonLinMPC, optim::JuMP.GenericModel
571
- ) -> Jfunc, ∇Jfunc!, ∇J²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
569
+ ) -> J_args, g_vec_args, geq_vec_args
572
570
573
571
Return the functions for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller.
574
-
575
- Return the nonlinear objective `Jfunc` function, and `∇Jfunc!` and `∇²Jfunc!`, to compute
576
- its gradient and hessian, respectively . Also return vectors with the nonlinear inequality
577
- constraint functions `gfuncs`, and `∇gfuncs!`, for the associated gradients. Lastly, also
578
- return vectors with the nonlinear equality constraint functions `geqfuncs` and gradients
579
- `∇geqfuncs!` .
572
+
573
+ Return the tuple `J_args` containing the functions to compute the objective function
574
+ value and its derivatives . Also return the tuple `g_vec_args` containing 2 vectors of
575
+ functions to compute the nonlinear inequality values and associated gradients. Lastly, also
576
+ return `geq_vec_args` containing 2 vectors of functions to compute the nonlinear equality
577
+ values and associated gradients .
580
578
581
579
This method is really intricate and I'm not proud of it. That's because of 3 elements:
582
580
@@ -630,35 +628,53 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630
628
end
631
629
if ! isnothing (hess)
632
630
prep_∇²J = prepare_hessian (Jfunc!, hess, Z̃_J, context_J... ; strict)
633
- @warn " Here's the objective Hessian sparsity pattern:"
634
631
display (sparsity_pattern (prep_∇²J))
635
632
else
636
633
prep_∇²J = nothing
637
634
end
638
635
∇J = Vector {JNT} (undef, nZ̃)
639
636
∇²J = init_diffmat (JNT, hess, prep_∇²J, nZ̃, nZ̃)
640
637
function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
641
- update_diff_objective ! (
638
+ update_memoized_diff ! (
642
639
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643
640
)
644
641
return J[]:: T
645
642
end
646
- ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
643
+ ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
647
644
function (Z̃arg)
648
- update_diff_objective ! (
645
+ update_memoized_diff ! (
649
646
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650
647
)
651
648
return ∇J[begin ]
652
649
end
653
- else # multivariate syntax (see JuMP.@operator doc):
650
+ else # multivariate syntax (see JuMP.@operator doc):
654
651
function (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
655
- update_diff_objective ! (
652
+ update_memoized_diff ! (
656
653
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657
654
)
658
655
return ∇Jarg .= ∇J
659
656
end
660
657
end
661
- ∇²Jfunc! = nothing
658
+ ∇²Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
659
+ function (Z̃arg)
660
+ update_memoized_diff! (
661
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
662
+ )
663
+ return ∇²J[begin , begin ]
664
+ end
665
+ else # multivariate syntax (see JuMP.@operator doc):
666
+ function (∇²Jarg:: AbstractMatrix{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
667
+ print (" d" )
668
+ update_memoized_diff! (
669
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
670
+ )
671
+ for i in 1 : N, j in 1 : i
672
+ ∇²Jarg[i, j] = ∇²J[i, j]
673
+ end
674
+ return ∇²Jarg
675
+ end
676
+ end
677
+ J_args = isnothing (hess) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
662
678
# --------------------- inequality constraint functions -------------------------------
663
679
function gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
664
680
update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -672,19 +688,13 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
672
688
)
673
689
# temporarily enable all the inequality constraints for sparsity detection:
674
690
mpc. con. i_g[1 : end - nc] .= true
675
- ∇g_prep = prepare_jacobian (gfunc!, g, jac, Z̃_g, context_g... ; strict)
691
+ prep_∇g = prepare_jacobian (gfunc!, g, jac, Z̃_g, context_g... ; strict)
676
692
mpc. con. i_g[1 : end - nc] .= false
677
- ∇g = init_diffmat (JNT, jac, ∇g_prep, nZ̃, ng)
678
- function update_con! (g, ∇g, Z̃, Z̃arg)
679
- if isdifferent (Z̃arg, Z̃)
680
- Z̃ .= Z̃arg
681
- value_and_jacobian! (gfunc!, g, ∇g, ∇g_prep, jac, Z̃, context_g... )
682
- end
683
- end
693
+ ∇g = init_diffmat (JNT, jac, prep_∇g, nZ̃, ng)
684
694
gfuncs = Vector {Function} (undef, ng)
685
695
for i in eachindex (gfuncs)
686
696
gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
687
- update_con! ( g, ∇g, Z̃_g , Z̃arg)
697
+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
688
698
return g[i]:: T
689
699
end
690
700
gfuncs[i] = gfunc_i
@@ -693,17 +703,18 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
693
703
for i in eachindex (∇gfuncs!)
694
704
∇gfuncs_i! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
695
705
function (Z̃arg:: T ) where T<: Real
696
- update_con! ( g, ∇g, Z̃_g , Z̃arg)
706
+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
697
707
return ∇g[i, begin ]
698
708
end
699
- else # multivariate syntax (see JuMP.@operator doc):
709
+ else # multivariate syntax (see JuMP.@operator doc):
700
710
function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
701
- update_con! ( g, ∇g, Z̃_g , Z̃arg)
711
+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
702
712
return ∇g_i .= @views ∇g[i, :]
703
713
end
704
714
end
705
715
∇gfuncs![i] = ∇gfuncs_i!
706
716
end
717
+ g_vec_args = (gfuncs, ∇gfuncs!)
707
718
# --------------------- equality constraint functions ---------------------------------
708
719
function geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
709
720
update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -715,18 +726,14 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
715
726
Cache (Û0), Cache (K0), Cache (X̂0),
716
727
Cache (gc), Cache (g)
717
728
)
718
- ∇geq_prep = prepare_jacobian (geqfunc!, geq, jac, Z̃_geq, context_geq... ; strict)
719
- ∇geq = init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
720
- function update_con_eq! (geq, ∇geq, Z̃, Z̃arg)
721
- if isdifferent (Z̃arg, Z̃)
722
- Z̃ .= Z̃arg
723
- value_and_jacobian! (geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, context_geq... )
724
- end
725
- end
729
+ prep_∇geq = prepare_jacobian (geqfunc!, geq, jac, Z̃_geq, context_geq... ; strict)
730
+ ∇geq = init_diffmat (JNT, jac, prep_∇geq, nZ̃, neq)
726
731
geqfuncs = Vector {Function} (undef, neq)
727
732
for i in eachindex (geqfuncs)
728
733
geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
729
- update_con_eq! (geq, ∇geq, Z̃_geq, Z̃arg)
734
+ update_memoized_diff! (
735
+ Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
736
+ )
730
737
return geq[i]:: T
731
738
end
732
739
geqfuncs[i] = geqfunc_i
@@ -737,12 +744,15 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
737
744
# constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
738
745
∇geqfuncs_i! =
739
746
function (∇geq_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
740
- update_con_eq! (geq, ∇geq, Z̃_geq, Z̃arg)
747
+ update_memoized_diff! (
748
+ Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
749
+ )
741
750
return ∇geq_i .= @views ∇geq[i, :]
742
751
end
743
752
∇geqfuncs![i] = ∇geqfuncs_i!
744
753
end
745
- return Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
754
+ geq_vec_args = (geqfuncs, ∇geqfuncs!)
755
+ return J_args, g_vec_args, geq_vec_args
746
756
end
747
757
748
758
"""
@@ -770,52 +780,6 @@ function update_predictions!(
770
780
return nothing
771
781
end
772
782
773
- """
774
- update_diff_objective!(
775
- Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776
- grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777
- )
778
-
779
- TBW
780
- """
781
- function update_diff_objective! (
782
- Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783
- grad:: AbstractADType , hess:: Nothing , Jfunc!:: F , Z̃arg
784
- ) where F <: Function
785
- if isdifferent (Z̃arg, Z̃_J)
786
- Z̃_J .= Z̃arg
787
- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context... )
788
- end
789
- return nothing
790
- end
791
-
792
- function update_diff_objective! (
793
- Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794
- grad:: Nothing , hess:: AbstractADType , Jfunc!:: F , Z̃arg
795
- ) where F <: Function
796
- if isdifferent (Z̃arg, Z̃_J)
797
- Z̃_J .= Z̃arg
798
- J[], _ = value_gradient_and_hessian! (
799
- Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800
- )
801
- @warn " Uncomment the following line to print the current Hessian"
802
- # println(∇²J)
803
- end
804
- return nothing
805
- end
806
-
807
- function update_diff_objective! (
808
- Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809
- grad:: AbstractADType , hess:: AbstractADType , Jfunc!:: F , Z̃arg
810
- ) where F<: Function
811
- if isdifferent (Z̃arg, Z̃_J)
812
- Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813
- hessian! (Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J... )
814
- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
815
- end
816
- return nothing
817
- end
818
-
819
783
@doc raw """
820
784
con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
821
785
0 commit comments