Three complementary/equivalent formulations of Difussion models#

The field of generative modelling using Deep Learning (DL) is one of the most actives and abullient areas in Machine Learning (ML) these days. Apart from the models used in the context of Natural Language Processing, which are mainly based on Transformers and Recurrent Networks, most of the theoretical and applied works focuse on one of the following four architectures/formulations:

  • Variational Auto-encoders (VAEs)

  • Generative Adversarial Networks (GANs)

  • Normalizing Flows (NF)

  • Diffusion Models (DM)

Image taken from [1]

VAEs and GANs have been in the state of the art for many years, and although they are still in the core of new approaches, there is a concensus in the academic community that their development have achieved a plateau. Notwithstanding, VAEs and their many variants are still widely used alone or as part of more complex systems for image/audio generation as tools to embeb the input objects into a latent space where the generation process can be performed more efficiently and then decodify the generated embedded vector to final objects in the input space (see Audioldm for an example). They are also a matter of study for the very interesting and vibrant field of physics-informed Neural Networks [2] and others [3].

On the other hand, although NF (and Neural Network-based variants such as Autoregressive Flows) and DM have been present in the state of the art for many years, their applicability and potential for solving new theoretical and practical problems are in the core of many ongoing research.

This notebook in intended to review the Difussion Models in detail, covering basic formulations and implementations, in order to understand the basics of DM, how the multiple formulations available connect each other and opening the door to address advanced topics which will be listed at the end of the lecture.

Before starting with the basic definitions of DM, we are going to review fundamental concepts about Bayesian learning and variational inference, required to understand the mathematical formulation of DMs.

1. Evidence Lower Bound#

In many ML problems, we can assume that the observed data is generated by an associated unseen laten variable \(z\). Therefore, the distribution of the observed variable \(x\) and the laten variable could be modeled by the join distribution \(p(x,z)\). There are two ways we can manipulate this joint distribution to recover the likelihood of purely observed p(x):

  • Marginalize out the latent variable \(z\): \(p(x) = \int p(x,z)dz\)

  • Use the change rule of probability: \(p(x) = \frac{p(x,z)}{p(z|x)}\)

In order to maximize the likelihood \(p(x)\) (the most common aim in ML training), each of the two options has its own challenges:

  • The first option is difficult because it involves integrating out all latent variables z. which is intractable for complex models.

  • The second option involves having access to a ground truth laten encoder \(p(z|x)\), that is not available.

Combining the two former equations we can derive a proxy objective, called Evidence Lower Bound (ELBO), with which to optimize a latent variable model [4].

Let’s \(q_{\phi}(z|x)\) be a flexible aproximate variational distribution with parameters \(\phi\) that we seek to optimize; the ELBO can be derive as follows:

\[\begin{split} \begin{split} \log p(x) & = \log \int p(x,z) dz\\ & = \log \int \frac{p(x,z)q_{\phi}(z|x)}{q_{\phi}(z|x)}\\ & = \log \mathbb{E}_{q_{\phi}(z|x)} \left[\frac{p(x,z)}{q_{\phi}(z|x)}\right]\\ & \geq \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{q_{\phi}(z|x)}\right] \end{split} \end{split}\]

This derivation of the ELBO does not give us a clear idea of why we want to maximize it as an objective. Let’s look at an alternative derivation:

\[\begin{split} \begin{split} \log p(x) & = \log p(x) \int q_{\phi}(z|x)dz\\ & = \int q_{\phi}(z|x)(\log p(x))dz\\ & = \mathbb{E}_{q_{\phi}(z|x)} \left[\log p(x)\right]\\ & = \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{p(z|x)}\right]\\ & = \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{p(z|x)}\frac{q_{\phi}(z|x)}{q_{\phi}(z|x)}\right]\\ & = \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{q_{\phi}(z|x)}\right] + \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{q_{\phi}(z|x)}{p(z|x)}\right]\\ & = \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{q_{\phi}(z|x)}\right] + D_{KL}(q_{\phi}(z|x)\|p(z|x))\\ & \geq \mathbb{E}_{q_{\phi}(z|x)} \left[\log \frac{p(x,z)}{q_{\phi}(z|x)}\right] \end{split} \end{split}\]

From this derivarion it is easy to observe that the ELBO is a lower bound for the evidence; since the KL term is not negative the ELBO can never exceed the evidence.

In order to obtain a model that represents the latent structure of the data, the goal is then to learn the paramters of the variation posterior \(q_{\phi}(z|x)\) such that it matches the true posterior distribution \(p(z|x)\), which is achieved by minimizing the KL divergence term. Unfourtunately, such a minimization cannot be done directly because we do not have acces to the ground truth \(p(z|x)\) distribution. However, as the left hand side of the former equation (the evidence) does not depend on \(\phi\), the right hand side sum up to a constant. Therefore, any maximization of the ELBO with respect to \(\phi\) necessarily invokes an equal minimization of the KL divergence term [4].

Variational Autoencoders#

As an example, we are going to use the former ELBO definition to derive the objective function used in VAEs, which is a well-known model.

\[\begin{split} \begin{split} \mathbb{E}_{q_{\phi}({\bf{z}}|{\bf{x}})} \left[\log \frac{p({\bf{x}},{\bf{z}})}{q_{\phi}({\bf{z}}|{\bf{x}})}\right] & = \mathbb{E}_{q_{\phi}({\bf{z}}|{\bf{x}})} \left[\log \frac{p_{\theta}({\bf{x}}|{\bf{z}})p(\bf{z})}{q_{\phi}({\bf{z}}|{\bf{x}})}\right]\\ & = \mathbb{E}_{q_{\phi}({\bf{z}}|{\bf{x}})} \left[\log p_{\theta}({\bf{x}}|{\bf{z}})\right] + \mathbb{E}_{q_{\phi}({\bf{z}}|{\bf{x}})} \left[\log \frac{p({\bf{z}})}{q_{\phi}({\bf{z}}|{\bf{x}})}\right]\\ & = \mathbb{E}_{q_{\phi}({\bf{z}}|{\bf{x}})} \left[\log p_{\theta}({\bf{x}}|{\bf{z}})\right] - D_{KL}(q_{\phi}({\bf{z}}|{\bf{x}})\|p({\bf{z}})) \end{split} \end{split}\]

The first term is then the reconstruction error and the second term is the KL divergence between the variational distribution and the prior distribution that is tipically assumed as \(\mathcal{N}({\bf{0}},{\bf{I}})\). During training, the model simultanously learns the intermediate bottlenecking distribution \(q_{\phi}({\bf{z}}|{\bf{x}})\) that can be treated as an encoder, and the parameters \(\theta\) of the deterministic function \(p_{\theta}({\bf{x}}|{\bf{z}})\) that converts a given latent vector \(\bf{z}\) into an observation \(\bf{x}\) (called Decoder).

