2.2. Automatic Differentiation

In scientific computing a common requirement is to compute derivatives of a program. The important requirement is that the resulting program computing the derivative should be efficient. The problem with symbolic differentiation is that it can produce expressions which are really inefficient. So called automatic differentiation addresses this issue an computes derivatives that are also efficient to compute.

Let's have a look how symbolic differentiation can produce large expressions.

open SciLean Scalar fun x => let y := x * x * x * x * x * x * x; let y_1 := x * x * x * x * x * x; let y_2 := x * x * x * x * x; let y_3 := x * x * x * x; let y_4 := x * x * x; let y_5 := x * x; let dy := x + x; let dy := y_5 + dy * x; let dy := y_4 + dy * x; let dy := y_3 + dy * x; let dy := y_2 + dy * x; let dy := y_1 + dy * x; y + dy * x : ℝ → ℝ#check ∂! (x : ), x * x * x * x * x * x * x * x

The last example, symbolic differentiation takes and expression with 9 mulplications and produces an expression with 27 multiplications and 7 additions.

If we instead use forward mode derivative(one type of automatic differentiation which we will explain in a moment)

fun x dx => let zdz := x * x; let zdz_1 := x * dx + dx * x; let zdz_2 := zdz * x; let zdz := zdz * dx + zdz_1 * x; let zdz_3 := zdz_2 * x; let zdz := zdz_2 * dx + zdz * x; let zdz_4 := zdz_3 * x; let zdz := zdz_3 * dx + zdz * x; let zdz_5 := zdz_4 * x; let zdz := zdz_4 * dx + zdz * x; let zdz_6 := zdz_5 * x; let zdz := zdz_5 * dx + zdz * x; let zdz_7 := zdz_6 * x; let zdz := zdz_6 * dx + zdz * x; (zdz_7, zdz) : ℝ → ℝ → ℝ × ℝ#check ∂>! (x : ), x * x * x * x * x * x * x * x

we obtain 21 multiplications and 7 additions.

The difference might not seems to big but if you look at the shape of the result you can see that the the expression reulting from contains a big triangular block of multiplications whose size(in terms of multiplications) grows quadratically in the number of multiplications in the original expression. On the other hand when we use ∂>, for every multiplication in the original expression we obtain one line with ydy * x and one line with dx * ydy_i + ydy * dx.

2.2.1. Forward Mode

The issue with symbolic differentiation is that the chain rule

$$(f\circ g)'(x) = f'(g(x)) g'(x) $$

repeats \(g\) on the right hand side twice which can lead to doubling of computation everytime we apply chainrule.

The remedy to this problem is to introduce forward mode derivative \( \overrightarrow{\partial} \)

