Okay, let’s break down Variational Inference (VI) step-by-step, weaving in intuition alongside the mathematical rigor.

A Masterful Explanation of Variational Inference (with Intuition)

The Scenario: Why We Need Approximation

At the heart of many machine learning problems, especially in Bayesian statistics and generative modeling, lies the need to understand the relationship between observed data x and some unobserved latent variables z. These latent variables might represent underlying causes, hidden features, or model parameters we care about. Our goal is to compute the posterior distribution p(z|x), which tells us what we can infer about the hidden variables z after observing the data x.

Bayes’ Theorem gives us the roadmap:

$p(z|x) = \frac{p(x|z) p(z)}{p(x)}$

Where:

The evidence p(x) is calculated by integrating (or summing) over all possible latent variables:

p(x) = ∫p(x|z)p(z)dz

This integral is often intractable because:

  1. The latent space z can be extremely high-dimensional (thousands or millions of dimensions).
  2. The integrand p(x|z)p(z) might represent a complex interaction between our model and prior, lacking a nice mathematical form we can integrate easily.

Intuition Break: Imagine x is a picture of a cat, and z represents all possible “cat-ness” features (pose, fur color, lighting, background). The true posterior p(z|x) tells us the probability of all combinations of these features given this specific cat picture. Calculating the denominator p(x) requires summing up the probability of this specific picture arising from every single conceivable combination of features, weighted by how likely those features are a priori. This is like trying to calculate the exact probability of finding a specific seashell on a beach by considering every possible wave, tide, and creature interaction that could have brought it there – computationally infeasible!

Since we can’t compute p(x), we can’t compute the true posterior p(z|x). We need an approximation.

Variational Inference: The Core Idea

Instead of calculating the intractable p(z|x) directly, VI cleverly reframes the inference problem as an optimization problem.

  1. Choose a Family of Approximating Distributions: We define a family of simpler, tractable distributions over the latent variables, denoted by q(z; ϕ). This family is called the variational family, and it’s governed by parameters ϕ (the variational parameters). A very common choice is the Gaussian family, where ϕ might represent the mean and covariance matrix (or just the diagonal of it): q(z; ϕ) = 𝒩(z|μ, Σ). The key properties of this family are that we can easily:

Intuition Break: Since we can’t grasp the true, potentially very weirdly shaped posterior p(z|x), we decide to approximate it with a familiar, manageable shape. Think of the true posterior as a complex, jagged mountain range. We choose a simpler shape, like a smooth hill (e.g., a Gaussian), represented by q(z; ϕ). The parameters ϕ control the location (μ) and spread (Σ) of our hill. Our task is to find the best possible hill – the one whose shape most closely matches the true mountain range.

  1. Find the Best Approximation: Our goal is to find the specific member of this family, q(z) = q(z; ϕ), that is “closest” to the true posterior p(z|x).

Measuring Closeness: The KL Divergence

How do we formally measure the “closeness” between our approximation q(z; ϕ) and the true posterior p(z|x)? We use the Kullback-Leibler (KL) divergence:

$$ KL(q(z; \phi) || p(z|x)) = \int q(z; \phi) \log \frac{q(z; \phi)}{p(z|x)} dz = \mathbb{E}_{q(z; \phi)}\left[ \log q(z; \phi) - \log p(z|x) \right] $$