2. Denoising Difussion Probabilistic Models (DDPM)#

Image taken from [5]

A diffusion probabilistic model is composed of two Markov chain:

  • A forward (noising) process, which is a Markov chain that gradually adds noise to the data until the signal is destroyed.

  • A reverse (denoising) process, which is a parameterized Markov chain whise transitions are learned to reverse the diffusion process.

A DDPM is a simplified Hierarchical Variational Autoencoder (or Recursive VAE) [4], where the dimension of the latent dimension is exactly equal to the data dimension, the structure of the latent encoder is not learned and the Gaussian parameters of the latent encoders vary over time.

Let’s \({\bf{x}}_t\) a latent variable, where \(t=0\) represents true data samples and \(t\in[1,T]\) represents a corresponding latent with hierarchy indexed by \(t\). By considering the Markov condition, the posterior distribution of the forward process can be written as:

\[ q({\bf{x}}_{1:T}|{\bf{x}}_0) = \prod_{t=1}^T q({\bf{x}}_t|{\bf{x}}_{t-1}) \]

As pointed out before, the structure of the latent encoder at each timestep \(t\) is no learned, but it is fixed as a linear Gaussian model, where the mean and standard deviation are typically set to predefined values that change deterministically over time (they can also be treated as parameters and learned during training [6]). Thus, the Gaussian encoder is defined to preserve the variance during the forward process:

\[ q({\bf{x}}_t|{\bf{x}}_{t-1}) = \mathcal{N}({\bf{x}}_{T};\sqrt{\alpha_t}{\bf{x}}_{t-1},(1-\alpha_t){\bf{I}}) \]

The joint distribution of the DDPM is given by:

\[ p({\bf{x}}_{0:T}) = p({\bf{x}}_T)\prod_{t=1}^T p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}) \]

Under the assumption that the evolution of \(\alpha_t\) over time produces a final latent distribution \(p({\bf{x}}_T)\) equal to a standard Gaussian \(\mathcal{N}({\bf{x}}_T;{\bf{0}},{\bf{I}})\).

A notable property of the forward process is that it admits sampling \({\bf{x}}_{t}\) at an arbitrary timestep \(t\) in closed form. Let’s see how.

Using the reparameterization trick is easy to see that sampling \({\bf{x}}_t \sim q({\bf{x}}_t|{\bf{x}}_{t-1})\) is equivalent to:

\[{\bf{x}}_t = \sqrt{\alpha_t}{\bf{x}}_{t-1} + \sqrt{1-\alpha_t}\mathbf{\epsilon}_{t-1}; \;\; \mathbf{\epsilon}_{t-1} \sim \mathcal{N}({\bf{0}},{\bf{I}})\]

Now, we would like to get \({\bf{x}}_t\) samples without having to estimate \({\bf{x}}_{t-1}\) but using directly the true sample \({\bf{x}}_0\). Therefore,

\[\begin{split} \begin{split} {\bf{x}}_t & = \sqrt{\alpha_t}{\bf{x}}_{t-1} + \sqrt{1-\alpha_t}{\mathbf{\epsilon}}_{t-1}\\ & = \sqrt{\alpha_t}\left( \sqrt{\alpha_{t-1}}{\bf{x}}_{t-2} + \sqrt{1-\alpha_{t-1}}\mathbf{\epsilon}_{t-2} \right) + \sqrt{1-\alpha_t}\mathbf{\epsilon}_{t-1}\\ & = \sqrt{\alpha_t\alpha_{t-1}}{\bf{x}}_{t-2} + \underbrace{\sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\epsilon_{t-2} + \sqrt{1-\alpha_t}\mathbf{\epsilon}_{t-1}}_{\text{Equivalent to the sum of two Gaussians with mean = $\bf{0}$ and different variance}}\\ & = \sqrt{\alpha_t\alpha_{t-1}}{\bf{x}}_{t-2} + \sqrt{\sqrt{\alpha_t-\alpha_t\alpha_{t-1}}^2 + \sqrt{1-\alpha_t}^2}\epsilon_{t-2}\\ & =\sqrt{\alpha_t\alpha_{t-1}}{\bf{x}}_{t-2} + \sqrt{1-\alpha_t\alpha_t-1}\mathbf{\epsilon}_{t-2}\\ & = \cdots\\ & = \sqrt{\prod_{i=1}^t \alpha_i} {\bf{x}}_{0} + \sqrt{1 - \prod_{i=1}^t \alpha_i}\mathbf{\epsilon}_{0}\\ & = \sqrt{\bar{\alpha}_t}{\bf{x}}_{0} + \sqrt{1 -\bar{\alpha}_t}\mathbf{\epsilon}_{0}\\ &\sim \mathcal{N}({\bf{x}}_{t};\sqrt{\bar{\alpha}_t}{\bf{x}}_{0},(1 -\bar{\alpha}_t){\bf{I}}) \end{split} \end{split}\]

Now, let’s try to apply the ELBO to DDPMs in order to find a loss function to optimize [4]:

