Scientific Computing in Lean

๐Ÿ–จ

Automatic Differentiation

In scientific computing a common requirement is to compute derivatives of a program. The important requirement is that the resulting program computing the derivative should be efficient. The problem with symbolic differentiation is that it can produce expressions which are really inefficient. So called automatic differentiation addresses this issue an computes derivatives that are also efficient to compute.

Let's have a look how symbolic differentiation can produce large expressions.

open SciLean Scalar fun x => let y := x * x * x * x * x * x * x; let y_1 := x * x * x * x * x * x; let y_2 := x * x * x * x * x; let y_3 := x * x * x * x; let y_4 := x * x * x; let y_5 := x * x; let dy := x + x; let dy := y_5 + dy * x; let dy := y_4 + dy * x; let dy := y_3 + dy * x; let dy := y_2 + dy * x; let dy := y_1 + dy * x; y + dy * x : โ„ โ†’ โ„#check (โˆ‚! (x : โ„), x * x * x * x * x * x * x * x)

The last example, symbolic differentiation takes and expression with 9 mulplications and produces an expression with 27 multiplications and 7 additions.

If we instead use forward mode derivative(one type of automatic differentiation which we will explain in a moment)

fun x => let ydy := x * x; let ydy_1 := ydy * x; let ydy_2 := ydy_1 * x; let ydy_3 := ydy_2 * x; let ydy_4 := ydy_3 * x; let ydy_5 := ydy_4 * x; let zdz := ydy_5 * x; fun dx => let ydy_6 := dx * x + dx * x; let ydy := dx * ydy + ydy_6 * x; let ydy := dx * ydy_1 + ydy * x; let ydy := dx * ydy_2 + ydy * x; let ydy := dx * ydy_3 + ydy * x; let ydy := dx * ydy_4 + ydy * x; let zdz_1 := dx * ydy_5 + ydy * x; (zdz, zdz_1) : โ„ โ†’ โ„ โ†’ โ„ ร— โ„#check (โˆ‚>! (x : โ„), x * x * x * x * x * x * x * x)

we obtain 21 multiplications and 7 additions.

The difference might not seems to big but if you look at the shape of the result you can see that the the expression reulting from โˆ‚ contains a big triangular block of multiplications whose size(in terms of multiplications) grows quadratically in the number of multiplications in the original expression. On the other hand when we use โˆ‚>, for every multiplication in the original expression we obtain one line with ydy * x and one line with dx * ydy_i + ydy * dx.

Add table with number of multplications and additions with growing n

Forward Mode

The issue with symbolic differentiation is that the chain rule

$$(f\circ g)'(x) = f'(g(x)) g'(x) $$

repeats \(g\) on the right hand side twice which can lead to doubling of computation everytime we apply chainrule.

The remedy to this problem is to introduce forward mode derivative \( \overrightarrow{\partial} \)

