Il Modello di Diffusione Annotato

Modello di Diffusione Annotato

In questo post del blog, esamineremo più approfonditamente i Denoising Diffusion Probabilistic Models (anche noti come DDPM, modelli di diffusione, modelli generativi basati su punteggio o semplicemente autoencoder) poiché i ricercatori sono riusciti a ottenere risultati notevoli con essi per la generazione di immagini/audio/video (in)condizionale. Esempi popolari (al momento della scrittura) includono GLIDE e DALL-E 2 di OpenAI, Latent Diffusion dell’Università di Heidelberg e ImageGen di Google Brain.

Esamineremo il paper originale DDPM di (Ho et al., 2020), implementandolo passo dopo passo in PyTorch, basandoci sull’implementazione di Phil Wang – che a sua volta si basa sull’implementazione originale di TensorFlow. Si noti che l’idea della diffusione per la modellazione generativa è stata già introdotta in (Sohl-Dickstein et al., 2015). Tuttavia, è stato necessario attendere fino a (Song et al., 2019) (all’Università di Stanford) e poi a (Ho et al., 2020) (a Google Brain) che hanno migliorato l’approccio in modo indipendente.

Si noti che ci sono diverse prospettive sui modelli di diffusione. Qui, utilizziamo la prospettiva del modello a variabile latente a tempo discreto, ma assicurati di esaminare anche le altre prospettive.

Ok, iniziamo!

from IPython.display import Image
Image(filename='assets/78_annotated-diffusion/ddpm_paper.png')

Per prima cosa installeremo e importeremo le librerie necessarie (supponendo che tu abbia già installato PyTorch).

!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

Cos’è un modello di diffusione?

Un modello di diffusione (denoising) non è così complesso se lo confronti con altri modelli generativi come Normalizing Flows, GAN o VAE: tutti convertono il rumore da una distribuzione semplice in un campione di dati. Questo è anche il caso qui, dove una rete neurale impara a denoizzare gradualmente i dati partendo da un rumore puro.

In modo un po’ più dettagliato per le immagini, la configurazione consiste in 2 processi:

  • un processo di diffusione in avanti fisso (o predefinito) q q q a nostra scelta, che aggiunge gradualmente rumore gaussiano a un’immagine, fino ad ottenere un rumore puro
  • un processo di denoising inverso appreso p θ p_\theta p θ ​ , in cui una rete neurale viene addestrata a denoizzare gradualmente un’immagine partendo da un rumore puro, fino ad ottenere un’immagine effettiva.

Sia il processo in avanti che quello inverso indicizzati da t t t avvengono per un certo numero di passi temporali finiti T T T (gli autori del DDPM utilizzano T = 1000). Si parte con t = 0 in cui si campiona un’immagine reale x 0 \mathbf{x}_0 x 0 ​ dalla distribuzione dei dati (diciamo un’immagine di un gatto da ImageNet) e il processo in avanti campiona del rumore da una distribuzione gaussiana ad ogni passo temporale t t t, che viene aggiunto all’immagine del passo temporale precedente. Dato un T T T sufficientemente grande e un programma ben strutturato per aggiungere rumore ad ogni passo temporale, si ottiene quello che viene chiamato una distribuzione gaussiana isotropa a t = T tramite un processo graduale.

In forma più matematica

Scriviamo tutto in modo più formale, poiché alla fine abbiamo bisogno di una funzione di perdita gestibile che la nostra rete neurale deve ottimizzare.

Sia q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ​ ) la distribuzione dei dati reali, ad esempio di “immagini reali”. Possiamo campionare da questa distribuzione per ottenere un’immagine, x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x 0 ​ ∼ q ( x 0 ​ ) . Definiamo il processo di diffusione in avanti q ( x t ∣ x t − 1 ) q(\mathbf{x}_t | \mathbf{x}_{t-1}) q ( x t ​ ∣ x t − 1 ​ ) che aggiunge rumore gaussiano ad ogni passo temporale t t t, secondo un programma di varianza noto 0 < β 1 < β 2 < . . . < β T < 1 0 < \beta_1 < \beta_2 < … < \beta_T < 1 0 < β 1 ​ < β 2 ​ < . . . < β T ​ < 1 come q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) . q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 – \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}). q ( x t ​ ∣ x t − 1 ​ ) = N ( x t ​ ; 1 − β t ​ ​ x t − 1 ​ , β t ​ I ) .