\[\begin{split} \begin{split} \log p({\bf{x}}_{0:T}) & = \log \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\frac{p({\bf{x}}_{0:T})}{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\right]\\ & \geq \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{0:T})}{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)\prod_{t=1}^T p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{\prod_{t=1}^T q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\prod_{t=2}^{T} p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_T|{\bf{x}}_{T-1})\prod_{t=1}^{T-1} q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\prod_{t=1}^{T-1} p_{\theta}({\bf{x}}_{t}|{\bf{x}}_{t+1})}{q({\bf{x}}_T|{\bf{x}}_{T-1})\prod_{t=1}^{T-1} q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ &=\mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_T|{\bf{x}}_{T-1})}\right] + \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \prod_{t=1}^{T-1} \frac{p_{\theta}({\bf{x}}_{t}|{\bf{x}}_{t+1})}{q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right] + \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{T})}{q({\bf{x}}_T|{\bf{x}}_{T-1})}\right] + \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[ \sum_{t=1}^{T-1} \log \frac{p_{\theta}({\bf{x}}_{t}|{\bf{x}}_{t+1})}{q({\bf{x}}_t|{\bf{x}}_{t-1})} \right]\\ &=\mathbb{E}_{q({\bf{x}}_{1}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right] + \mathbb{E}_{q({\bf{x}}_{T-1},{\bf{x}}_{T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{T})}{q({\bf{x}}_T|{\bf{x}}_{T-1})}\right] + \sum_{t=1}^{T-1} \mathbb{E}_{q({\bf{x}}_{t-1},{\bf{x}}_{t},{\bf{x}}_{t+1}|{\bf{x}}_0)}\left[\log \frac{p_{\theta}({\bf{x}}_{t}|{\bf{x}}_{t+1})}{q({\bf{x}}_t|{\bf{x}}_{t-1})} \right]\\ & = \underbrace{\mathbb{E}_{q({\bf{x}}_{1}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right]}_{\text{reconstruction term}} - \underbrace{\mathbb{E}_{q({\bf{x}}_{T-1}|{\bf{x}}_0)}\left[D_{KL}(q({\bf{x}}_{T}|{\bf{x}}_{T-1})\|p({\bf{x}}_{T}))\right]}_{\text{prior matching term}} - \sum_{t=1}^{T-1} \underbrace{\mathbb{E}_{q({\bf{x}}_{t-1},{\bf{x}}_{t+1}|{\bf{x}}_0)}\left[D_{KL}(q({\bf{x}}_t|{\bf{x}}_{t-1})\|p_{\theta}({\bf{x}}_{t}|{\bf{x}}_{t+1}))\right]}_{\text{consistency term}} \end{split} \end{split}\]

The terms that constitute the former ELBO can be interpreted as [4]:

  • The reconstruction term predicts the log probability of the original data sample given the first-step latent. This is equivalent to the reconstruction term in the VAE formulation.

  • The prior matching term is also similar to that of the VAE and is minimized when the the final latent distribution is Gaussian. But, unlike VAE, this term requires no optimization, as it has no trainable paramters.

  • The consistency term atttempts to make the distribution at \({\bf{x}}_{t}\) consistent, from both forward and backward processes. “That is, a denoising step from a noisier image should match the corresponding noising step from a cleaner image, for every intermediate timestep; this is reflected mathematically by the KL Divergence”.

\[\require{cancel}\]

The consistency term in the former formulation, although interpretable, is hard to compute since it requires computing expectation over two random variables \(\{{\bf{x}}_{t-1},{\bf{x}}_{t+1}\}\) for every time step. Moreover, the variance of the ELBO may be large, as it is computed by summing up \(T-1\) terms. However, there is an alternative formulation that uses a mathematical trick based on the Markov property of the process. The key insight is that the encoder transitions \(q({\bf{x}}_{t}|{\bf{x}}_{t-1}) = q({\bf{x}}_{t}|{\bf{x}}_{t-1}, {\bf{x}}_{0})\), where the extra conditioning variable is redundant due to the Markov property.

According to the Bayes rule: $\( q({\bf{x}}_{t}|{\bf{x}}_{t-1}, {\bf{x}}_{0}) = \frac{q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})q({\bf{x}}_{t}|{\bf{x}}_{0})}{q({\bf{x}}_{t-1}|{\bf{x}}_{0})} \)$

Now, let’s derive the ELBO again [4]:

\[\begin{split} \begin{split} \log p({\bf{x}}_{0:T}) & \geq \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{0:T})}{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)\prod_{t=1}^T p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{\prod_{t=1}^T q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\prod_{t=2}^{T} p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_1|{\bf{x}}_{0})\prod_{t=2}^{T} q({\bf{x}}_t|{\bf{x}}_{t-1})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\prod_{t=2}^{T} p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_1|{\bf{x}}_{0})\prod_{t=2}^{T} q({\bf{x}}_t|{\bf{x}}_{t-1}, {\bf{x}}_{0})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_1|{\bf{x}}_{0})} + \log \prod_{t=2}^T \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_t|{\bf{x}}_{t-1}, {\bf{x}}_{0})}\right]\\ & = \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_1|{\bf{x}}_{0})} + \log \prod_{t=2}^T \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{\frac{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})q({\bf{x}}_{t}|{\bf{x}}_{0})}{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}}\right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_1|{\bf{x}}_{0})} + \log \prod_{t=2}^T \frac{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}{q({\bf{x}}_{t}|{\bf{x}}_{0})} + \log \prod_{t=2}^T \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_1|{\bf{x}}_{0})} + \underbrace{\sum_{t=2}^T \log \frac{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}{q({\bf{x}}_{t}|{\bf{x}}_{0})}}_{\text{telescopic summation}} + \sum_{t=2}^T \log \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{\cancel{q({\bf{x}}_1|{\bf{x}}_{0})}} + \log \frac{\cancel{q({\bf{x}}_{1}|{\bf{x}}_0)}}{q({\bf{x}}_{T}|{\bf{x}}_0)} + \sum_{t=2}^T \log \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_T)p_{\theta}({\bf{x}}_0|{\bf{x}}_1)}{q({\bf{x}}_T|{\bf{x}}_{0})} + \sum_{t=2}^T \log \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right] + \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{T})}{q({\bf{x}}_T|{\bf{x}}_{0})}\right] + \sum_{t=2}^T \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ &= \mathbb{E}_{q({\bf{x}}_{1}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right] + \mathbb{E}_{q({\bf{x}}_{T}|{\bf{x}}_0)}\left[\log \frac{p({\bf{x}}_{T})}{q({\bf{x}}_T|{\bf{x}}_{0})}\right] + \sum_{t=2}^T \mathbb{E}_{q({\bf{x}}_{t},{\bf{x}}_{t-1}|{\bf{x}}_0)}\left[\log \frac{p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})}{q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})} \right]\\ & = \underbrace{\mathbb{E}_{q({\bf{x}}_{1}|{\bf{x}}_0)}\left[\log p_{\theta}({\bf{x}}_0|{\bf{x}}_1)\right]}_{\text{reconstruction term}} - \underbrace{D_{KL}(q({\bf{x}}_{T}|{\bf{x}}_{0})\|p({\bf{x}}_{T}))}_{\text{prior matching term}} - \sum_{t=2}^{T} \underbrace{\mathbb{E}_{q({\bf{x}}_{t}|{\bf{x}}_0)}\left[D_{KL}(q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\|p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}))\right]}_{\text{denoising matching term}} \end{split} \end{split}\]

This terms can also be interpretes as before [4]:

  • The reconstruction term is exactly the same as before. By considering that the last term sums up \(T-1\) terms, this term can be neglected.

  • The prior matching term represents how close the distribution of the final noisified input is to the standard Gaussian prior. It has no trainable parameters, and is also equal to zero under our assumptions.

  • The denoising matching term uses \(q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\) as a ground truth to learn \(p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})\). The ground truth “defines how to denoise a noisy sample \({\bf{x}}_{t}\) with access to what the final, completely denoised sample \({\bf{x}}_{0}\) should be” [4].

Now, let’s dig deeper inside the denoising matching term. Remeber that \(q({\bf{x}}_{t}|{\bf{x}}_{0}) = \mathcal{N}({\bf{x}}_{t};\sqrt{\bar{\alpha}_t}{\bf{x}}_{0},(1 -\bar{\alpha}_t){\bf{I}})\), so using the Bayes rule

