Inferenza Variazionale Le Basi

'Variational Inference Basics'

Viviamo nell’era della quantificazione. Ma la quantificazione rigorosa è più facile a dirsi che a farsi. Nei sistemi complessi come la biologia, i dati possono essere difficili ed costosi da raccogliere. Mentre in applicazioni ad alto rischio come in medicina e finanza, è cruciale tenere conto dell’incertezza. L’inferenza variazionale – una metodologia all’avanguardia della ricerca in Intelligenza Artificiale – è un modo per affrontare questi aspetti.

Questo tutorial ti introduce alle basi: quando, perché e come usare l’inferenza variazionale.

Quando è utile l’inferenza variazionale?

L’inferenza variazionale è interessante nei seguenti tre casi d’uso strettamente correlati:

1. se hai pochi dati (cioè un basso numero di osservazioni),

2. ti interessa l’incertezza,

3. per la modellizzazione generativa.

Esamineremo ogni caso d’uso nel nostro esempio pratico.

1. Inferenza variazionale con pochi dati

Fig. 1: L'inferenza variazionale ti consente di scambiare la conoscenza di dominio con le informazioni degli esempi. Immagine dell'autore.

A volte, la raccolta di dati è costosa. Ad esempio, le misurazioni di DNA o RNA possono facilmente costare qualche migliaio di euro per osservazione. In questo caso, puoi codificare la conoscenza di dominio al posto di campioni aggiuntivi. L’inferenza variazionale può aiutare a “ridurre” sistematicamente la conoscenza di dominio man mano che si raccolgono più esempi, e fare affidamento maggiormente sui dati (Fig. 1).

2. Inferenza variazionale per l’incertezza

Per le applicazioni critiche per la sicurezza, come in finanza e sanità, l’incertezza è importante. L’incertezza può influire su tutti gli aspetti del modello, più ovviamente sull’output previsto. Meno ovvi sono i parametri del modello (ad esempio, i pesi e i bias). Invece delle usuali matrici di numeri – i pesi e i bias – puoi dotare i parametri di una distribuzione per renderli sfumati. L’inferenza variazionale ti consente di inferire i range di valori ragionevoli.

3. Inferenza variazionale per la modellizzazione generativa

I modelli generativi forniscono una specifica completa di come i dati sono stati generati. Ad esempio, come generare un’immagine di un gatto o di un cane. Di solito, c’è una rappresentazione latente z che porta un significato semantico (ad esempio, z descrive un gatto siamese). Attraverso una serie di trasformazioni (non lineari) e passaggi di campionamento, z viene trasformato nell’immagine effettiva x (ad esempio, i valori dei pixel del gatto siamese). L’inferenza variazionale è un modo per inferire e campionare lo spazio semantico latente z . Un esempio ben noto è il variational autoencoder.

Cos’è l’inferenza variazionale?

Alla base, l’inferenza variazionale è un’impresa bayesiana [1]. Nella prospettiva bayesiana, permetti ancora alla macchina di imparare dai dati, come al solito. Ciò che è diverso, è che dai al modello un suggerimento (una priorità) e consenti alla soluzione (la posteriorità) di essere più sfumata. Più concretamente, diciamo che hai un set di formazione X = [ x ₁, x ₂,.., x ₘ ]ᵗ di m esempi. Usiamo il teorema di Bayes:

p ( Θ | X ) = p ( X | Θ ) p ( Θ ) / p ( X ),

per inferire un intervallo – una distribuzione – di soluzioni Θ . Confronta questo con l’approccio di apprendimento automatico convenzionale, dove minimizziamo una perdita ℒ( Θ, X ) = ln p ( X | Θ ) per trovare una soluzione specifica Θ . L’inferenza bayesiana ruota attorno alla scoperta di un modo per determinare p ( Θ | X ): la distribuzione posteriore dei parametri Θ dati il set di formazione X . In generale, questo è un problema difficile. In pratica, si usano due modi per risolvere p ( Θ | X ): (i) usando la simulazione (Monte Carlo a catena di Markov) o (ii) attraverso l’ottimizzazione.

L’inferenza variazionale riguarda l’opzione (ii).

Il lower bound dell’evidenza (ELBO)

Fig. 2: Sketch of variational inference. We look for a distribution q(Θ) that is close to p(Θ|X). Image by Author.