Ricordiamo che una distribuzione normale (chiamata anche distribuzione gaussiana) è definita da 2 parametri: una media μ \mu μ e una varianza σ 2 ≥ 0 \sigma^2 \geq 0 σ 2 ≥ 0 . Fondamentalmente, ogni nuova immagine (leggermente più rumorosa) al passaggio temporale t t t è estratta da una distribuzione gaussiana condizionale con μ t = 1 − β t x t − 1 \mathbf{\mu}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} μ t ​ = 1 − β t ​ ​ x t − 1 ​ e σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ , che possiamo ottenere campionando ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) e poi impostando x t = 1 − β t x t − 1 + β t ϵ \mathbf{x}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} x t ​ = 1 − β t ​ ​ x t − 1 ​ + β t ​ ​ ϵ .

Si noti che i β t \beta_t β t ​ non sono costanti ad ogni passaggio temporale t t t (da qui l’indice) — infatti si definisce una cosiddetta “varianza schedule” , che può essere lineare, quadratico, coseno, ecc., come vedremo più avanti (un po’ come un programma di apprendimento).

Quindi partendo da x 0 \mathbf{x}_0 x 0 ​ , otteniamo x 1 , . . . , x t , . . . , x T \mathbf{x}_1, …, \mathbf{x}_t, …, \mathbf{x}_T x 1 ​ , . . . , x t ​ , . . . , x T ​ , dove x T \mathbf{x}_T x T ​ è solo rumore gaussiano se impostiamo adeguatamente lo schedule.

Ora, se conoscessimo la distribuzione condizionale p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) , potremmo eseguire il processo al contrario: campionando un po’ di rumore gaussiano casuale x T \mathbf{x}_T x T ​ , e poi gradualmente “denoizzarlo” in modo da ottenere un campione dalla vera distribuzione x 0 \mathbf{x}_0 x 0 ​ .

Tuttavia, non conosciamo p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) . È intractable poiché richiede la conoscenza della distribuzione di tutte le possibili immagini per calcolare questa probabilità condizionale. Perciò, useremo una rete neurale per approssimare (apprendere) questa distribuzione di probabilità condizionale , chiamiamola p θ ( x t − 1 ∣ x t ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) p θ ​ ( x t − 1 ​ ∣ x t ​ ) , con θ \theta θ che rappresenta i parametri della rete neurale, aggiornati tramite discesa del gradiente.

Ok, quindi abbiamo bisogno di una rete neurale per rappresentare una distribuzione di probabilità (condizionale) del processo inverso. Se assumiamo che anche questo processo inverso sia gaussiano, ricordiamo che qualsiasi distribuzione gaussiana è definita da 2 parametri:

  • una media parametrizzata da μ θ \mu_\theta μ θ ​ ;
  • una varianza parametrizzata da Σ θ \Sigma_\theta Σ θ ​ ;

quindi possiamo parametrizzare il processo come p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t)) p θ ​ ( x t − 1 ​ ∣ x t ​ ) = N ( x t − 1 ​ ; μ θ ​ ( x t ​ , t ) , Σ θ ​ ( x t ​ , t ) ) dove la media e la varianza sono anche condizionate dal livello di rumore t t t .

Quindi, la nostra rete neurale deve imparare/rappresentare la media e la varianza. Tuttavia, gli autori di DDPM hanno deciso di mantenere la varianza fissa e consentire alla rete neurale di imparare (rappresentare) solo la media μ θ \mu_\theta μ θ ​ di questa distribuzione di probabilità condizionale. Dal paper:

Inizialmente, impostiamo Σ θ ( x t , t ) = σ t 2 I \Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I} Σ θ ​ ( x t ​ , t ) = σ t 2 ​ I come costanti dipendenti dal tempo non addestrate. Sperimentalmente, sia σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ sia σ t 2 = β ~ t \sigma^2_t = \tilde{\beta}_t σ t 2 ​ = β ~ ​ t ​ (vedi paper) hanno prodotto risultati simili.

Questo è stato successivamente migliorato nel paper sui modelli di diffusione migliorati, in cui una rete neurale impara anche la varianza di questo processo inverso, oltre alla media.

Proseguiamo quindi, assumendo che la nostra rete neurale debba solo imparare/rappresentare la media di questa distribuzione di probabilità condizionale.