\[\begin{split} \begin{split} q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0}) &= \frac{q({\bf{x}}_{t}|{\bf{x}}_{t-1}, {\bf{x}}_{0})q({\bf{x}}_{t-1}|{\bf{x}}_{0})}{q({\bf{x}}_{t}|{\bf{x}}_{0})}\\ &= \frac{\mathcal{N}({\bf{x}}_{t};\sqrt{\alpha_t}{\bf{x}}_{t-1},(1 -\alpha_t){\bf{I}})\mathcal{N}({\bf{x}}_{t-1};\sqrt{\bar{\alpha}_{t-1}}{\bf{x}}_{0},(1 -\bar{\alpha}_{t-1}){\bf{I}})}{\mathcal{N}({\bf{x}}_{t};\sqrt{\bar{\alpha}_t}{\bf{x}}_{0},(1 -\bar{\alpha}_t){\bf{I}})}\\ &=... \\ &\propto \mathcal{N}({\bf{x}}_{t-1}; \tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0}),\tilde{\mathbf{\Sigma}}_t) \end{split} \end{split}\]

where [4], $\( \tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0}) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1}){\bf{x}}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\bar{\alpha}_{t}){\bf{x}}_0}{1-\bar{\alpha}_t} \)\( \)\( \tilde{\mathbf{\Sigma}}_t = \frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}{\bf{I}} = \tilde{\sigma}_t^2{\bf{I}} \)$

Reversibility of Gaussian Distributions: One of the essential properties of Gaussian distributions is that they are closed under linear transformations and convolution with other Gaussians. Since the forward process involves adding Gaussian noise (which is a convolution operation), the reverse process essentially “removes” that noise. Because Gaussians are closed under these operations, the reverse process remains Gaussian (see Wikipedia for a proof.).

Therefore, we can assume that \(p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}) = \mathcal{N}(\mu_{\theta}({\bf{x}}_{t},t),\sigma_{\theta}^2{\bf{I}})\). Furthermore, as all \(\alpha\) terms are known to be frozen at each timestep we can immediately construct the variance of the approximate denoising transition step to also be \(\sigma_{\theta}^2{\bf{I}} = \tilde{\sigma}_t^2{\bf{I}} = \Sigma_t\).

The KL divergence between two multivariate Gaussian distributions is given by:

\[ D_{KL}\left(\mathcal{N}({\bf{x}};\mu_{\bf{x}},\Sigma_{\bf{x}})\|\mathcal{N}({\bf{y}};\mu_{\bf{y}},\Sigma_{\bf{y}}) \right) = \frac{1}{2}\left[ \log \frac{|\Sigma_{\bf{y}}|}{|\Sigma_{\bf{x}}|} - d + \text{tr}\left( \Sigma_{\bf{y}}^{-1}\Sigma_{\bf{x}}\right) + (\mu_{\bf{y}} - \mu_{\bf{x}})^\top \Sigma_{\bf{y}}^{-1} (\mu_{\bf{y}} - \mu_{\bf{x}})\right] \]

So,

\[\begin{split} \begin{split} &\underset{\theta}{\arg\min} D_{KL}(q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\|p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}))\\ = & \underset{\theta}{\arg\min} D_{KL}(\mathcal{N}({\bf{x}}_{t-1};\tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0}),\Sigma_{t})\|\mathcal{N}({\bf{x}}_{t-1};\mu_{\theta}({\bf{x}}_{t},t),\Sigma_{t})\\ = & ...\\ = & \underset{\theta}{\arg\min} \frac{1}{2\tilde{\sigma}_t^2} \left[ \|\mu_{\theta}({\bf{x}}_{t},t) - \tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0})\|_2^2\right] \end{split} \end{split}\]

We can simplify the former expression even more. Considering that \(\mu_{\theta}({\bf{x}}_{t},t)\) conditions on \({\bf{x}}_{t}\), as \(\tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0})\) does; we can set \(\mu_{\theta}({\bf{x}}_{t},t)\) to closely math the definition of \(\tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0})\) using the following form:

\[ \mu_{\theta}({\bf{x}}_{t},t) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1}){\bf{x}}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\bar{\alpha}_{t}){\hat{\bf{x}}}_{\theta}({\bf{x}}_{t},t)}{1-\bar{\alpha}_t} \]

So now, $\( \begin{split} &\underset{\theta}{\arg\min} D_{KL}(q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\|p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}))\\ = & \underset{\theta}{\arg\min} \frac{1}{2\tilde{\sigma}_t^2} \frac{\bar{\alpha}_{t-1}(1 - \alpha_t)^2}{(1 - \bar{\alpha}_{t})^2}\left[\| {\hat{\bf{x}}}_{\theta}({\bf{x}}_{t},t) - {\bf{x}}_{0}\|_2^2\right] \end{split} \)$

where \({\hat{\bf{x}}}_{\theta}({\bf{x}}_{t},t)\) is a function approximator (i.e. an artificial neural network).

Equivalently, recall that \({\bf{x}}_{t} = \sqrt{\bar{\alpha}_t}{\bf{x}}_{0} + \sqrt{1 -\bar{\alpha}_t}\mathbf{\epsilon}\). Thus, solving for \({\bf{x}}_{0}\): $\( {\bf{x}}_{0} = \frac{1}{\sqrt{\bar{\alpha}}_t}({\bf{x}}_{t} - \sqrt{1 - \bar{\alpha}_t}\epsilon) \)\( Plugin the former equation in: \)\( \begin{split} \tilde{\mathbf{\mu}}_t({\bf{x}}_{t},{\bf{x}}_{0}) & = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1}){\bf{x}}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\bar{\alpha}_{t}){\bf{x}}_0}{1-\bar{\alpha}_t}\\ & = \frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_{t} - \frac{1- \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon \right) \end{split} \)$

Thefore, similar to the previous analysis, we could assume that \(\mu_{\theta}({\bf{x}}_{t},t)\) takes the form:

\[ \mu_{\theta}({\bf{x}}_{t},t) = \frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_{t} - \frac{1- \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}({\bf{x}}_{t},t) \right) \]

and instead of predicting \({\bf{x}}_{0}\) from a noisy version corrupted by \(t\) steps of the forward process, we try to predict the noise. The loss function can then be written as:

\[\begin{split} \begin{split} &\underset{\theta}{\arg\min} D_{KL}(q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\|p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t}))\\ = & \underset{\theta}{\arg\min} \frac{1}{2\tilde{\sigma}_t^2} \frac{(1 - \alpha_t)^2}{(1 - \bar{\alpha}_{t})\alpha_t}\left[\| \epsilon - \epsilon_{\theta}({\bf{x}}_{t},t)\|_2^2\right]\\ = & \underset{\theta}{\arg\min} \frac{1}{2\tilde{\sigma}_t^2} \frac{(1 - \alpha_t)^2}{(1 - \bar{\alpha}_{t})\alpha_t}\left[\| \epsilon - \epsilon_{\theta}(\sqrt{\bar{\alpha}_t}{\bf{x}}_{0} + \sqrt{1 -\bar{\alpha}_t}\mathbf{\epsilon},t)\|_2^2\right] \end{split} \end{split}\]