$$\overrightarrow{\partial} f (x,dx) = (f(x), f'(x) dx) $$

The motivation is that the forward mode derivative computes the value \( f(x) \) and the derivative \( f'(x) dx \) once and at the same time. Both values then can be comnsumed by subsequent computation without ever evaluating \(f\) again. The chain rule for forward mode derivative is

$$\overrightarrow{\partial} \left( f \circ g \right) = \overrightarrow{\partial} f \circ \overrightarrow{\partial} g $$

which does not suffer from the problem of repeating \( g \) twice and thus not potentially double the computation.

In SciLean, the forward mode derivative is fwdFDeriv which has notaion โˆ‚>. It is defined as

open SciLean variable (f : โ„ โ†’ โ„) (x dx : โ„) example : (โˆ‚> f x dx) = (f x, โˆ‚ f x dx) :=
f:โ„ โ†’ โ„
x:โ„
dx:โ„
โŠข โˆ‚> f x dx = (f x, (โˆ‚ f x) dx)
All goals completed! ๐Ÿ™

In Lean notation the chain rule can be writen as

open SciLean variable (f g : โ„ โ†’ โ„) (hf : Differentiable โ„ f) (hg : Differentiable โ„ g) (x dx : โ„) example : (โˆ‚> (f โˆ˜ g) x dx) = (let (y,dy) := (โˆ‚> g x dx); (โˆ‚> f y dy)) :=
fโœ:โ„ โ†’ โ„
xโœ:โ„
dxโœ:โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
x:โ„
dx:โ„
โŠข โˆ‚> (f โˆ˜ g) x dx = match โˆ‚> g x dx with | (y, dy) => โˆ‚> f y dy
All goals completed! ๐Ÿ™

Alternativelly, when we use the notation โ†ฟf that uncurries any function we can write the chain rule

open SciLean variable (f g : โ„ โ†’ โ„) (hf : Differentiable โ„ f) (hg : Differentiable โ„ g) example : โ†ฟ(โˆ‚> (f โˆ˜ g)) = โ†ฟ(โˆ‚> f) โˆ˜ โ†ฟ(โˆ‚> g) :=
fโœยน:โ„ โ†’ โ„
xโœ:โ„
dxโœ:โ„
fโœ:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœ
hgโœ:Differentiable โ„ gโœ
x:โ„
dx:โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
โŠข โ†ฟ(โˆ‚> f โˆ˜ g) = โ†ฟ(โˆ‚> f) โˆ˜ โ†ฟ(โˆ‚> g)
All goals completed! ๐Ÿ™

In SciLean it is the theorem SciLean.fwdFDeriv.comp_rule.

autodiff vs fun_trans

So far we have been using the tactic fun_trans to compute derivatives. You might have notices that fun_trans removes all let bindings from the expression. For example

fun x => x * x + (x + x) * x : โ„ โ†’ โ„#check (โˆ‚ (x : โ„), let y := x*x; y*x) rewrite_by
fโœยน:โ„ โ†’ โ„
xโœ:โ„
dxโœ:โ„
fโœ:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœ
hgโœ:Differentiable โ„ gโœ
x:โ„
dx:โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
| fun x => (โˆ‚ (x:=x), let y := x * x; y * x) 1
;
fโœยน:โ„ โ†’ โ„
xโœ:โ„
dxโœ:โ„
fโœ:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœ
hgโœ:Differentiable โ„ gโœ
x:โ„
dx:โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
| fun x => x * x + (x + x) * x

we can alternativelly use autodiff tactic

fun x => let y := x * x; let dy := x + x; let dz := y + dy * x; dz : โ„ โ†’ โ„#check (โˆ‚ (x : โ„), let y := x*x; y*x) rewrite_by
fโœยน:โ„ โ†’ โ„
xโœ:โ„
dxโœ:โ„
fโœ:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœ
hgโœ:Differentiable โ„ gโœ
x:โ„
dx:โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
| fun x => let y := x * x; let dy := x + x; let dz := y + dy * x; dz

The tactic autodiff behaves very similarly to fun_trans but it carefully handled let bindings which is important for generating efficient code. Internally it uses a slightly modified version of Lean's simplifier and carefully configured fun_trans such that it handles let bindings more carefully and efficiently.

From now one, the rule of thumb it to use fun_trans if you do not care about generating efficient code and want to just prove something using derivatives and to use autodiff when you care about the efficiency of the resulting code.

You might have noticed that let bindings are preserved when using โˆ‚! notation. This is because โˆ‚! is using the tactic autodiff instead of fun_trans.

Note that when you want to compute efficient derivatives you have to use autodiff and forward mode derivative โˆ‚>. Using only one of them will not yield the most efficient code.

Relation to Dual Numbers

One common explanation of forward mode derivative is through dual numbers \(a + \epsilon b \). Similarly to complex numbers \( \mathbb{C}\), which extend reals numbers with complex unit \(i\) that squares to negative one, \(i^2 = -1\), the dual numbers \( \overline{\mathbb{R}}\) extend real numbers with \(\epsilon\) that squares to zero, \(\epsilon^2 = 0\).

We can add and multiply two dual numbers

$$\begin{align} (a + \epsilon b) + (c + \epsilon d) &= ((a + c) + \epsilon (b + d)) \\ (a + \epsilon b) (c + \epsilon d) &= ((a c) + \epsilon (b c + a d)) \end{align} $$

Through power series we can also calculate functions like sin, cos or exp

$$\begin{align} \sin(a + \epsilon b) &= \sin(a) + \epsilon b \cos(a) \\ \cos(a + \epsilon b) &= \cos(a) - \epsilon b \sin(a) \\ \exp(a + \epsilon b) &= \exp(a) + \epsilon b \exp(a) \end{align} $$

In general, for an analytical function (f : โ„ โ†’ โ„) we can show that

$$f(a + \epsilon b) = f(a) + \epsilon b f'(a) $$

Every dual number is just a pair of two real numbers i.e. \( \overline{\mathbb{R}} \cong \mathbb{R} \times \mathbb{R} \) therefore we can think about the forward mode derivative \(\overrightarrow{\partial} f\) as extension of \(f\) to dual numbers

$$\overrightarrow{\partial} f (x + \epsilon dx) = f(x) + \epsilon f'(x)dx $$

Explain this for general function f : X โ†’ Y and relate it to complexification of a vector space.

Exercises

  1. In section (??) we defined a function foo

def foo (x : โ„) := x^2 + x

In order to compute derivatives with foo you need to derivative generate theorems using def_fun_trans and def_fun_prop macros. Use def_fun_trans macro to generate forward mode derivative rule for foo.

  1. Newton's method with forward mode AD. In previous chapter (??) we talked about Newton's method. At each iteration we need to compute the function value \( f(x_n)\) and its derivative \( f'(x_n)\). Instead of evaluating f x and โˆ‚! f x use fowrward mode โˆ‚>! f x 1 to compute the function value and its derivative at the same time.

  2. Redo the exercises on Newton's method from chapter (??) using forward mode derivative โˆ‚> instead of normal derivative โˆ‚.

  3. Prove on paper that the chain rule for forward mode derivative is equivalent to the standard chain rule.

Reverse Mode

Forward mode derivative โˆ‚> f is designed to efficiently compute the derivative โˆ‚ f. To efficiently compute the gradient โˆ‡ f we have to introduce reverse mode derivative <โˆ‚ f.

Recall the that definition of the gradient โˆ‡ f is adjoint of the derivative

variable (f : โ„ร—โ„ โ†’ โ„) (x : โ„ร—โ„) example : (โˆ‡ f x) = adjoint โ„ (โˆ‚ f x) 1 :=
fโœยฒ:โ„ โ†’ โ„
xโœยน:โ„
dxโœ:โ„
fโœยน:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœยน
hgโœ:Differentiable โ„ gโœ
xโœ:โ„
dx:โ„
fโœ:โ„ โ†’ โ„
g:โ„ โ†’ โ„
hf:Differentiable โ„ fโœ
hg:Differentiable โ„ g
f:โ„ ร— โ„ โ†’ โ„
x:โ„ ร— โ„
โŠข โˆ‡ f x = adjoint โ„ (โ‡‘(โˆ‚ f x)) 1
All goals completed! ๐Ÿ™

One might naively do the same trick as with forward mode derivative and define reverse mode by putting together the function value and the derivative together i.e. <โˆ‚ f x dy = (f x, adjoint โ„ (โˆ‚ f x) dy). However, it is not possible to write down a good chain rule for this operation. The only way to do this is to postpone the argument dy and define the reverse mode derivative as

variable (f : โ„ โ†’ โ„) (x : โ„)
example : (<โˆ‚ f x) = (f x, fun dy => adjoint โ„ (โˆ‚ f x) dy) := by rfl

The reverse mode derivative at point x computes the value f x and a function that can compute the adjoint of the derivative.

With this definition we can write down the chain rule for reverse mode derivative

variable (f g : โ„ โ†’ โ„) (x : โ„) (hf : Differentiable โ„ f) (hg : Differentiable โ„ g) example : (<โˆ‚ (fun x => f (g x)) x) = let (y,dg') := <โˆ‚ g x let (z,df') := <โˆ‚ f y (z, fun dz => dg' (df' dz)) :=
fโœยณ:โ„ โ†’ โ„
xโœยฒ:โ„
dxโœ:โ„
fโœยฒ:โ„ โ†’ โ„
gโœยน:โ„ โ†’ โ„
hfโœยน:Differentiable โ„ fโœยฒ
hgโœยน:Differentiable โ„ gโœยน
xโœยน:โ„
dx:โ„
fโœยน:โ„ โ†’ โ„
gโœ:โ„ โ†’ โ„
hfโœ:Differentiable โ„ fโœยน
hgโœ:Differentiable โ„ gโœ
fโœ:โ„ ร— โ„ โ†’ โ„
xโœ:โ„ ร— โ„
f:โ„ โ†’ โ„
g:โ„ โ†’ โ„
x:โ„
hf:Differentiable โ„ f
hg:Differentiable โ„ g
โŠข <โˆ‚ (x:=x), f (g x) = match <โˆ‚ g x with | (y, dg') => match <โˆ‚ f y with | (z, df') => (z, fun dz => dg' (df' dz))
All goals completed! ๐Ÿ™

The crucial observation is that the derivative part of <โˆ‚ (fโˆ˜g) composes the derivatives of f and g in the reverse order hence the name reverse mode.

Talk about

  1. what is forward pass and reverse pass

  2. memory requirements of reverse mode

  3. forward mode vs reverse mode

  4. relation to vjp jvp convention

Exercises

  1. reimplement gradient descent using revFDeriv and use f x for computing the improvement

  2. Generate revFDeriv rules for foo.

  3. SDF projection

    • sphere

    • pick something from https://iquilezles.org/articles/distfunctions/

Derivatives of Neural Network Layers

In the chapter 'Working with Arrays' we have constructured a simple neural network. To train it we have to comput its derivative

Differentiate neural network layers from the chapter 'Working with Arrays'