Definizione di una funzione obiettivo (tramite riparametrizzazione della media)

Per derivare una funzione obiettivo per imparare la media del processo inverso, gli autori osservano che la combinazione di q q q e p θ p_\theta p θ ​ può essere vista come un variational auto-encoder (VAE) (Kingma et al., 2013). Pertanto, il limite inferiore variazionale (detto anche ELBO) può essere utilizzato per minimizzare il logaritmo negativo della probabilità congiunta rispetto ai dati di verità fondamentale x 0 \mathbf{x}_0 x 0 ​ (facciamo riferimento al paper VAE per i dettagli riguardanti l’ELBO). Risulta che l’ELBO per questo processo è la somma delle perdite ad ogni passaggio temporale t t t , L = L 0 + L 1 + . . . + L T L = L_0 + L_1 + … + L_T L = L 0 ​ + L 1 ​ + . . . + L T ​ . Attraverso la costruzione dei processi q q q in avanti e inverso, ogni termine (ad eccezione di L 0 L_0 L 0 ​ ) della perdita è effettivamente la divergenza KL tra 2 distribuzioni Gaussiane che può essere scritta esplicitamente come una perdita L2 rispetto alle medie!

Una conseguenza diretta del processo in avanti q q q costruito, come mostrato da Sohl-Dickstein et al., è che possiamo campionare x t \mathbf{x}_t x t ​ a qualsiasi livello di rumore arbitrario condizionato a x 0 \mathbf{x}_0 x 0 ​ (poiché la somma di Gaussiane è ancora una Gaussiana). Questo è molto comodo: non abbiamo bisogno di applicare q q q ripetutamente per campionare x t \mathbf{x}_t x t ​ . Abbiamo che q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t | \mathbf{x}_0) = \cal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1- \bar{\alpha}_t) \mathbf{I}) q ( x t ​ ∣ x 0 ​ ) = N ( x t ​ ; α ˉ t ​ ​ x 0 ​ , ( 1 − α ˉ t ​ ) I )

con α t : = 1 − β t \alpha_t := 1 – \beta_t α t ​ : = 1 − β t ​ e α ˉ t : = Π s = 1 t α s \bar{\alpha}_t := \Pi_{s=1}^{t} \alpha_s α ˉ t ​ : = Π s = 1 t ​ α s ​ . Chiamiamo questa equazione “proprietà interessante”. Ciò significa che possiamo campionare rumore gaussiano, scalare adeguatamente e aggiungerlo a x 0 \mathbf{x}_0 x 0 ​ per ottenere direttamente x t \mathbf{x}_t x t ​. Nota che gli α ˉ t \bar{\alpha}_t α ˉ t ​ sono funzioni della nota programmazione del β t \beta_t β t ​ della varianza e quindi sono anche noti e possono essere precomputati. Ciò ci consente, durante l’addestramento, di ottimizzare termini casuali della funzione di perdita L L L (o, in altre parole, di campionare casualmente t t t durante l’addestramento ed ottimizzare L t L_t L t ​ ).

Un’altra bellezza di questa proprietà, come mostrato in Ho et al., è che si può (dopo un po’ di matematica, per la quale si rimanda il lettore a questo eccellente post nel blog) invece riparametrizzare la media per far apprendere alla rete neurale (prevedere) il rumore aggiunto (tramite una rete ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) ) per livello di rumore t t t nei termini KL che costituiscono le perdite. Ciò significa che la nostra rete neurale diventa un predittore di rumore, piuttosto che un predittore (diretto) di media. La media può essere calcolata come segue:

μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t – \frac{\beta_t}{\sqrt{1- \bar{\alpha}_t}} \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \right) μ θ ​ ( x t ​ , t ) = α t ​ ​ 1 ​ ( x t ​ − 1 − α ˉ t ​ ​ β t ​ ​ ϵ θ ​ ( x t ​ , t ) )

La funzione obiettivo finale L t L_t L t ​ appare quindi come segue (per un passo temporale casuale t t t dato ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) ):