According to [5], neglecting the weighting term produces better results.

Sampling#

\[ {\bf{x}}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_{t} - \frac{1- \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}({\bf{x}}_{t},t) \right) + \sigma_t \epsilon \sim \mathcal{N}({\bf 0}, {\bf I}) \]

The first term is the mean of \(p_{\theta}({\bf{x}}_{t-1}|{\bf{x}}_{t})\), but it is deterministic, so its variance is 0. That’s why suming \(\sigma_t \epsilon\) is equivalent to sampling from \(\mathcal{N}(\mu_{\theta}({\bf{x}}_{t},t),\sigma_t {\bf I})\).

Basic implementation of a DDPM#

As dataset we use the StandordCars Dataset, which consists of around 8000 images in the train set. Let’s see if this is enough to get good results ;-)

#!pip install opendatasets --upgrade --quiet

import opendatasets as od

dataset_url = 'https://www.kaggle.com/jutrera/stanford-car-dataset-by-classes-folder'
od.download(dataset_url)
Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: jdariasl
Your Kaggle Key: ········
Dataset URL: https://www.kaggle.com/datasets/jutrera/stanford-car-dataset-by-classes-folder
Downloading stanford-car-dataset-by-classes-folder.zip to ./stanford-car-dataset-by-classes-folder
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.83G/1.83G [02:48<00:00, 11.7MB/s]

import os
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
import math
DATA_DIR_TRAIN = './stanford-car-dataset-by-classes-folder/car_data/car_data/train'
train_classes = os.listdir(DATA_DIR_TRAIN)

DATA_DIR_TEST = './stanford-car-dataset-by-classes-folder/car_data/car_data/test'
test_classes = os.listdir(DATA_DIR_TEST)
train_classes[:5], test_classes[:5]
(['Ford Expedition EL SUV 2009',
  'BMW M5 Sedan 2010',
  'Chevrolet Silverado 1500 Classic Extended Cab 2007',
  'MINI Cooper Roadster Convertible 2012',
  'Aston Martin V8 Vantage Convertible 2012'],
 ['Ford Expedition EL SUV 2009',
  'BMW M5 Sedan 2010',
  'Chevrolet Silverado 1500 Classic Extended Cab 2007',
  'MINI Cooper Roadster Convertible 2012',
  'Aston Martin V8 Vantage Convertible 2012'])
train_dataset = ImageFolder(DATA_DIR_TRAIN, transform = ToTensor())
Ns = len(train_dataset)
def show_images(datset, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15))
    #for i, img in enumerate(train_dataset):
    for i in range(Ns):
        if i == num_samples:
            break
        indx = np.random.permutation(Ns)
        img, _ = train_dataset[i+indx[0]]
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(np.moveaxis(img.numpy(),0,-1))
show_images(train_dataset)
../_images/bb0253d083250eb7d2a7512eb6653b096872aaec4f1e1a9f8257524713ce7184.png

Step 1: The forward process = Noise scheduler#

We first need to build the inputs for our model, which are more and more noisy images. Instead of doing this sequentially, we can use the closed form provided before to calculate the image for any of the timesteps individually.

Key Takeaways:

  • The noise-levels/variances can be pre-computed

  • There are different types of variance schedules

  • We can sample each timestep image independently (Sums of Gaussians is also Gaussian)

  • No model is needed in this forward step

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

Let’s test it on our dataset …

IMG_SIZE = 128
BATCH_SIZE = 128

data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
data_transform = transforms.Compose(data_transforms)

train_dataset = ImageFolder(DATA_DIR_TRAIN, transform = data_transform)
test_dataset = ImageFolder(DATA_DIR_TEST, transform = data_transform)

data = torch.utils.data.ConcatDataset([train_dataset, test_dataset])

dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

Simulate forward diffusion…

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))
    
image = next(iter(dataloader))[0]

plt.figure(figsize=(15,2))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    img, noise = forward_diffusion_sample(image, t)
    show_tensor_image(img)
../_images/d946b9c07ef5dc784256a6900d4ceb79eb56476df7c85c9ce8eddc8e09339901.png

Step 2: The backward process = U-Net#

For a great introduction to UNets, have a look at this post: https://amaarora.github.io/2020/09/13/unet.html.

Key Takeaways:

  • We use a simple form of a UNet for to predict the noise in the image

  • The input is a noisy image, the ouput the noise in the image

  • Because the parameters are shared accross time, we need to tell the network in which timestep we are

  • The Timestep is encoded by the transformer Sinusoidal Embedding

  • We output one single value (mean), because the variance is fixed

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, t, x):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, timestep, x):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(t, x)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(t, x)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model
Num params:  62438883
SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (2): Block(
      (time_mlp): Linear(in_features=32, out_features=512, bias=True)
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (3): Block(
      (time_mlp): Linear(in_features=32, out_features=1024, bias=True)
      (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
  )
  (ups): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=512, bias=True)
      (conv1): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (2): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (3): Block(
      (time_mlp): Linear(in_features=32, out_features=64, bias=True)
      (conv1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
  )
  (output): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
)

Step 3: The loss#

def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(t, x_noisy)
    return F.l1_loss(noise, noise_pred)
    #return F.mse_loss(noise, noise_pred)

Sampling#

  • Without adding @torch.no_grad() we quickly run out of memory, because pytorch tacks all the previous images for gradient calculation

  • Because we pre-calculated the noise variances for the forward pass, we also have to use them when we sequentially perform the backward process

@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(t, x) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,2))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()

Training#

#model = torch.load('Diff_U_net.pt',weights_only=False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
        loss = get_loss(model, batch[0], t)
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0 and step == 0:
            print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
            sample_plot_image()
