Okay, let’s break down Variational Inference (VI) step-by-step, weaving in intuition alongside the mathematical rigor.
The Scenario: Why We Need Approximation
At the heart of many machine learning problems, especially in Bayesian statistics and generative modeling, lies the need to understand the relationship between observed data x and some unobserved latent variables z. These latent variables might represent underlying causes, hidden features, or model parameters we care about. Our goal is to compute the posterior distribution p(z|x), which tells us what we can infer about the hidden variables z after observing the data x.
Bayes’ Theorem gives us the roadmap:
$p(z|x) = \frac{p(x|z) p(z)}{p(x)}$
Where:
The evidence p(x) is calculated by integrating (or summing) over all possible latent variables:
p(x) = ∫p(x|z)p(z)dz
This integral is often intractable because:
Intuition Break: Imagine x is a picture of a cat, and z represents all possible “cat-ness” features (pose, fur color, lighting, background). The true posterior p(z|x) tells us the probability of all combinations of these features given this specific cat picture. Calculating the denominator p(x) requires summing up the probability of this specific picture arising from every single conceivable combination of features, weighted by how likely those features are a priori. This is like trying to calculate the exact probability of finding a specific seashell on a beach by considering every possible wave, tide, and creature interaction that could have brought it there – computationally infeasible!
Since we can’t compute p(x), we can’t compute the true posterior p(z|x). We need an approximation.
Variational Inference: The Core Idea
Instead of calculating the intractable p(z|x) directly, VI cleverly reframes the inference problem as an optimization problem.
Intuition Break: Since we can’t grasp the true, potentially very weirdly shaped posterior p(z|x), we decide to approximate it with a familiar, manageable shape. Think of the true posterior as a complex, jagged mountain range. We choose a simpler shape, like a smooth hill (e.g., a Gaussian), represented by q(z; ϕ). The parameters ϕ control the location (μ) and spread (Σ) of our hill. Our task is to find the best possible hill – the one whose shape most closely matches the true mountain range.
Measuring Closeness: The KL Divergence
How do we formally measure the “closeness” between our approximation q(z; ϕ) and the true posterior p(z|x)? We use the Kullback-Leibler (KL) divergence:
$$ KL(q(z; \phi) || p(z|x)) = \int q(z; \phi) \log \frac{q(z; \phi)}{p(z|x)} dz = \mathbb{E}_{q(z; \phi)}\left[ \log q(z; \phi) - \log p(z|x) \right] $$