$$\overrightarrow{\partial} f (x,dx) = (f(x), f'(x) dx) $$

The motivation is that the forward mode derivative computes the value \( f(x) \) and the derivative \( f'(x) dx \) once and at the same time. Both values then can be comnsumed by subsequent computation without ever evaluating \(f\) again. The chain rule for forward mode derivative is

$$\overrightarrow{\partial} \left( f \circ g \right) = \overrightarrow{\partial} f \circ \overrightarrow{\partial} g $$

which does not suffer from the problem of repeating \( g \) twice and thus not potentially double the computation.

In SciLean, the forward mode derivative is fwdFDeriv which has notaion ∂>. It is defined as

variable (f g : ) (x dx : ) (hf : Differentiable f) (hg : Differentiable g) example : (∂> f x dx) = (f x, f x dx) := f:g:x:dx:hf:Differentiable fhg:Differentiable g∂> f x dx = (f x, (∂ f x) dx) All goals completed! 🐙

In Lean notation the chain rule can be writen as

example : (∂> (f g) x dx) = let (y,dy) := (∂> g x dx) (∂> f y dy) := f:g:x:dx:hf:Differentiable fhg:Differentiable g∂> (fg) x dx = match ∂> g x dx with | (y, dy) => ∂> f y dy All goals completed! 🐙

Alternativelly, when we use the notation ↿f that uncurries any function we can write the chain rule

example : (∂> (f g)) = (∂> f) (∂> g) := f:g:x:dx:hf:Differentiable fhg:Differentiable g(∂> fg) = (∂> f)(∂> g) All goals completed! 🐙

In SciLean it is the theorem @SciLean.fwdFDeriv.comp_rule.

2.2.1.1. autodiff vs fun_trans

So far we have been using the tactic fun_trans to compute derivatives. You might have notices that fun_trans removes all let bindings from the expression. For example

fun x => x * x + (x + x) * x : ℝ → ℝ#check ( (x : ), let y := x*x; y*x) rewrite_by f:g:x:dx:hf:Differentiable fhg:Differentiable g| fun x => (∂ (x:=x), let y := x * x; y * x) 1; f:g:x:dx:hf:Differentiable fhg:Differentiable g| fun x => x * x + (x + x) * x

we can alternativelly use autodiff tactic

fun x => let y := x * x; let dy := x + x; let dz := y + dy * x; dz : ℝ → ℝ#check ( (x : ), let y := x*x; y*x) rewrite_by f:g:x:dx:hf:Differentiable fhg:Differentiable g| fun x => let y := x * x; let dy := x + x; let dz := y + dy * x; dz

The tactic autodiff behaves very similarly to fun_trans but it carefully handled let bindings which is important for generating efficient code. Internally it uses a slightly modified version of Lean's simplifier and carefully configured fun_trans such that it handles let bindings more carefully and efficiently.

From now one, the rule of thumb it to use fun_trans if you do not care about generating efficient code and want to just prove something using derivatives and to use autodiff when you care about the efficiency of the resulting code.

You might have noticed that let bindings are preserved when using ∂! notation. This is because ∂! is using the tactic autodiff instead of fun_trans.

Note that when you want to compute efficient derivatives you have to use autodiff and forward mode derivative ∂>. Using only one of them will not yield the most efficient code.

2.2.1.2. Relation to Dual Numbers

One common explanation of forward mode derivative is through dual numbers \(a + \epsilon b \). Similarly to complex numbers \( \mathbb{C}\), which extend reals numbers with complex unit \(i\) that squares to negative one, \(i^2 = -1\), the dual numbers \( \overline{\mathbb{R}}\) extend real numbers with \(\epsilon\) that squares to zero, \(\epsilon^2 = 0\).

We can add and multiply two dual numbers

$$\begin{align} (a + \epsilon b) + (c + \epsilon d) &= ((a + c) + \epsilon (b + d)) \\ (a + \epsilon b) (c + \epsilon d) &= ((a c) + \epsilon (b c + a d)) \end{align} $$

Through power series we can also calculate functions like sin, cos or exp

$$\begin{align} \sin(a + \epsilon b) &= \sin(a) + \epsilon b \cos(a) \\ \cos(a + \epsilon b) &= \cos(a) - \epsilon b \sin(a) \\ \exp(a + \epsilon b) &= \exp(a) + \epsilon b \exp(a) \end{align} $$

In general, for an analytical function (f : ℝ → ℝ) we can show that

$$f(a + \epsilon b) = f(a) + \epsilon b f'(a) $$

Every dual number is just a pair of two real numbers i.e. \( \overline{\mathbb{R}} \cong \mathbb{R} \times \mathbb{R} \) therefore we can think about the forward mode derivative \(\overrightarrow{\partial} f\) as extension of \(f\) to dual numbers

$$\overrightarrow{\partial} f (x + \epsilon dx) = f(x) + \epsilon f'(x)dx $$

2.2.1.3. Exercises

  1. In section (??) we defined a function foo

def foo (x : ) := x^2 + x

In order to compute derivatives with foo you need to derivative generate theorems using def_fun_trans and def_fun_prop macros. Use def_fun_trans macro to generate forward mode derivative rule for foo.

  1. Newton's method with forward mode AD. In previous chapter (??) we talked about Newton's method. At each iteration we need to compute the function value \( f(x_n)\) and its derivative \( f'(x_n)\). Instead of evaluating f x and ∂! f x use fowrward mode ∂>! f x 1 to compute the function value and its derivative at the same time.

  2. Redo the exercises on Newton's method from chapter (??) using forward mode derivative ∂> instead of normal derivative .

  3. Prove on paper that the chain rule for forward mode derivative is equivalent to the standard chain rule.

2.2.2. Reverse Mode

Forward mode derivative ∂> f is designed to efficiently compute the derivative ∂ f. To efficiently compute the gradient ∇ f we have to introduce reverse mode derivative <∂ f.

Recall the that definition of the gradient ∇ f is adjoint of the derivative

variable (f : × ) (x : ×) example : ( f x) = adjoint ( f x) 1 := f✝:g:x✝:dx:hf:Differentiable f✝hg:Differentiable gf: × x: × f x = adjoint (⇑(∂ f x)) 1 All goals completed! 🐙

One might naively do the same trick as with forward mode derivative and define reverse mode by putting together the function value and the derivative together i.e. <∂ f x dy = (f x, adjoint ℝ (∂ f x) dy). However, it is not possible to write down a good chain rule for this operation. The only way to do this is to postpone the argument dy and define the reverse mode derivative as

variable (f g : ) (x : ) (hf : Differentiable f) (hg : Differentiable g) example : (<∂ f x) = (f x, fun dy => adjoint (fderiv f x) dy) := f✝¹:g✝:x✝¹:dx:hf✝:Differentiable f✝¹hg✝:Differentiable g✝f✝: × x✝: × f:g:x:hf:Differentiable fhg:Differentiable g<∂ f x = (f x, fun dy => adjoint (⇑(∂ f x)) dy) All goals completed! 🐙

The reverse mode derivative at point x computes the value f x and a function that can compute the adjoint of the derivative.

With this definition we can write down the chain rule for reverse mode derivative

variable (f g : ) (x : ) (hf : Differentiable f) (hg : Differentiable g) example : (<∂ (fun x => f (g x)) x) = let (y,dg') := <∂ g x let (z,df') := <∂ f y (z, fun dz => dg' (df' dz)) := f✝²:g✝¹:x✝²:dx:hf✝¹:Differentiable f✝²hg✝¹:Differentiable g✝¹f✝¹: × x✝¹: × f✝:g✝:x✝:hf✝:Differentiable f✝hg✝:Differentiable g✝f:g:x:hf:Differentiable fhg:Differentiable g<∂ (x:=x), f (g x) = match <∂ g x with | (y, dg') => match <∂ f y with | (z, df') => (z, fun dz => dg' (df' dz)) All goals completed! 🐙

The crucial observation is that the derivative part of <∂ (f∘g) composes the derivatives of f and g in the reverse order hence the name reverse mode.

2.2.2.1. Exercises

  1. reimplement gradient descent using revFDeriv and use f x for computing the improvement

  2. Generate revFDeriv rules for foo.

  3. SDF projection

    • sphere

    • pick something from https://iquilezles.org/articles/distfunctions/