Epoch 0 | step 000 Loss: 0.10153281688690186 
../_images/df14da3a485e5ca38a60cd895b4a83a789e89502cd540b9bf6c84f0ca7349c48.png
Epoch 5 | step 000 Loss: 0.10914735496044159 
../_images/59ceadc103548ea75039d85bf503745d664343b64aa54164135a4b9ec2c1d759.png
Epoch 10 | step 000 Loss: 0.1037554144859314 
../_images/142155dbe8cb65036a1ca3f219f67535646317513cddf1951c01fedfed72a0a0.png
Epoch 15 | step 000 Loss: 0.11673183739185333 
../_images/e76d1c42586a36a765f292413008ed77f535e219779e1b2f042ccc5d5ae80f16.png
Epoch 20 | step 000 Loss: 0.10023071616888046 
../_images/ae08065cee48b5e0d71c4b31a6938f85b4108968c1e5e0d475249bda28f3323f.png
Epoch 25 | step 000 Loss: 0.11450882256031036 
../_images/2eab5b5204e4f6409362cca53a1a1a2173125f6404740e23a9de661bcd7596db.png
Epoch 30 | step 000 Loss: 0.10585831105709076 
../_images/31f242f236923b9842df2f0586d5262f1a9812ced715976931d2fd7ea6f61439.png
Epoch 35 | step 000 Loss: 0.10274531692266464 
../_images/028c553ba4708b5c0916bdf7cdc7841b85b362f797da2d7ebf2c8d49ee8a3918.png
Epoch 40 | step 000 Loss: 0.10607850551605225 
../_images/00e2374764f38fe45da80f4c89b5b4af83236f974dadac9eb51f63cd35162ab1.png
Epoch 45 | step 000 Loss: 0.10293224453926086 
../_images/7f46e226fecf6f370caa5e7e608d73119e66e708bda439cbb3a3069e6fa341fd.png
Epoch 50 | step 000 Loss: 0.09736713021993637 
../_images/a32c07d3508921f33098849fe6c025801c2347c97d18d412a9caa2437888ce20.png
Epoch 55 | step 000 Loss: 0.10590964555740356 
../_images/e3419dc8894a3703bce10421e0bc7f688f2c70ed3c3bca60d7955baee6338b1c.png
Epoch 60 | step 000 Loss: 0.11159863322973251 
../_images/e3a1600764d97124f4fc57e9e736c2df3e1797996dae5d4b04c5f808122331e5.png
Epoch 65 | step 000 Loss: 0.10891596972942352 
../_images/4bc63c81f79c0376cefe57412af806ce37c58fb7514261dd2911465480fccdf5.png
Epoch 70 | step 000 Loss: 0.10659800469875336 
../_images/d2def2ec2b1f1a0e7c2b48da0d2cac746d94035e71730b2bf0d7145854247818.png
Epoch 75 | step 000 Loss: 0.10499440133571625 
../_images/f45e86223a3b36b3c482bf46c2bae7d4d82a24e137a9c0aff6a631b16786c48e.png
Epoch 80 | step 000 Loss: 0.10662849247455597 
../_images/96d97ee051e6673eb44f24b63c8bb07760c8dfa0cfd94169f6321e4d15311043.png
Epoch 85 | step 000 Loss: 0.10670851171016693 
../_images/42e1011973128bb5cbdad6fb68fd083769be08da402d85d6d15690ae25469d23.png
Epoch 90 | step 000 Loss: 0.10444175451993942 
../_images/bf0f6241c2862495395b571658b45237a27d9dccc06f08ce09320a0c1a4e3228.png
Epoch 95 | step 000 Loss: 0.1138913482427597 
../_images/237ca74d1ef2e2023e4a583dbdbb39f24532149f5d6c35d52aa417e1195ad1fd.png
torch.save(model, 'Diff_U_net.pt')

3. Score-based Generative Models#

The score (stein) function is defined as:

\[ {\bf{s}}_{\theta}({\bf{x}}) = \nabla_{\bf{x}} \log p_{\theta}({\bf{x}}) \]
\[\require{cancel}\]

There are two reason why optimizing a score function makes sense.

  • First, let’s assume you are estimating a pdf using a ANN, the output of the network function \(f_{\theta}({\bf{x}})\), which can be also understood as an energy function, so the probability distribution can be written as:

\[ p_{\theta}({\bf{x}}) = \frac{\exp(f_{\theta}({\bf{x}}))}{Z_{\theta}} \]

where \(Z_{\theta}\) is a normalizing constant to ensure that \(\int p_{\theta}({\bf{x}}) d{\bf{x}} = 1\). When applying Maximum likehood it is required to have acces to a tractably computable normalizing constant, which may not be possible for complex \(f_{\theta}(\cdot)\). However, by estimating the score function:

\[\begin{split} \begin{split} \nabla_{\bf{x}} \log p_{\theta}({\bf{x}}) & = \nabla_{\bf{x}} \log \left( \frac{\exp(f_{\theta}({\bf{x}}))}{Z_{\theta}} \right)\\ &= \nabla_{\bf{x}} f_{\theta}({\bf{x}}) - \cancel{\nabla_{\bf{x}} \log Z_{\theta}}\\ &= \nabla_{\bf{x}} f_{\theta}({\bf{x}})\\ &= {\bf{s}}_{\theta}({\bf{x}}) \end{split} \end{split}\]

Which can be freely represented by an ANN without involving any normalization constant.

  • Second, thanks to Langevin dynamics, we can sample from a distribution using only its score function by applying

\[ {\bf{x}}_{t+1} \leftarrow {\bf{x}}_{t} + \frac{\eta}{2}{\bf{s}}_{\theta}({\bf{x}}_t) \]

or to avoid the method to collpase we can apply Annealing Langevin dynamics:

\[ {\bf{x}}_{t+1} \leftarrow {\bf{x}}_{t} + \frac{\eta}{2}{\bf{s}}_{\theta}({\bf{x}}_t) + \sqrt{\eta}\epsilon \]

where \(\epsilon \sim \mathcal{N}({\bf{0}},{\bf{I}})\).

The score model can be optimized by minimizing the Fisher Divergence with the ground truth score function:

\[ \frac{1}{2}\mathbb{E}_{p({\bf{x}})} \left[ \|{\bf{s}}_{\theta}({\bf{x}}) - \underbrace{\nabla_{\bf{x}} \log p_{\theta}({\bf{x}})}_{\text{unknown}}\|_2^2\right] \]

Fourtunately, there is a technique call Score Matching [7] that only requires to have acces the score function. Using integration by parts, the authors derived the following expression which is equivalent to the Fisher divergence formulation:

\[ \mathbb{E}_{p({\bf{x}})} \left[ \frac{1}{2} \|{\bf{s}}_{\theta}({\bf{x}})\|_2^2 + \underbrace{\text{tr}\left( \underbrace{\nabla_{\bf{x}} {\bf{s}}_{\theta}}_{\text{Jacobian of} {\,\bf{s}}_{\theta}}\right)}_{\text{div}({\bf{s}}_{\theta})} \right] \]

The problem that arise is that \(\nabla_{\bf{x}} {\bf{s}}_{\theta}\) is hard to compute for images, videos and high dimensional data in general.

There are two main approaches:

  • Sliced score matching: Idea [8]: project onto random directions to reduce the dimensionality of the Jacobian and add an additional expectation (which is parallelizable). The loss function takes the form:

\[ \frac{1}{2}\mathbb{E}_{p_{\bf{v}}} \mathbb{E}_{p({\bf{x}})} \left[ \left({\bf{v}}^\top{\bf{s}}_{\theta}({\bf{x}}) - {\bf{v}}^\top\nabla_{\bf{x}}\log p_{\theta}({\bf{x}})\right)^2\right] \]

and by applying integration by parts again: $\( \mathbb{E}_{p_{\bf{v}}} \mathbb{E}_{p({\bf{x}})} \left[ {\bf{v}}^\top \nabla_{{\bf{x}}}{\bf{s}}_{\theta}({\bf{x}}) {\bf{v}} + \frac{1}{2} \left( {\bf{v}}^\top {\bf{s}}_{\theta}({\bf{x}})\right)^2\right] \)$

Note that \({\bf{v}}^\top \nabla_{{\bf{x}}}{\bf{s}}_{\theta}({\bf{x}}) {\bf{v}} = {\bf{v}}^\top \nabla_{{\bf{x}}} \left( {\bf{v}}^\top {\bf{s}}_{\theta}({\bf{x}}) \right)^\top\)

  • Denoising score matching (DSM): Idea [9]: avoid the Jacobian by changing the data distribution in the Fisher divergence by a conditional distribution \(q_{\sigma}(\tilde{{\bf x}}|{\bf x})\) and taking expectations over the two random variables. Thus

\[ \frac{1}{2}\mathbb{E}_{p({\bf{x}})} \mathbb{E}_{q_{\sigma}(\tilde{{\bf x}}|{\bf x})} \left[ \|{\bf{s}}_{\theta}(\tilde{\bf{x}}) - \nabla_{\tilde{\bf{x}}} \log q_{\sigma}(\tilde{{\bf x}}|{\bf x}) \|_2^2\right] \]

Note that assuming \(q_{\sigma}(\tilde{{\bf x}}|{\bf x}) = \mathcal{\tilde{\bf x};{\bf x},\sigma^2{\bf{I}}}\), \(\nabla_{\tilde{\bf{x}}} \log q_{\sigma}(\tilde{{\bf x}}|{\bf x}) \propto \frac{1}{\sigma^2} (\tilde{\bf{x}} - {{\bf x}})\).

The main drawback is that the method is unable to model the noise-free data distribution.

Let’s rewrite DSM using \(t \sim \mathcal{U}[1,T]\), \({\bf x}_0 \sim q({\bf x}_0)\) and \({\bf x}_t \sim q({\bf x}_t|{\bf x}_0)\), where \(q({\bf x}_t|{\bf x}_0) = \mathcal{{\bf x}_t;{\bf x}_0,\sigma_t^2{\bf{I}}}\). The loss function can be expressed by:

\[ \frac{1}{2}\mathbb{E}_{t \sim \mathcal{U}[1,T],q({\bf{x}}_0), q({\bf x}_t|{\bf x}_0)}\left[\lambda(t)\sigma_t^2 \left\|{\bf{s}}_{\theta}({\bf{x}}_t,t) + \frac{{\bf x}_t - {\bf x}_0}{\sigma_t^2} \right\|_2^2\right] \]
\[ \frac{1}{2}\mathbb{E}_{t \sim \mathcal{U}[1,T],q({\bf{x}}_0), q({\bf x}_t|{\bf x}_0)}\left[\lambda(t) \left\|\sigma_t{\bf{s}}_{\theta}({\bf{x}}_t,t) + \frac{{\bf x}_t - {\bf x}_0}{\sigma_t} \right\|_2^2\right] \]
\[ \frac{1}{2}\mathbb{E}_{t \sim \mathcal{U}[1,T],q({\bf{x}}_0), q({\bf x}_t|{\bf x}_0)}\left[\lambda(t) \left\|\epsilon + \sigma_t{\bf{s}}_{\theta}({\bf{x}}_t,t) \right\|_2^2\right] \]

Since \({\bf{x}}_t = {\bf{x}}_0 + \sigma_t\epsilon\), if we take \(\epsilon_{\theta}({\bf{x}}_t,t) = -\sigma_t {\bf{s}}_{\theta}({\bf{x}}_t,t)\), the former expression is equivalent to DDPM.

Why this result is important, because we can use Langevin dynamics to sample the diffusion process.

As we pointed out before, we have to add some noise to the sampler to avoid the paths’ collapse. However, since the score function is learnt using samples from the \(p({\bf{x}})\) distribution, the regions with low representation (low density regions) are not well-modeled. The most accepted solution is to generate samples using different \(\sigma_t\) values sequentially from large to small.

4. Stochastic Differential Equations#

DDPS and SGM can be further generalized to the case of infinite time steps or noise levels, where the perturbation and denoising process are solutions to stochastic differential equations (SDE) [10].

Score SDE perturb data noise with a difussion process governed by the following stochastic differential equation:

\[ d{\bf{x}} = {\bf{f}}({\bf{x}},t)dt + g(t)d{\bf{w}} \]

where \({\bf{f}}({\bf{x}},t)\) and \(g(t)\) are diffusion and drift functions of the SDE, respectively and \({\bf{w}}\) is a standard Wiener process.

A remarkable result from Anderson (1982) states that the reverse of a diffusion process is also a diffusion process, running backwards in time and given by the reverse-time SDE:

\[ d{\bf{x}} = \left[ {\bf{f}}({\bf{x}},t) - g(t)^2 \nabla_{{\bf{x}}} \log p_t({\bf{x}})\right]dt + g(t)d\bar{\bf{w}} \]

where \(\bar{\bf{w}}\) is a standard Wiener process when time flows backwards from \(T\) to 0.

In [10] we can find the expression for \({\bf{f}}(\cdot)\) and \(g(t)\) that makes the SDE formulation equivalent to DDPM and SGM. The relevant part is that, once trained the score function \({\bf{s}}_{\theta}({\bf{x}}_t,t)\), we can use SDEs solvers (Euler–Maruyama) to sample from the diffusion process efficiently. There are also some results using ODEs.

Image taken from [10]

def sample_timestep_iter(x,t):
    betas_t = get_index_from_list(betas, t, x.shape)
    g = -model(t,x)
    model_mean = (2 - torch.sqrt(1 - betas_t))*x + 0.5*betas_t*g
    r = 0.05
    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        e = 2*torch.sqrt(betas_t)*(r*torch.norm(noise)/torch.norm(g))**2
        return  model_mean +  e*noise

