profile pic
⌘ '
raccourcis clavier

Auto differentiation and XLA

f(x)=e2xx3dfdx=2e2x3x2f(x) = e^{2x} - x^3 \rightarrow \frac{df}{dx} = 2e^{2x} - 3x^2 manual diff

Others:

  • numerical, symbolic
  • autodiff
    • similar to symbolic, but on demand?
    • instead of expression returns numerical value

Forward mode

  • compute the partial diff of each scalar wrt each inputs in a forward pass

  • represented with tuple of original viv_i and primal viov_i^o (tangent) vi(vi,vo˙)v_i \rightarrow (v_i, \dot{v^o})

  • Jax uses operator overloading.

Reverse mode

  • store values and dependencies of intermediate variables in memory
  • After forward pass, compute partial diff output wrt the intermediate adjoint vˉ\bar{v}