∥ ϵ − ϵ θ ( x t , t ) ∥ 2 = ∥ ϵ − ϵ θ ( α ˉ t x 0 + ( 1 − α ˉ t ) ϵ , t ) ∥ 2 . \| \mathbf{\epsilon} – \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \|^2 = \| \mathbf{\epsilon} – \mathbf{\epsilon}_\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{(1- \bar{\alpha}_t) } \mathbf{\epsilon}, t) \|^2. ∥ ϵ − ϵ θ ​ ( x t ​ , t ) ∥ 2 = ∥ ϵ − ϵ θ ​ ( α ˉ t ​ ​ x 0 ​ + ( 1 − α ˉ t ​ ) ​ ϵ , t ) ∥ 2 .

In queste formule, x 0 \mathbf{x}_0 x 0 ​ è l’immagine iniziale (reale, non corrotta), e vediamo il livello di rumore diretto t t t campionato mediante il processo inoltrato fisso. ϵ \mathbf{\epsilon} ϵ è il rumore puro campionato al passo temporale t t t , e ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta (\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) è la nostra rete neurale. La rete neurale viene ottimizzata utilizzando un semplice errore quadratico medio (MSE) tra il rumore gaussiano vero e quello predetto.

L’algoritmo di addestramento ora appare come segue:

In altre parole:

  • prendiamo un campione casuale x 0 \mathbf{x}_0 x 0 ​ dalla distribuzione di dati sconosciuta e eventualmente complessa q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ​ )
  • campioniamo un livello di rumore t t t in modo uniforme tra 1 1 1 e T T T (ovvero, un passo temporale casuale)
  • campioniamo del rumore da una distribuzione gaussiana e corrompiamo l’input con questo rumore al livello t t t (utilizzando la bella proprietà definita sopra)
  • la rete neurale viene addestrata per prevedere questo rumore basandosi sull’immagine corrotta x t \mathbf{x}_t x t ​ (cioè rumore applicato su x 0 \mathbf{x}_0 x 0 ​ in base a un programma noto β t \beta_t β t ​ )

In realtà, tutto ciò viene fatto su batch di dati, poiché si utilizza la discesa del gradiente stocastica per ottimizzare le reti neurali.

La rete neurale

La rete neurale deve prendere un’immagine rumorosa in un determinato passo temporale e restituire il rumore previsto. Si noti che il rumore previsto è un tensore che ha la stessa dimensione/risoluzione dell’immagine di input. Quindi tecnicamente, la rete prende in input e restituisce tensori della stessa forma. Quale tipo di rete neurale possiamo usare per questo?

Ciò che viene tipicamente utilizzato qui è molto simile a quello di un Autoencoder, che potresti ricordare dai tutorial tipici di “introduzione al deep learning”. Gli Autoencoder hanno uno strato chiamato “bottleneck” tra l’encoder e il decoder. L’encoder codifica prima un’immagine in una rappresentazione nascosta più piccola chiamata “bottleneck”, e il decoder decodifica quindi quella rappresentazione nascosta in un’immagine effettiva. Ciò costringe la rete a mantenere solo le informazioni più importanti nello strato del bottleneck.

In termini di architettura, gli autori del DDPM hanno scelto un U-Net, introdotto da (Ronneberger et al., 2015) (che, all’epoca, ha ottenuto risultati di primo piano per la segmentazione delle immagini mediche). Questa rete, come qualsiasi autoencoder, è composta da uno strato del bottleneck nel mezzo che si assicura che la rete apprenda solo le informazioni più importanti. In modo importante, ha introdotto connessioni residue tra l’encoder e il decoder, migliorando notevolmente il flusso del gradiente (ispirato a ResNet in He et al., 2015).

Come si può vedere, un modello U-Net ridimensiona prima l’input (cioè riduce la risoluzione spaziale dell’input), dopodiché esegue un upsampling.

Di seguito, implementiamo questa rete, passo dopo passo.

Assistenti di rete

Innanzitutto, definiamo alcune funzioni e classi di supporto che verranno utilizzate durante l’implementazione della rete neurale. In modo importante, definiamo un modulo Residual, che semplicemente aggiunge l’input all’output di una particolare funzione (in altre parole, aggiunge una connessione residua a una particolare funzione).

Definiamo anche alias per le operazioni di up- e downsampling.

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

Incorporazioni di posizione

Poiché i parametri della rete neurale sono condivisi nel tempo (livello di rumore), gli autori utilizzano incorporazioni di posizione sinusoidali per codificare t t t, ispirate al Transformer (Vaswani et al., 2017). Ciò consente alla rete neurale di “sapere” in quale specifico passo temporale (livello di rumore) sta operando, per ogni immagine in un batch.