L’idea dell’inferenza variazionale è quella di cercare una distribuzione q(Θ) che funga da sostituto (surrogato) per p(Θ|X). Successivamente, cerchiamo di rendere q(Θ|Φ) simile a p(Θ|X) modificando i valori di Φ (Fig. 2). Questo viene fatto massimizzando il lower bound dell’evidenza (ELBO):

ℒ(Φ) = E[ln p(X, Θ) — ln q(Θ|Φ)],

dove l’aspettativa E[·] è presa su q(Θ|Φ). A prima vista, sembra che dobbiamo fare attenzione nell’effettuare derivate (rispetto a Φ) a causa della dipendenza di E[·] da q(Θ|Φ). Fortunatamente, i pacchetti autograd come JAX supportano trucchi di riparametrizzazione [2] che ti consentono di prendere direttamente derivate da campioni casuali (ad esempio, della distribuzione gamma) anziché fare affidamento su approcci variabili a scatola nera ad alta varianza [3]. Per farla breve: stimare ∇ℒ(Φ) con un batch ₁, Θ₂,..] ~ q(Θ|Φ) e lasciare al tuo pacchetto autograd il compito di occuparsi dei dettagli.

Inferenza variazionale da zero

Fig. 3: Example image of a handwritten “zero” from sci-kit learn’s digits dataset. Image by Author.

Per solidificare la nostra comprensione, implementiamo l’inferenza variazionale da zero usando JAX . In questo esempio, addestrerai un modello generativo su cifre scritte a mano da sci-kit learn . Puoi seguire l’esempio nel notebook Colab .

Per mantenere la semplicità, analizzeremo solo la cifra “zero”.

from sklearn import datasetsdigits = datasets.load_digits()is_zero = digits.target == 0X_train = digits.images[is_zero]# Flatten image grid to a vector.n_pixels = 64  # 8-by-8.X_train = X_train.reshape((-1, n_pixels))

Ogni immagine è una matrice 8×8 di valori di pixel discreti che variano da 0 a 16. Poiché i pixel sono dati di conteggio, modelliamo i pixel, x, utilizzando la distribuzione di Poisson con una prior gamma per il tasso Θ. Il tasso Θ determina l’intensità media dei pixel. Pertanto, la distribuzione congiunta è data da:

p(x, Θ) = Poisson(x|Θ) Gamma(Θ|a, b),

dove a e b sono la forma e il tasso della distribuzione gamma .

Fig. 4: La conoscenza di dominio della cifra “zero” viene utilizzata come priorità. Immagine dell'autore.

La priorità — in questo caso, Gamma( Θ | a , b ) — è il luogo in cui infondi la tua conoscenza di dominio (caso d’uso 1.). Ad esempio, potresti avere un’idea di come dovrebbe apparire in media la cifra zero (Fig. 4). Puoi utilizzare questa informazione a priori per guidare la scelta di a e b . Per utilizzare la Fig. 4 come informazione a priori — chiamiamola x ₀ — e pesare la sua importanza come due esempi, quindi impostare a = 2 x ₀; b = 2.

Scritto in Python sembra così:

import jax.numpy as jnpimport jax.scipy as jsp# Iperparametri del modello.a = 2. * conoscenza_dominio_zero_b = 2.def log_joint(θ):  log_likelihood = jnp.sum(jsp.stats.gamma.logpdf(θ, a, scale=1./b))  log_likelihood += jnp.sum(jsp.stats.poisson.logpmf(X_train, θ))  return log_likelihood

Si noti che abbiamo utilizzato l’implementazione di numpy e scipy di JAX, in modo da poter derivare.

Successivamente, dobbiamo scegliere una distribuzione surrogata q ( Θ|Φ ). Per ricordarti, il nostro obiettivo è cambiare Φ in modo che la distribuzione surrogata q ( Θ|Φ ) corrisponda a p ( Θ|X) . Quindi, la scelta di q ( Θ ) determina il livello di approssimazione (sopprimiamo la dipendenza da Φ dove il contesto lo consente). A scopo illustrativo, scegliamo una distribuzione variazionale che è composta da (un prodotto di) gamma:

q ( Θ|Φ ) = Gamma( Θ | α , β ),

dove abbiamo usato la notazione abbreviata Φ = { α , β }.

Successivamente, per implementare il bound inferiore della veridicità ℒ ( Φ ) = E[ln p ( X , Θ ) — ln q ( Θ|Φ )], scriviamo innanzi tutto il termine all’interno delle parentesi di aspettativa:

@partial(vmap, in_axes=(0, None, None))def evidence_lower_bound(θ_i, alpha, inv_beta):  elbo = log_joint(θ_i) - jnp.sum(jsp.stats.gamma.logpdf(θ_i, alpha, scale=inv_beta))  return elbo

Qui, abbiamo usato il vmap di JAX per vettorizzare la funzione in modo da poterla eseguire su un batch [ Θ ₁, Θ ₂,.., Θ ₁₂₈]ᵗ.

Per completare l’implementazione di ℒ ( Φ ), media il valore della funzione sopra su realizzazioni della distribuzione variazionale Θ ᵢ ~ q ( Θ ):

def loss(Φ: dict, key):  """Stima stocastica del limite inferiore della veridicità."""  alpha = jnp.exp(Φ['log_alpha'])  inv_beta = jnp.exp(-Φ['log_beta'])  # Seleziona un batch dalla distribuzione variazionale q.  batch_size = 128  batch_shape = [batch_size, n_pixels]  θ_samples = random.gamma(key, alpha , shape=batch_shape) * inv_beta    # Calcola la stima Monte Carlo del limite inferiore della veridicità.  elbo_loss = jnp.mean(evidence_lower_bound(θ_samples, alpha, inv_beta))    # Trasforma elbo in una perdita.  return -elbo_loss

Alcune cose da notare qui riguardo agli argomenti:

  • Abbiamo impacchettato Φ come un dizionario (o tecnicamente, un pytree) contenente ln (α) e ln (β). Questo trucco garantisce che α > 0 e β > 0 – un requisito imposto dalla distribuzione gamma – durante l’ottimizzazione.
  • La perdita è una stima casuale dell’ELBO. In JAX, abbiamo bisogno di una nuova chiave generatore di numeri pseudo-casuali (PRNG) ogni volta che campioniamo. In questo caso, usiamo la chiave per campionare ₁, Θ₂,.., Θ₁₂₈]ᵗ.

Questo completa la specifica del modello p (x, Θ), della distribuzione variazionale q (Θ) e della perdita ℒ (Φ).

Addestramento del modello

Successivamente, minimizziamo la perdita ℒ (Φ) variando Φ = {α, β} in modo che q (Θ|Φ) corrisponda alla posteriore p (Θ|X). Come? Usando la vecchia discesa del gradiente alla moda! Per comodità, usiamo l’ottimizzatore Adam di Optax e inizializziamo i parametri con la prior α = a e β = b [ricorda, la prior era Gamma (Θ|a, b) e ha codificato la nostra conoscenza di dominio].

# Inizializza i parametri usando la prior.Φ = {'log_alpha': jnp.log(a), 'log_beta': jnp.full(fill_value=jnp.log(b), shape=[n_pixels]),}loss_val_grad = jit(jax.value_and_grad(loss))optimiser = optax.adam(learning_rate=0.2)opt_state = optimiser.init(Φ)

Qui, usiamo value_and_grad per valutare simultaneamente l’ELBO e la sua derivata. Comodo per monitorare la convergenza! Quindi compiliamo in tempo reale la funzione risultante (con jit) per renderla veloce.

Infine, addestreremo il modello per 5000 passaggi. Poiché la perdita è casuale, per ogni valutazione dobbiamo fornire una chiave generatore di numeri pseudo-casuali (PRNG). Lo facciamo allocando 5000 chiavi con random.split.

n_iter = 5_000keys = random.split(random.PRNGKey(42), num=n_iter)for i, key in enumerate(keys):  elbo, grads = loss_val_grad(Φ, key)  updates, opt_state = optimiser.update(grads, opt_state)  Φ = optax.apply_updates(Φ, updates)

Congratulazioni! Hai addestrato con successo il tuo primo modello usando l’inferenza variazionale!

Puoi accedere al notebook con il codice completo qui su Colab.

Risultati

Fig. 5: Confronto tra la distribuzione variazionale e la distribuzione posteriore esatta. Immagine dell'autore.

Facciamo un passo indietro e apprezziamo ciò che abbiamo costruito (Fig. 5). Per ogni pixel, la surrogata q (Θ) descrive l’incertezza sulla media dell’intensità del pixel (caso d’uso 2.). In particolare, la nostra scelta di q (Θ) cattura due elementi complementari:

  • L’intensità tipica del pixel.
  • Quanto l’intensità varia da immagine a immagine (la variabilità).

