Automatic Differentiation
In scientific computing a common requirement is to compute derivatives of a program. The important requirement is that the resulting program computing the derivative should be efficient. The problem with symbolic differentiation is that it can produce expressions which are really inefficient. So called automatic differentiation addresses this issue an computes derivatives that are also efficient to compute.
Let's have a look how symbolic differentiation can produce large expressions.
open SciLean Scalar
#check (โ! (x : โ), x * x * x * x * x * x * x * x)
The last example, symbolic differentiation takes and expression with 9 mulplications and produces an expression with 27 multiplications and 7 additions.
If we instead use forward mode derivative(one type of automatic differentiation which we will explain in a moment)
#check (โ>! (x : โ), x * x * x * x * x * x * x * x)
we obtain 21 multiplications and 7 additions.
The difference might not seems to big but if you look at the shape of the result you can see that the the expression reulting from โ
contains a big triangular block of multiplications whose size(in terms of multiplications) grows quadratically in the number of multiplications in the original expression. On the other hand when we use โ>
, for every multiplication in the original expression we obtain one line with ydy * x
and one line with dx * ydy_i + ydy * dx
.
Add table with number of multplications and additions with growing n
Forward Mode
The issue with symbolic differentiation is that the chain rule
$$(f\circ g)'(x) = f'(g(x)) g'(x) $$
repeats \(g\) on the right hand side twice which can lead to doubling of computation everytime we apply chainrule.
The remedy to this problem is to introduce forward mode derivative \( \overrightarrow{\partial} \)
$$\overrightarrow{\partial} f (x,dx) = (f(x), f'(x) dx) $$
The motivation is that the forward mode derivative computes the value \( f(x) \) and the derivative \( f'(x) dx \) once and at the same time. Both values then can be comnsumed by subsequent computation without ever evaluating \(f\) again. The chain rule for forward mode derivative is
$$\overrightarrow{\partial} \left( f \circ g \right) = \overrightarrow{\partial} f \circ \overrightarrow{\partial} g $$
which does not suffer from the problem of repeating \( g \) twice and thus not potentially double the computation.
In SciLean, the forward mode derivative is fwdFDeriv
which has notaion โ>
. It is defined as
open SciLean
variable (f : โ โ โ) (x dx : โ)
example : (โ> f x dx) = (f x, โ f x dx) := f : โ โ โ x : โ dx : โ
โข โ> f x dx = (f x, (โ f x) dx) All goals completed! ๐
In Lean notation the chain rule can be writen as
open SciLean
variable (f g : โ โ โ) (hf : Differentiable โ f) (hg : Differentiable โ g) (x dx : โ)
example : (โ> (f โ g) x dx) = (let (y,dy) := (โ> g x dx); (โ> f y dy)) := fโ : โ โ โ xโ : โ dxโ : โ f : โ โ โ g : โ โ โ hf : Differentiable โ f hg : Differentiable โ g x : โ dx : โ
โข โ> (f โ g) x dx =
match โ> g x dx with
| (y, dy) => โ> f y dy All goals completed! ๐
Alternativelly, when we use the notation โฟf
that uncurries any function we can write the chain rule
open SciLean
variable (f g : โ โ โ) (hf : Differentiable โ f) (hg : Differentiable โ g)
example : โฟ(โ> (f โ g)) = โฟ(โ> f) โ โฟ(โ> g) := fโยน : โ โ โ xโ : โ dxโ : โ fโ : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโ hgโ : Differentiable โ gโ x : โ dx : โ f : โ โ โ g : โ โ โ hf : Differentiable โ f hg : Differentiable โ g
โข โฟ(โ> f โ g) = โฟ(โ> f) โ โฟ(โ> g) All goals completed! ๐
In SciLean it is the theorem SciLean.fwdFDeriv.comp_rule
.
autodiff vs fun_trans
So far we have been using the tactic fun_trans
to compute derivatives. You might have notices that fun_trans
removes all let bindings from the expression. For example
#check (โ (x : โ), let y := x*x; y*x) rewrite_by fโยน : โ โ โ xโ : โ dxโ : โ fโ : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโ hgโ : Differentiable โ gโ x : โ dx : โ f : โ โ โ g : โ โ โ hf : Differentiable โ f hg : Differentiable โ g
| fun x =>
(โ (x:=x),
let y := x * x;
y * x)
1; fโยน : โ โ โ xโ : โ dxโ : โ fโ : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโ hgโ : Differentiable โ gโ x : โ dx : โ f : โ โ โ g : โ โ โ hf : Differentiable โ f hg : Differentiable โ g
| fun x => x * x + (x + x) * x
we can alternativelly use autodiff
tactic
#check (โ (x : โ), let y := x*x; y*x) rewrite_by fโยน : โ โ โ xโ : โ dxโ : โ fโ : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโ hgโ : Differentiable โ gโ x : โ dx : โ f : โ โ โ g : โ โ โ hf : Differentiable โ f hg : Differentiable โ g
| fun x =>
let y := x * x;
let dy := x + x;
let dz := y + dy * x;
dz
The tactic autodiff
behaves very similarly to fun_trans
but it carefully handled let bindings which is important for generating efficient code. Internally it uses a slightly modified version of Lean's simplifier and carefully configured fun_trans
such that it handles let bindings more carefully and efficiently.
From now one, the rule of thumb it to use fun_trans
if you do not care about generating efficient code and want to just prove something using derivatives and to use autodiff
when you care about the efficiency of the resulting code.
You might have noticed that let bindings are preserved when using โ!
notation. This is because โ!
is using the tactic autodiff
instead of fun_trans
.
Note that when you want to compute efficient derivatives you have to use autodiff
and forward mode derivative โ>
. Using only one of them will not yield the most efficient code.
Relation to Dual Numbers
One common explanation of forward mode derivative is through dual numbers \(a + \epsilon b \). Similarly to complex numbers \( \mathbb{C}\), which extend reals numbers with complex unit \(i\) that squares to negative one, \(i^2 = -1\), the dual numbers \( \overline{\mathbb{R}}\) extend real numbers with \(\epsilon\) that squares to zero, \(\epsilon^2 = 0\).
We can add and multiply two dual numbers
$$\begin{align} (a + \epsilon b) + (c + \epsilon d) &= ((a + c) + \epsilon (b + d)) \\ (a + \epsilon b) (c + \epsilon d) &= ((a c) + \epsilon (b c + a d)) \end{align} $$
Through power series we can also calculate functions like sin, cos or exp
$$\begin{align} \sin(a + \epsilon b) &= \sin(a) + \epsilon b \cos(a) \\ \cos(a + \epsilon b) &= \cos(a) - \epsilon b \sin(a) \\ \exp(a + \epsilon b) &= \exp(a) + \epsilon b \exp(a) \end{align} $$
In general, for an analytical function (f : โ โ โ)
we can show that
$$f(a + \epsilon b) = f(a) + \epsilon b f'(a) $$
Every dual number is just a pair of two real numbers i.e. \( \overline{\mathbb{R}} \cong \mathbb{R} \times \mathbb{R} \) therefore we can think about the forward mode derivative \(\overrightarrow{\partial} f\) as extension of \(f\) to dual numbers
$$\overrightarrow{\partial} f (x + \epsilon dx) = f(x) + \epsilon f'(x)dx $$
Explain this for general function f : X โ Y
and relate it to complexification of a vector space.
Exercises
-
In section (??) we defined a function
foo
def foo (x : โ) := x^2 + x
In order to compute derivatives with foo
you need to derivative generate theorems using def_fun_trans
and def_fun_prop
macros. Use def_fun_trans
macro to generate forward mode derivative rule for foo
.
-
Newton's method with forward mode AD. In previous chapter (??) we talked about Newton's method. At each iteration we need to compute the function value \( f(x_n)\) and its derivative \( f'(x_n)\). Instead of evaluating
f x
andโ! f x
use fowrward modeโ>! f x 1
to compute the function value and its derivative at the same time. -
Redo the exercises on Newton's method from chapter (??) using forward mode derivative
โ>
instead of normal derivativeโ
. -
Prove on paper that the chain rule for forward mode derivative is equivalent to the standard chain rule.
Reverse Mode
Forward mode derivative โ> f
is designed to efficiently compute the derivative โ f
. To efficiently compute the gradient โ f
we have to introduce reverse mode derivative <โ f
.
Recall the that definition of the gradient โ f
is adjoint of the derivative
variable (f : โรโ โ โ) (x : โรโ)
example : (โ f x) = adjoint โ (โ f x) 1 := fโยฒ : โ โ โ xโยน : โ dxโ : โ fโยน : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโยน hgโ : Differentiable โ gโ xโ : โ dx : โ fโ : โ โ โ g : โ โ โ hf : Differentiable โ fโ hg : Differentiable โ g f : โ ร โ โ โ x : โ ร โ
โข โ f x = adjoint โ (โ(โ f x)) 1 All goals completed! ๐
One might naively do the same trick as with forward mode derivative and define reverse mode by putting together the function value and the derivative together i.e. <โ f x dy = (f x, adjoint โ (โ f x) dy)
. However, it is not possible to write down a good chain rule for this operation. The only way to do this is to postpone the argument dy
and define the reverse mode derivative as
variable (f : โ โ โ) (x : โ) example : (<โ f x) = (f x, fun dy => adjoint โ (โ f x) dy) := by rfl
The reverse mode derivative at point x
computes the value f x
and a function that can compute the adjoint of the derivative.
With this definition we can write down the chain rule for reverse mode derivative
variable (f g : โ โ โ) (x : โ) (hf : Differentiable โ f) (hg : Differentiable โ g)
example :
(<โ (fun x => f (g x)) x)
=
let (y,dg') := <โ g x
let (z,df') := <โ f y
(z, fun dz => dg' (df' dz)) := fโยณ : โ โ โ xโยฒ : โ dxโ : โ fโยฒ : โ โ โ gโยน : โ โ โ hfโยน : Differentiable โ fโยฒ hgโยน : Differentiable โ gโยน xโยน : โ dx : โ fโยน : โ โ โ gโ : โ โ โ hfโ : Differentiable โ fโยน hgโ : Differentiable โ gโ fโ : โ ร โ โ โ xโ : โ ร โ f : โ โ โ g : โ โ โ x : โ hf : Differentiable โ f hg : Differentiable โ g
โข <โ (x:=x), f (g x) =
match <โ g x with
| (y, dg') =>
match <โ f y with
| (z, df') => (z, fun dz => dg' (df' dz)) All goals completed! ๐
The crucial observation is that the derivative part of <โ (fโg)
composes the derivatives of f
and g
in the reverse order hence the name reverse mode.
Talk about
-
what is forward pass and reverse pass
-
memory requirements of reverse mode
-
forward mode vs reverse mode
-
relation to vjp jvp convention
Exercises
-
reimplement gradient descent using
revFDeriv
and usef x
for computing the improvement -
Generate
revFDeriv
rules forfoo
. -
SDF projection
-
sphere
-
pick something from https://iquilezles.org/articles/distfunctions/
-
Derivatives of Neural Network Layers
In the chapter 'Working with Arrays' we have constructured a simple neural network. To train it we have to comput its derivative
Differentiate neural network layers from the chapter 'Working with Arrays'