Il modulo SinusoidalPositionEmbeddings prende un tensore di forma (batch_size, 1) in input (cioè i livelli di rumore di diverse immagini rumorose in un batch) e lo trasforma in un tensore di forma (batch_size, dim), con dim che rappresenta la dimensionalità delle incorporazioni di posizione. Questo viene quindi aggiunto a ogni blocco residuo, come vedremo in seguito.

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

Blocco ResNet

Successivamente, definiamo il blocco fondamentale del modello U-Net. Gli autori di DDPM hanno utilizzato un blocco Wide ResNet (Zagoruyko et al., 2016), ma Phil Wang ha sostituito il livello convoluzionale standard con una versione “weight standardized”, che funziona meglio in combinazione con la normalizzazione di gruppo (vedi (Kolesnikov et al., 2019) per i dettagli).

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    La standardizzazione dei pesi funziona sinergicamente con la normalizzazione di gruppo
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

Modulo di attenzione

Successivamente, definiamo il modulo di attenzione, che gli autori di DDPM hanno aggiunto tra i blocchi convoluzionali. L’attenzione è il blocco fondamentale della famosa architettura Transformer (Vaswani et al., 2017), che ha mostrato grande successo in vari domini dell’AI, dall’NLP alla visione al folding delle proteine. Phil Wang utilizza 2 varianti di attenzione: una è l’auto-attenzione multi-head regolare (come utilizzata nel Transformer), l’altra è una variante lineare dell’attenzione (Shen et al., 2018), i cui requisiti di tempo e memoria scalano linearmente con la lunghezza della sequenza, a differenza dell’attenzione regolare che scala quadraticamente.

Per una spiegazione approfondita del meccanismo di attenzione, si rimanda al meraviglioso post del blog di Jay Allamar.

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

Normalizzazione di gruppo

Gli autori di DDPM alternano i livelli di convoluzione/attenzione dell’U-Net con la normalizzazione di gruppo (Wu et al., 2018). Di seguito, definiamo una classe PreNorm, che verrà utilizzata per applicare la groupnorm prima del livello di attenzione, come vedremo in seguito. È importante notare che c’è stato un dibattito su se applicare la normalizzazione prima o dopo l’attenzione nei Transformers.

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

U-Net condizionale

Ora che abbiamo definito tutti i blocchi di costruzione (embedding di posizione, blocchi ResNet, attenzione e normalizzazione di gruppo), è il momento di definire l’intera rete neurale. Ricordiamo che il compito della rete ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) è quello di prendere in input un batch di immagini rumorose e i rispettivi livelli di rumore, e restituire il rumore aggiunto all’input. Più formalmente:

  • la rete riceve in input un batch di immagini rumorose di forma (batch_size, num_channels, height, width) e un batch di livelli di rumore di forma (batch_size, 1), e restituisce un tensore di forma (batch_size, num_channels, height, width)

La rete è costruita nel seguente modo:

  • prima di tutto, viene applicato un livello convoluzionale sul batch di immagini rumorose, e vengono calcolati gli embedding di posizione per i livelli di rumore
  • successivamente, vengono applicate una serie di fasi di downsampling. Ogni fase di downsampling consiste di 2 blocchi ResNet + groupnorm + attenzione + connessione residuale + un’operazione di downsampling
  • a metà della rete, vengono nuovamente applicati blocchi ResNet, alternati all’attenzione
  • successivamente, vengono applicate una serie di fasi di upsampling. Ogni fase di upsampling consiste di 2 blocchi ResNet + groupnorm + attenzione + connessione residuale + un’operazione di upsampling
  • infine, viene applicato un blocco ResNet seguito da un livello convoluzionale.

In definitiva, le reti neurali impilano strati come se fossero blocchi Lego (ma è importante capire come funzionano).

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # determina le dimensioni
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # cambiato da 7,3 a 1 e 0

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # embedding temporali
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

Definizione del processo di diffusione in avanti

Il processo di diffusione in avanti aggiunge gradualmente rumore a un’immagine dalla distribuzione reale, in un certo numero di passaggi temporali T T T. Ciò avviene secondo uno schema di varianza. Gli autori originali di DDPM hanno utilizzato uno schema lineare:

