4.2. 🚧 Harmonic Oscillator Optimization
4.2.1. Problem Statement
In this example, we will demonstrate how to find a parameter of differential equation such that the solution satisfies particular property. Consider this practical problem. We are designing a new musical instrument, we have already decided on its particular shape. We want to hit a particular note and we will do that by finding the right material. To do this we model the musical instrument with harmonic oscillator and we will find the right stifness value that gives us solution with desired frequency.
Of course that in reality we would optimize for the shape rather than for the material but this would yield much harder mathematical problem. Therefore we do it the other way around which is much easier, serves as demonstration on how to use SciLean and we can easily check the numrical answer to the analytical solution.
The harmonic oscillator is governed by the following differential equation
$$m \ddot x = - k x $$
because we will try to find the appropricate k
we will denote \(x(k,t)\) the solution of this equation as a function of time \(t\) and stiffness \(k\).
How can we express the requirement that the solution should hit a particular frequencty \(omega\)? One way to state this requirement is that the solution should return to the same position after time \(T\). We want to find stiffness \(k\) such that
$$x(k,T) = x(k,0) $$
where the time \(T\) is related to the given frequency by
$$T = \frac{2 \pi}{\omega} $$
(You might have noticed that this argument contains a flaw which we kindly ask you to ignore for the moment as we will discuss it at the end.)
4.2.2. Lean Specification
Lets formulate this in Lean now. In the previous example 'Harmonic Oscillator' we have shows how to specify harmonic oscillator using its energy
def H (m k x p : Float) := (1/(2*m)) * p^2 + k/2 * x^2
The solution of the corresponding differential equation is symbolically represented using the function odeSolve
noncomputable
def solution (m k x₀ p₀ t : Float) : Float×Float :=
odeSolve (t₀ := 0) (x₀ := (x₀,p₀)) (t := t)
(fun (t : Float) (x,p) =>
( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
The expression solution m k x₀ p₀ t
is the position and momentum at time t
of harmonic oscillator with mass m
, stiffness k
starting at time 0
at position x₀
and momentum p₀
.
Notice that we had to mark the function solution
with noncomputable
as it is purely symbolic as it can't be executed because it uses the function odeSolve
.
To express that we are looking for the stiffness satisfying a particular property we can use the notation solve x, f x = 0
which will symbolically return the solution x
of the equation f x = 0
.
noncomputable
def optimalStiffness (m x₀ ω : Float) : Float :=
let x := fun k t =>
let (x,p) := solution m k x₀ 0 t
x
let T := 2*π/ω
solve k, x k T = x k 0
The notation solve x, f x = 0
uses the noncomputable function solveFun
which returns a term that satisfies a particular property. This is why we have to mark optimalStiffness
as noncomputable.
Our goal is to have an executable function that can approximatelly compute the symbolic function optimalStiffness
.
We can therefore define function findStiffness
that is an approximation to the specification. In the next section we will show how replace skip
with a sequence of tactics that will turn the specification into executable approximation and allows us to remove the noncomputable
annotation from this fuction.
noncomputable
approx findStiffness (m x₀ ω : Float) :=
let T := 2*π/ω
let y := fun (k : Float) =>
odeSolve (t₀ := 0) (t:=T) (x₀:=(x₀,0))
(fun (t : Float) (x,p) =>
( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
solve k, (y k).1 = x₀
by
m:Floatx₀:Floatω:Float⊢ Approx (?m.2380 m x₀ ω)
(let T := 2 * π / ω;
let y := fun k =>
odeSolve
(fun t x =>
match x with
| (x, p) => (∇ (p':=p), H m k x p', -∇ (x':=x), H m k x' p))
0 T (x₀, 0);
solve k, (y k).1 = x₀)
4.2.3. Turning Specification into Implementation
Unfortunatelly the current documentation tool we are using does not allow mixing a proof with text so we have to write down everything with only brief comments and only then we will go over each tactic and explain what it does.
The are two little differences in the initial specification that are mainly of technical nature. First, we had to add the function holdLet
which is just an identity function and should be ignored. Its purpose is to prevent inlining of y
by the tactic autodiff
. The se
Please go ahead and have a brief look at every step. The short comments should give you rought idea what each tactic is doing and by clicking on each command will display the current state.
approx findStiffness (m x₀ ω k₀ : Float) :=
let T := 2*π/ω
let y := holdLet <| fun (k : Float) =>
odeSolve (t₀ := 0) (t:=T) (x₀:=(x₀,0))
(fun (t : Float) (x,p) =>
( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
solve k, (y k).1 = x₀
by
conv =>
-- focus on the specification
enter[2]
-- Unfold Hamiltonian and compute gradients
m:Floatx₀:Floatω:Floatk₀:Float| let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
match x with
| (x, p) =>
(∇ (p':=p), (1 / (2 * m) * p' ^ 2 + k / 2 * x ^ 2), -∇ (x':=x), (1 / (2 * m) * p ^ 2 + k / 2 * x' ^ 2)))
0 T (x₀, 0);
solve k, (y k).1 = x₀; m:Floatx₀:Floatω:Floatk₀:Float| let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
0 T (x₀, 0);
solve k, (y k).1 = x₀
conv =>
-- focus on solve k, (y k).1 = x₀
enter[T,m:Floatx₀:Floatω:Floatk₀:FloatT:Floaty:Float → Float × Float| solve k, (y k).1 = x₀]
-- reformulate as minimization problem
rw[solve_eq_argmin_norm2 Float (m:Floatx₀:Floatω:Floatk₀:FloatT:Floaty:Float → Float × Float⊢ HasUniqueSolution fun x => (y x).1 = x₀ All goals completed! 🐙)]
-- approximate by gradient descrent
rw[m:Floatx₀:Floatω:Floatk₀:FloatT:Floaty:Float → Float × Float| limit opts ∈ Options.filter,
let f' := holdLet (<∂ x, ‖(y x).1 - x₀‖₂²);
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts) k₀;
r.minimizer]
-- consume limit by `Approx`
approx_limit opts (m:Floatx₀:Floatω:Floatk₀:Float⊢ (let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
0 T (x₀, 0);
limit opts ∈ Options.filter,
let f' := holdLet (<∂ x, ‖(y x).1 - x₀‖₂²);
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts)
k₀;
r.minimizer) =
limit opts ∈ Options.filter,
let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
0 T (x₀, 0);
let f' := holdLet (<∂ x, ‖(y x).1 - x₀‖₂²);
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts)
k₀;
r.minimizer All goals completed! 🐙)
conv =>
-- focus on the specification again
enter[2]
-- rewrite reverse mode AD (<∂) as forward mode AD (∂>)
-- this is possible because we are differentiating scalar function `Float → Float`
m:Floatx₀:Floatω:Floatk₀:Floatopts:Options Float| let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
0 T (x₀, 0);
let f' :=
holdLet fun x =>
let x := ∂> (x:=x;1), ‖(y x).1 - x₀‖₂²;
let y := x.1;
let dy := x.2;
(y, fun dy' => dy * dy');
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts) k₀;
r.minimizer
-- run forward mode AD
-- this will formulate a new ODE that solves for `x`, `p`, `dx/dk` and `dp/dk`
m:Floatx₀:Floatω:Floatk₀:Floatopts:Options Float| let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolve
(fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
0 T (x₀, 0);
let f' :=
holdLet fun x =>
let ydf := (2 * m)⁻¹;
let F :=
holdLet fun t xdx =>
let x_1 := xdx.1;
let dx := xdx.2;
let zdz := x_1.2;
let zdz_1 := dx.2;
let zdz_2 := x / 2;
let zdz_3 := 2⁻¹;
let zdz_4 := x_1.1;
let zdz_5 := dx.1;
let zdz := 2 * zdz;
let zdz_6 := 2 * zdz_1;
let zdz := ydf * zdz;
let zdz_7 := ydf * zdz_6;
let zdz_8 := 2 * zdz_4;
let zdz_9 := 2 * zdz_5;
let zdz_10 := zdz_2 * zdz_8;
let zdz_11 := zdz_2 * zdz_9 + zdz_3 * zdz_8;
let zdz_12 := -zdz_10;
let zdz_13 := -zdz_11;
((zdz, zdz_12), zdz_7, zdz_13);
let xdx := odeSolve F 0 T ((x₀, 0), 0);
let y := xdx.1;
let dy := xdx.2;
let zdz := y.1;
let zdz_1 := dy.1;
let ydy := zdz - x₀;
let zdz := ydy ^ 2;
let zdz_2 := 2 * ⟪zdz_1, ydy⟫;
(zdz, fun dy' => zdz_2 * dy');
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts) k₀;
r.minimizer
-- approximate both ODEs with RK4
m:Floatx₀:Floatω:Floatk₀:Floatopts:Options Float⊢ Approx (Filter.atTop ×ˢ ?m.117588 m x₀ ω k₀)
(let T := 2 * π / ω;
let y :=
holdLet fun k =>
limit n → ∞,
odeSolveFixedStep
(rungeKutta4 fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
n 0 T (x₀, 0);
let f' :=
holdLet fun x =>
let ydf := (2 * m)⁻¹;
let F :=
holdLet fun t xdx =>
let x_1 := xdx.1;
let dx := xdx.2;
let zdz := x_1.2;
let zdz_1 := dx.2;
let zdz_2 := x / 2;
let zdz_3 := 2⁻¹;
let zdz_4 := x_1.1;
let zdz_5 := dx.1;
let zdz := 2 * zdz;
let zdz_6 := 2 * zdz_1;
let zdz := ydf * zdz;
let zdz_7 := ydf * zdz_6;
let zdz_8 := 2 * zdz_4;
let zdz_9 := 2 * zdz_5;
let zdz_10 := zdz_2 * zdz_8;
let zdz_11 := zdz_2 * zdz_9 + zdz_3 * zdz_8;
let zdz_12 := -zdz_10;
let zdz_13 := -zdz_11;
((zdz, zdz_12), zdz_7, zdz_13);
let xdx := limit n → ∞, odeSolveFixedStep (rungeKutta4 F) n 0 T ((x₀, 0), 0);
let y := xdx.1;
let dy := xdx.2;
let zdz := y.1;
let zdz_1 := dy.1;
let ydy := zdz - x₀;
let zdz := ydy ^ 2;
let zdz_2 := 2 * ⟪zdz_1, ydy⟫;
(zdz, fun dy' => zdz_2 * dy');
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts) k₀;
r.minimizer)⊢ Float → Float → Float → Float → Type⊢ (m x₀ ω k₀ : Float) → Filter (ℕ × ?m.117586 m x₀ ω k₀)
-- choose the same number of steps for both ODE solvers
-- and consume the corresponding limin in `Approx`
approx_limit steps (m:Floatx₀:Floatω:Floatk₀:Floatopts:Options Float⊢ (let T := 2 * π / ω;
let y :=
holdLet fun k =>
limit n → ∞,
odeSolveFixedStep
(rungeKutta4 fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
n 0 T (x₀, 0);
let f' :=
holdLet fun x =>
let ydf := (2 * m)⁻¹;
let F :=
holdLet fun t xdx =>
let x_1 := xdx.1;
let dx := xdx.2;
let zdz := x_1.2;
let zdz_1 := dx.2;
let zdz_2 := x / 2;
let zdz_3 := 2⁻¹;
let zdz_4 := x_1.1;
let zdz_5 := dx.1;
let zdz := 2 * zdz;
let zdz_6 := 2 * zdz_1;
let zdz := ydf * zdz;
let zdz_7 := ydf * zdz_6;
let zdz_8 := 2 * zdz_4;
let zdz_9 := 2 * zdz_5;
let zdz_10 := zdz_2 * zdz_8;
let zdz_11 := zdz_2 * zdz_9 + zdz_3 * zdz_8;
let zdz_12 := -zdz_10;
let zdz_13 := -zdz_11;
((zdz, zdz_12), zdz_7, zdz_13);
let xdx := limit n → ∞, odeSolveFixedStep (rungeKutta4 F) n 0 T ((x₀, 0), 0);
let y := xdx.1;
let dy := xdx.2;
let zdz := y.1;
let zdz_1 := dy.1;
let ydy := zdz - x₀;
let zdz := ydy ^ 2;
let zdz_2 := 2 * ⟪zdz_1, ydy⟫;
(zdz, fun dy' => zdz_2 * dy');
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts) k₀;
r.minimizer) =
limit n → ∞,
let T := 2 * π / ω;
let y :=
holdLet fun k =>
odeSolveFixedStep
(rungeKutta4 fun t x =>
let ydf := (2 * m)⁻¹;
let ydf_1 := x.2;
let ydf_2 := k / 2;
let ydf_3 := x.1;
(ydf * (2 * ydf_1), -(ydf_2 * (2 * ydf_3))))
n 0 T (x₀, 0);
let f' :=
holdLet fun x =>
let ydf := (2 * m)⁻¹;
let F :=
holdLet fun t xdx =>
let x_1 := xdx.1;
let dx := xdx.2;
let zdz := x_1.2;
let zdz_1 := dx.2;
let zdz_2 := x / 2;
let zdz_3 := 2⁻¹;
let zdz_4 := x_1.1;
let zdz_5 := dx.1;
let zdz := 2 * zdz;
let zdz_6 := 2 * zdz_1;
let zdz := ydf * zdz;
let zdz_7 := ydf * zdz_6;
let zdz_8 := 2 * zdz_4;
let zdz_9 := 2 * zdz_5;
let zdz_10 := zdz_2 * zdz_8;
let zdz_11 := zdz_2 * zdz_9 + zdz_3 * zdz_8;
let zdz_12 := -zdz_10;
let zdz_13 := -zdz_11;
((zdz, zdz_12), zdz_7, zdz_13);
let xdx := odeSolveFixedStep (rungeKutta4 F) n 0 T ((x₀, 0), 0);
let y := xdx.1;
let dy := xdx.2;
let zdz := y.1;
let zdz_1 := dy.1;
let ydy := zdz - x₀;
let zdz := ydy ^ 2;
let zdz_2 := 2 * ⟪zdz_1, ydy⟫;
(zdz, fun dy' => zdz_2 * dy');
let r :=
optimize (ObjectiveFunction.mk (fun x => ‖(y x).1 - x₀‖₂²) ⋯) (AbstractOptimizer.setOptions Float default opts)
k₀;
r.minimizer All goals completed! 🐙)
#eval (2*π)^2
#eval findStiffness (m:=1) (ω:=2*π) (x₀:=1) (k₀:=10)
({x_abstol := 1e-16, g_abstol := 0, show_trace := true, result_trace := true},200,())
#eval 4*π^2
Original differential equation
$$\begin{align} \dot x &= \frac{p}{m} \\ \dot p &= -k x \\ x(0) &= x_0 \qquad p(0) = p_0 \end{align} $$
Derived differential equation that also solves for two additional variables \(xk\) and \(pk\) which are derivatives of the solution w.r.t. \(k\).
The autodiff
tactic uses the theorem odeSolve.arg_ft₀tx₀.fwdDeriv_rule
that states what happends when we differentiate the solution of differential equation w.r.t. to parameter, inital condition or initial and terminal time.
$$\begin{align} \dot x &= \frac{p}{m} \\ \dot p &= -k x \\ \dot x_k &= \frac{p_k}{m} \\ \dot p_k &= -x - k x_k \\ x(0) = x_0 \qquad p(0) &= p_0 \qquad x_k(0) = 0 \qquad p_k(0) = 0 \end{align} $$