Si scopre che la distribuzione congiunta p (x, Θ) che abbiamo scelto ha una soluzione esatta:

p (Θ|X) = Gamma (Θ|a + Σ xᵢ, m + b),

dove m è il numero di campioni nel set di addestramento X. Qui vediamo esplicitamente come la conoscenza del dominio – codificata in a e b – diminuisce man mano che raccogliamo più esempi x ᵢ.

Possiamo facilmente confrontare la forma appresa α e il tasso β con i valori reali a + Σ x ᵢ e m + b. Nella Fig. 4 confrontiamo le distribuzioni – q (Θ|Φ) rispetto a p (Θ|X) — per due pixel specifici. Ecco che, una perfetta corrispondenza!

Bonus: generare immagini sintetiche

Fig. 6: Immagini sintetiche generate utilizzando l'inferenza variazionale. Immagine dell'autore.

L’inferenza variazionale è ottima per la modellazione generativa (caso d’uso 3.). Con la postcondizione di stand-in q (Θ) a portata di mano, la generazione di nuove immagini sintetiche è banale. I due passaggi sono:

  • Pixel di intensità campionaria Θ ~ q (Θ).
# Estrarre i parametri di q. Alpha = jnp.exp(Φ['log_alpha'])inv_beta = jnp.exp(-Φ['log_beta'])# 1) Generare intensità a livello di pixel per 10 immagini. Key_θ, key_x = random.split(key)m_new_images = 10new_batch_shape = [m_new_images, n_pixels]θ_samples = random.gamma(key_θ, alpha , shape=new_batch_shape) * inv_beta
  • Campione di immagini utilizzando x ~ Poisson (x | Θ).
# 2) Campione di immagini dalle intensità.X_synthetic = random.poisson(key_x, θ_samples)

Puoi vedere il risultato nella Fig. 6. Si noti che il carattere “zero” è leggermente meno nitido del previsto. Ciò faceva parte delle nostre ipotesi di modellizzazione: abbiamo modellato i pixel come mutuamente indipendenti anziché correlati. Per tener conto delle correlazioni tra i pixel, è possibile espandere il modello per raggruppare le intensità dei pixel: questo è chiamato fattorizzazione di Poisson [4].

Sommario

In questo tutorial, abbiamo introdotto le basi dell’inferenza variazionale e l’abbiamo applicata ad un esempio giocattolo: imparare il numero zero scritto a mano. Grazie a autograd, l’implementazione dell’inferenza variazionale da zero richiede solo poche righe di Python.

L’inferenza variazionale è particolarmente potente se hai pochi dati. Abbiamo visto come infondere e scambiare la conoscenza del dominio con le informazioni dai dati. La distribuzione surrogata inferita q (Θ) fornisce una rappresentazione “sfocata” dei parametri del modello, invece di un valore fisso. Questo è ideale se si è in un’applicazione ad alta posta in gioco in cui l’incertezza è importante! Infine, abbiamo dimostrato la modellazione generativa. Generare campioni sintetici è facile una volta che puoi campionare da q (Θ).

In sintesi, questo lo rende un componente fondamentale della cassetta degli attrezzi della scienza dei dati.

sfruttando il potere dell’inferenza variazionale, possiamo affrontare problemi complessi, consentendoci di prendere decisioni informate, quantificare le incertezze e, in definitiva, sbloccare il vero potenziale della scienza dei dati.

Riconoscimenti

Vorrei ringraziare Dorien Neijzen e Martin Banchero per la revisione.

Riferimenti:

[1] Blei, David M., Alp Kucukelbir, e Jon D. McAuliffe. “Inferenza variazionale: una revisione per gli statisticisti.” Giornale dell’Associazione statistica americana 112.518 (2017): 859-877.

[2] Figurnov, Mikhail, Shakir Mohamed e Andriy Mnih. “Gradienti impliciti di riparametrizzazione.” Progressi nei sistemi di informazione neurale 31 (2018).

[3] Ranganath, Rajesh, Sean Gerrish e David Blei. “Inferenza variazionale a scatola nera.” Intelligenza artificiale e statistica. PMLR, 2014.

[4] Gopalan, Prem, Jake M. Hofman e David M. Blei. “Raccomandazione scalabile con la fattorizzazione di Poisson.” Preprint di arXiv arXiv:1311.1704 (2013).