Impostiamo le varianze del processo in avanti a costanti che aumentano linearmente da β 1 = 1 0 − 4 \beta_1 = 10^{−4} β 1 ​ = 1 0 − 4 a β T = 0.02 \beta_T = 0.02 β T ​ = 0 . 0 2 .

Tuttavia, è stato dimostrato in (Nichol et al., 2021) che si possono ottenere risultati migliori utilizzando uno schema coseno.

Di seguito, definiamo vari schemi per i passaggi temporali T T T (ne sceglieremo uno in seguito).

def cosine_beta_schedule(timesteps, s=0.008):
    """
    schema coseno come proposto in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

Per iniziare, utilizziamo lo schema lineare per T = 300 T=300 T = 3 0 0 passaggi temporali e definiamo le varie variabili dai β t \beta_t β t ​ di cui avremo bisogno, come il prodotto cumulativo delle varianze α ˉ t \bar{\alpha}_t α ˉ t ​ . Ognuna delle variabili sottostanti è solo un tensore unidimensionale, che memorizza i valori da t t t a T T T . In modo importante, definiamo anche una funzione extract, che ci permetterà di estrarre l’indice t t t appropriato per un batch di indici.

timesteps = 300

# definisci lo schema beta
betas = linear_beta_schedule(timesteps=timesteps)

# definisci le alphas
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calcoli per la diffusione q(x_t | x_{t-1}) e altri
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

# calcoli per il posteriore q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

Illustreremo con un’immagine di gatti come viene aggiunto rumore ad ogni passaggio temporale del processo di diffusione.

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # immagine PIL di forma HWC
image

Il rumore viene aggiunto ai tensori PyTorch, piuttosto che alle immagini di Pillow. Innanzitutto definiremo le trasformazioni dell’immagine che ci permettono di passare da un’immagine PIL a un tensore PyTorch (su cui possiamo aggiungere il rumore) e viceversa.

Queste trasformazioni sono piuttosto semplici: prima normalizziamo le immagini dividendo per 255 255 2 5 5 (in modo che siano nell’intervallo [ 0 , 1 ] [0,1] [ 0 , 1 ]), e poi ci assicuriamo che siano nell’intervallo [ − 1 , 1 ] [-1, 1] [ − 1 , 1 ]. Secondo il paper DPPM:

Assumiamo che i dati dell’immagine siano costituiti da interi in { 0 , 1 , . . . , 255 } \{0, 1, … , 255\} { 0 , 1 , . . . , 2 5 5 } scalati linearmente in [ − 1 , 1 ] [−1, 1] [ − 1 , 1 ] . Questo assicura che il processo di retroazione della rete neurale operi su input scalati in modo coerente a partire dalla prior normale standard p ( x T ) p(\mathbf{x}_T ) p ( x T ​ ) .

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

dimensione_immagine = 128
trasformazione = Compose([
    Resize(dimensione_immagine),
    CenterCrop(dimensione_immagine),
    ToTensor(), # convertire in un tensore PyTorch di forma CHW, dividere per 255
    Lambda(lambda t: (t * 2) - 1),
    
])

x_iniziale = trasformazione(immagine).unsqueeze(0)
x_iniziale.shape

Definiamo anche la trasformazione inversa, che prende in input un tensore PyTorch contenente valori in [ − 1 , 1 ] [-1, 1] [ − 1 , 1 ] e li trasforma nuovamente in un’immagine PIL:

import numpy as np

trasformazione_inversa = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW a HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

Verifichiamolo:

trasformazione_inversa(x_iniziale.squeeze())

Ora possiamo definire il processo di diffusione in avanti come nel paper:

# diffusione in avanti (utilizzando la bella proprietà)
def campione_q(x_iniziale, t, rumore=None):
    if rumore is None:
        rumore = torch.randn_like(x_iniziale)

    sqrt_alphas_cumprod_t = estrai(sqrt_alphas_cumprod, t, x_iniziale.shape)
    sqrt_one_minus_alphas_cumprod_t = estrai(
        sqrt_one_minus_alphas_cumprod, t, x_iniziale.shape
    )

    return sqrt_alphas_cumprod_t * x_iniziale + sqrt_one_minus_alphas_cumprod_t * rumore

Testiamolo su un particolare passo temporale:

def ottieni_immagine_rumore(x_iniziale, t):
  # aggiungi rumore
  x_rumoroso = campione_q(x_iniziale, t=t)

  # trasforma nuovamente in un'immagine PIL
  immagine_rumorosa = trasformazione_inversa(x_rumoroso.squeeze())

  return immagine_rumorosa

# prendi un passo temporale
t = torch.tensor([40])

ottieni_immagine_rumore(x_iniziale, t)

Visualizziamolo per vari passi temporali:

import matplotlib.pyplot as plt

# utilizziamo un seed per la riproducibilità
torch.manual_seed(0)

# fonte: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Crea una griglia 2D anche se c'è solo 1 riga
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [immagine] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Immagine originale')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

plot([ottieni_immagine_rumore(x_iniziale, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

Ciò significa che possiamo ora definire la funzione di perdita data il modello come segue:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

Il denoise_model sarà il nostro U-Net definito in precedenza. Utilizzeremo la perdita di Huber tra il rumore vero e il rumore previsto.

Definire un dataset PyTorch + DataLoader

Qui definiamo un normale dataset PyTorch . Il dataset consiste semplicemente di immagini provenienti da un dataset reale, come Fashion-MNIST, CIFAR-10 o ImageNet, scalate linearmente a [ − 1 , 1 ] [−1, 1] [ − 1 , 1 ] .

Ogni immagine viene ridimensionata alla stessa dimensione. Interessante notare che le immagini sono anche ribaltate orizzontalmente in modo casuale. Dal paper:

Abbiamo utilizzato ribaltamenti orizzontali casuali durante l’addestramento per CIFAR10; abbiamo provato ad addestrare sia con che senza ribaltamenti, e abbiamo riscontrato che i ribaltamenti migliorano leggermente la qualità del campione.

Qui utilizziamo la libreria 🤗 Datasets per caricare facilmente il dataset Fashion MNIST dall’hub . Questo dataset consiste di immagini che hanno già la stessa risoluzione, ovvero 28×28.

from datasets import load_dataset

# carica il dataset dall'hub
dataset = load_dataset("fashion_mnist")
dimensione_immagine = 28
canali = 1
dimensione_batch = 128

Successivamente, definiamo una funzione che applicheremo in tempo reale all’intero dataset. Utilizziamo la funzionalità with_transform per questo. La funzione applica semplicemente una pre-elaborazione di base delle immagini: ribaltamenti orizzontali casuali, ridimensionamento e infine le rende avere valori nell’intervallo [ − 1 , 1 ] [-1,1] [ − 1 , 1 ] .

from torchvision import transforms
from torch.utils.data import DataLoader

# definisce le trasformazioni delle immagini (ad esempio, utilizzando torchvision)
trasformazione = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# definisce la funzione
def trasforma(esempi):
   esempi["valori_pixel"] = [trasformazione(immagine.convert("L")) for immagine in esempi["immagine"]]
   del esempi["immagine"]

   return esempi

dataset_trasformato = dataset.with_transform(trasforma).remove_columns("etichetta")

# crea il dataloader
dataloader = DataLoader(dataset_trasformato["train"], batch_size=dimensione_batch, shuffle=True)

batch = next(iter(dataloader))
print(batch.keys())

Campionamento

Dato che campioneremo dal modello durante l’addestramento (per monitorare i progressi), definiamo il codice per questo di seguito. Il campionamento è riassunto nel paper come Algoritmo 2:

La generazione di nuove immagini da un modello di diffusione avviene invertendo il processo di diffusione: partiamo da T T T , dove campioniamo rumore puro da una distribuzione gaussiana, e quindi utilizziamo la nostra rete neurale per denoizzarlo gradualmente (utilizzando la probabilità condizionata che ha appreso), fino ad arrivare al passaggio di tempo t = 0 t = 0 t = 0 . Come mostrato sopra, possiamo ottenere un’immagine leggermente meno denoizzata x t − 1 \mathbf{x}_{t-1 } x t − 1 ​ collegando la riparametrizzazione della media, utilizzando il nostro predittore del rumore. Ricordiamo che la varianza è nota in anticipo.

Idealemente, otteniamo un’immagine che sembra provenire dalla distribuzione dei dati reali.

Il codice di seguito implementa ciò.

@torch.no_grad()
def p_sample(modello, x, t, t_index):
    betas_t = estrai(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = estrai(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = estrai(sqrt_recip_alphas, t, x.shape)
    
    # Equazione 11 nel paper
    # Utilizza il nostro modello (predittore del rumore) per prevedere la media
    media_modello = sqrt_recip_alphas_t * (
        x - betas_t * modello(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return media_modello
    else:
        devianza_posteriore_t = estrai(devianza_posteriore, t, x.shape)
        rumore = torch.randn_like(x)
        # Algoritmo 2 riga 4:
        return media_modello + torch.sqrt(devianza_posteriore_t) * rumore 

# Algoritmo 2 (incluso il ritorno di tutte le immagini)
@torch.no_grad()
def p_sample_loop(modello, forma):
    dispositivo = next(modello.parameters()).device

    b = forma[0]
    # partiamo da rumore puro (per ogni esempio nel batch)
    img = torch.randn(forma, device=dispositivo)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='campionamento ciclo passo temporale', total=timesteps):
        img = p_sample(modello, img, torch.full((b,), i, device=dispositivo, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def campione(modello, dimensione_immagine, dimensione_batch=16, canali=3):
    return p_sample_loop(modello, shape=(dimensione_batch, canali, dimensione_immagine, dimensione_immagine))

Si noti che il codice sopra è una versione semplificata dell’implementazione originale. Abbiamo trovato che la nostra semplificazione (che è in linea con l’Algoritmo 2 nel paper) funziona altrettanto bene dell’implementazione originale, più complessa, che utilizza il clipping.

Allenare il modello

Successivamente, alleniamo il modello nel modo standard di PyTorch. Definiamo anche una logica per salvare periodicamente le immagini generate, utilizzando il metodo sample definito in precedenza.

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

Di seguito, definiamo il modello e lo spostiamo sulla GPU. Definiamo anche un ottimizzatore standard (Adam).

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

Iniziamo l’allenamento!

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algoritmo 1 riga 3: campiona t in modo uniforme per ogni esempio nel batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # salva le immagini generate
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

Campionamento (inferenza)

Per campionare dal modello, possiamo semplicemente utilizzare la nostra funzione di campionamento definita in precedenza:

# campiona 64 immagini
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# mostra una casuale
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

Sembra che il modello sia in grado di generare una bella maglietta! Tieni presente che il dataset su cui abbiamo allenato è di risoluzione piuttosto bassa (28×28).

Possiamo anche creare un gif del processo di denoising:

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

Si noti che il paper DDPM ha dimostrato che i modelli di diffusione sono una direzione promettente per la generazione di immagini (non)condizionali. Questo è stato migliorato enormemente, soprattutto per la generazione di immagini condizionate dal testo. Di seguito, elenchiamo alcuni lavori successivi importanti (ma non esaustivi):

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021): trova che imparare la varianza della distribuzione condizionale (oltre alla media) aiuta a migliorare le prestazioni
  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021): introduce la diffusione a cascata, che comprende una pipeline di più modelli di diffusione che generano immagini di risoluzione crescente per la sintesi di immagini ad alta fedeltà
  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021): dimostra che i modelli di diffusione possono raggiungere una qualità di campionamento delle immagini superiore rispetto ai modelli generativi attuali di ultima generazione migliorando l’architettura U-Net e introducendo una guida al classificatore
  • Classifier-Free Diffusion Guidance (Ho et al., 2021): dimostra che non è necessario un classificatore per guidare un modello di diffusione addestrando congiuntamente un modello di diffusione condizionale e uno incondizionato con una singola rete neurale
  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022): utilizza un prior per trasformare una didascalia di testo in un embedding di immagine CLIP, dopodiché un modello di diffusione lo decodifica in un’immagine
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022): dimostra che combinare un grande modello di linguaggio pre-addestrato (ad es. T5) con la diffusione a cascata funziona bene per la sintesi di testo-immagine

Si tenga presente che questa lista include solo opere importanti fino al momento della scrittura, che è il 7 giugno 2022.

Al momento, sembra che l’unico svantaggio principale dei modelli di diffusione sia che richiedono più passaggi in avanti per generare un’immagine (cosa che non accade per i modelli generativi come le GAN). Tuttavia, ci sono ricerche in corso che consentono la generazione ad alta fedeltà con pochi passaggi di denoising, fino a un massimo di 10.