- Input: code computing a function
- Output: code to compute one or more derivatives of the function.
AD writes functions as a sequence of simple compositions
And writes derivatives using the chain rule:
We decompose code using the chain rule to make derivative code. This can lead to a lot of redundant computations. We can use dynamic programming to avoid redundant calculations.
We define a computation graph as a DAG
- Root nodes are the parameters (and inputs).
- Branch nodes are computed values (𝛼 values).
- Leaf node is the function value.
Two stages (example of a function that takes and and calculates a ):
- Forward AD pass is called forward propagation:
- Computes from and and passes it to its outputs
- Storing intermediate calculations and
- Backward AD pass is called backpropagation:
- Starts from the end with which is just 1
- Computes and from
- Using intermediate calculations stored during forward pass