@torch.no_grad()
def sample_plot_image_sde():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,2))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep_iter(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()

# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)
sample_plot_image_sde()
../_images/799be3ddfd005ec31e8f4612840b7f0aa2b98a0a47ee64538eed8a21de97cd9e.png
\[\require{cancel}\]

5. Conditional generation#

So far, we have focused on modeling a complex data distribution \(p(x)\) and how to sample it. However, tipically we are interested in learning conditional distributions \(p(x|y)\), so that we can control the data we generate through conditioning information \(y\).

Classificer guidance#

By applying Bayes rule we know that:

\[ p(x|y) = \frac{p(y|x)p(x)}{\color{red}{p(y)}} \]

Nevertheless, if we use a score-based approach, where the goal is to learn \(\nabla_{x_t} \log p(x_t|y)\) at an arbitrary noise level \(t\), we can derive the following equivalent form:

\[\begin{split} \begin{split} \nabla_{x_t} \log p(x_t|y) &= \nabla_{x_t} \log \left(\frac{p(x_t)p(y|x_t)}{\color{red}{p(y)}}\right)\\ &= \nabla_{x_t} \log p(x_t) + \nabla_{x_t} \log p(y/x_t) - \cancel{\nabla_{x_t} \log \color{red}{p(y)}} \\ &= \underbrace{\nabla_{x_t} \log p(x_t)}_{\text{uncondicional score}} + \underbrace{\nabla_{x_t} \log p(y/x_t)}_{\text{adversarial gradient}} \end{split} \end{split}\]

Therefore, the score function for conditional generation is just the combination of a unconstrained generator and a noisy classifier which is easy to train and can change depending on the task. We can add a constant value to balance the importance of each term during generation:

\[ \nabla_{x_t} \log p(x_t|y) = \nabla_{x_t} \log p(x_t) + \gamma \nabla_{x_t} \log p(y/x_t) \]

Classifier-free guidance#

In [11], the authors ditch the training of a separate classifier model in favor of an unconditional diffusion model and a conditional diffusion model. By rearranging a previous result:

\[ \nabla_{x_t} \log p(y/x_t) = \nabla_{x_t} \log p(x_t|y) - \nabla_{x_t} \log p(x_t) \]

And pluging it into the classifier guidance score function:

\[\begin{split} \begin{split} \nabla_{x_t} \log p(x_t|y) &= \nabla_{x_t} \log p(x_t) + \gamma (\nabla_{x_t} \log p(x_t|y) - \nabla_{x_t} \log p(x_t))\\ &= \nabla_{x_t} \log p(x_t) + \gamma \nabla_{x_t} \log p(x_t|y) - \gamma \nabla_{x_t} \log p(x_t)\\ &= \underbrace{\gamma \nabla_{x_t} \log p(x_t|y)}_{\text{conditional score}} + \underbrace{(1 - \gamma)\nabla_{x_t} \log p(x_t)}_{\text{unconditional score}} \end{split} \end{split}\]
  • When \(\gamma=0\), the learned conditional model completely ignores the conditioner

  • When \(\gamma=1\), the model explicitly learns the vanilla conditional distribution

  • When \(\gamma > 1\), the diffusion model moves in the direction away from the unconditional score function, which reduces diversity but take a lot of attention to the conditioning information.

The good news is that this technique enables us greater control over our conditional generation procedure and requires only the training od one model, since we can learn both the conditional and unconditional diffusion models together as a singular conditional model; the unconditional model can be queried by replacing the conditioning information with a fixed value.

6. Using Diffusion models for unsupervised explainability#

Image taken from [12]

The general idea is as follows:

  • Take an image and corrupt it using \(L\) steps of a forward diffusion model

  • Take the corrupted images and denoise it using a classifier guided model. The classsifier is trained to discriminate normal vs pathological samples; therefore, during denoising, the conditional variable is set to a normal value.

  • Subtract the denoised image from the original. The resulting difference will highlight the features that classify the original image as pathological.

This can also be interpreted as a counterfactual explainability approach [18].

Further readings:#

  • Diffusion Schrödinger Bridge for taking from an arbitrary distribution A to a target distribution [13].

    • Schrödinger Bridge for Generative Speech Enhancement [17].

  • Flow Matching for Generative Modeling: Advanced alternatives for training DM with better results [14].

  • Physics-Informed Diffusion Models. Conditioning the generation process on physics-based forces [15], [16].

References#

[1] Yang, Y., Jin, M., Wen, H., Zhang, C., Liang, Y., Ma, L., … & Wen, Q. (2024). A survey on diffusion models for time series and spatio-temporal data. arXiv preprint arXiv:2404.18886.

[2] Brunton, S. L., & Kutz, J. N. (2022). Data-driven science and engineering: Machine learning, dynamical systems, and control. Cambridge University Press.

[3] Chen, N., Klushyn, A., Ferroni, F., Bayer, J., & Van Der Smagt, P. (2020). Learning flat latent manifolds with vaes. arXiv preprint arXiv:2002.04881.

[4] Luo, C. (2022). Understanding diffusion models: A unified perspective. arXiv preprint arXiv:2208.11970.

[5] Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. Advances in neural information processing systems, 33, 6840-6851.

[6] Kingma, D., Salimans, T., Poole, B., & Ho, J. (2021). Variational diffusion models. Advances in neural information processing systems, 34, 21696-21707.

[7] Hyvärinen, A., & Dayan, P. (2005). Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6(4).

[8] Song, Y., Garg, S., Shi, J., & Ermon, S. (2020, August). Sliced score matching: A scalable approach to density and score estimation. In Uncertainty in Artificial Intelligence (pp. 574-584). PMLR.

[9] Vincent, P. (2011). A connection between score matching and denoising autoencoders. Neural computation, 23(7), 1661-1674.

[10] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.

[11] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598.

[12] Wolleb, J., Bieder, F., Sandkühler, R., & Cattin, P. C. (2022, September). Diffusion models for medical anomaly detection. In International Conference on Medical image computing and computer-assisted intervention (pp. 35-45). Cham: Springer Nature Switzerland.

[13] De Bortoli, V., Thornton, J., Heng, J., & Doucet, A. (2021). Diffusion schrödinger bridge with applications to score-based generative modeling. Advances in Neural Information Processing Systems, 34, 17695-17709.

[14] Lipman, Y., Chen, R. T., Ben-Hamu, H., Nickel, M., & Le, M. (2022). Flow matching for generative modeling. arXiv preprint arXiv:2210.02747.

[15] Shu, D., Li, Z., & Farimani, A. B. (2023). A physics-informed diffusion model for high-fidelity flow field reconstruction. Journal of Computational Physics, 478, 111972.

[16] Bastek, J. H., Sun, W., & Kochmann, D. M. (2024). Physics-Informed Diffusion Models. arXiv preprint arXiv:2403.14404.

[17] Jukić, A., Korostik, R., Balam, J., & Ginsburg, B. (2024). Schrödinger Bridge for Generative Speech Enhancement. arXiv preprint arXiv:2407.16074.

[18] eanneret, G., Simon, L., & Jurie, F. (2022). Diffusion models for counterfactual explanations. In Proceedings of the Asian Conference on Computer Vision (pp. 858-876).