This post was inspired by Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One. It made me rethink how image classifiers work—it’s much more nuanced than you might suspect.
Let’s say you have an MNIST classifier $f$. For some wild reason, instead of classifying an image, you want to use it to generate one. For example, a nice, clean image of a zero. How would you do it?
Naive approach: direct input optimization on class probabilities
Here’s the obvious first attempt: you start with a random image $x = \text{torch.randn()}$ and a target class label $y = [1, 0, \dots, 0]$. Feed $x$ into $f$, calculate the loss (e.g., cross-entropy) between the predicted label $\hat y$ and $y$, and backpropagate to update $x$ so it looks more “zero-like.” Repeat this process until $f$ confidently says that $x$ looks like a zero.
Sounds good, right? The problem is: it probably won’t look like a zero to you!
The disappointment of every ML student: classified as a zero with 99.9% certainty…
Why the naive approach fails: softmax normalization
There’s only one reason this doesn’t work: the softmax operation, which normalizes logits into probabilities. Have you ever thought about what softmax does exactly? Here’s the breakdown:
Right before the final linear layer, the model has transformed the input image $x$ into a so-called image embedding. The weights $W$ of the linear layer act as a collection of class centroids (like in K-means clustering), where each column vector $W_i$ represents such a centroid or “prototype” of class $i$ in feature space. The dot product between the image embedding $e$ and $W_i$ gives a similarity score that tells you how likely it is that $x$ belongs to class $i$.1 But you probably know this similarity vector by the name logits.
Next, softmax bundles all these independent expert opinions by means of normalization. In a way, this is cheating: it assumes that $x$ must belong to one of the known classes. Normalization simply tries to find which class is the most likely to match with the given image. For out-of-distribution images, this assumption no longer holds and the models seems to fail entirely.
For example: a rubbish image may get classified as a zero despite the 0-expert saying there’s 1% chance it’s a zero, simply because all other experts said there’s less than 1% chance it’s their class. Going back to the K-means clustering analogy, it’s the same as finding the nearest centroid, without considering the actual distance to it.
Normalization effectively ignores whether any expert thinks the image is realistic. This is why optimization on class probabilities fails: softmax cares about relative confidence, not absolute plausibility.
A classifier without softmax is just an energy-based model
To fix this, we need to stop relying on softmax probabilities and instead look at the raw logits. These logits give us two crucial pieces of information:
- How strongly the target expert thinks $x$ belongs to its class.
- How realistic the image is overall, regardless of class.
For point 2, we can use the LogSumExp to get an unnormalized estimate for $\log p(x)$, the data distribution. This works because every expert can be considered an energy-based model, with energy $E_{x,y_i} = -\text{logits}_i \propto -\log p(x \|y = i)$
. Given that $p(x) = \sum_i p(x \|y = i)$
, we find that $\log p(x) = \log \sum_i \exp (-E_{x,y_i})$
, or thus, the LogSumExp of the logits.
But wait, can a classifier do this out-of-the-box? It might seem that it’s not trained to do this—but that’s not true. You see, the matrix multiplication guarantees the independence of every expert $W_i$. And if all experts agree that $x$ is not in their class, we must conclude that $x$ lies outside the data distribution.
Naive approach revisited: direct input optimization on logits
Let’s take another look at our naive approach from the perspective of logits. Suppose we use a cross-entropy loss with a one-hot label at class $i$. Writing our loss in terms of logits, we find:
\[\text{loss} = -\text{logits}_i + \text{LogSumExp}(\text{logits})\]
Minimizing this loss means
- Maximizing the logit from our desired class
- Minimizing the LogSumExp(logits)
Notice that the second term actively discourages realistic images! During training, this term makes sense: we want to push away all non-target experts. But for generation, this is clearly the wrong loss to optimize.
Let’s make the simple modification of encouraging realistic images:
\[\text{loss} = -\text{logits}_i - \text{LogSumExp}(\text{logits})\]
Using this loss, we can finally generate a somewhat recognizable zero!2
An insanely realistic image that’s 100% a zero—according to $f$, at least
New problem: sampling from an energy-based model
Even after all this trouble, there’s still one major challenge left: even with $\log p(x)$, optimizing $x$ for maximum likelihood doesn’t necessarily give us realistic samples. Instead, we need to sample from the typical set, the region of high probability mass. What does that mean? How do we do this exactly? That story leads us into the territory of denoising diffusion models and their sampler schedules—so I’ll leave it at that, for now. But for the curious minds, I can already highly recommend this blogpost from Sander Dieleman on typicality.
Conclusion
Your classifier isn’t just a probability predictor; it’s secretly a collection of energy-based models, burried under a softmax. By treating it as such, we gain a much deeper understanding of what the model is actually doing. With a dead-simple modification to the loss, we can turn the classifier into a generator, to produce images that are not just classified correctly but also look realistic. A humble reminder that solutions come easy to those who truly understand the problem.
Food for thought: I suspect adversarial attacks make use of the exact same mechanism. I wonder how it looks like under the LogSumExp lens!
-
In this dot product, the vectors $e$ and $W_i$ are not necessarily normalized. In fact, the norm of $W_i$ can be used to model the prior probability of the class, together with the bias term $b_i$. This is actually true for all linear operations, and led to techniques like Weight Normalization which model this weight norm as a separate parameter, thereby improving the training stability. ↩
-
This will only work well for models with continuous gradients. Replace the discontinuities (e.g., ReLU, MaxPooling) with continuous alternatives (e.g., SoftPlus, AveragePooling). Since we’re dealing with inverse optimization here (trying to find an image that matches a class label), it’s also best to use bijective operations. ↩