Un’introduzione delicata al Deep Reinforcement Learning in JAX

Un'affascinante introduzione al Deep Reinforcement Learning in JAX

Risoluzione dell’ambiente CartPole con DQN in meno di un secondo

Foto di Thomas Despeyroux su Unsplash

Recenti progressi nell’apprendimento per rinforzo (RL), come i taxi autonomi di Waymo o gli agenti di scacchi sovrumani di DeepMind, integrano il RL classico con componenti di deep learning come le reti neurali e i metodi di ottimizzazione del gradiente.

Sulla base dei fondamenti e dei principi di programmazione introdotti in una delle mie storie precedenti, scopriremo e impareremo ad implementare le Reti Q Profonde (DQN) e i buffer di replay per risolvere l’ambiente CartPole dell’OpenAI. Tutto questo in meno di un secondo usando JAX!

Per una introduzione a JAX, agli ambienti vettorializzati e al Q-learning, si prega di fare riferimento al contenuto di questa storia:

Vettorizzare e Parallelizzare Ambienti RL con JAX: Q-learning alla velocità della luce⚡

Impara a vettorizzare un ambiente GridWorld e ad allenare 30 agenti di Q-learning in parallelo su una CPU, a 1.8 milioni di step per…

towardsdatascience.com

Il nostro framework di scelta per il deep learning sarà la libreria Haiku di DeepMind, che ho recentemente presentato nel contesto dei Transformers:

Implementazione di un Trasformatore Encoder da Zero con JAX e Haiku 🤖

Comprensione dei blocchi fondamentali dei Transformers.

towardsdatascience.com

Questo articolo coprirà le seguenti sezioni:

  • Perché abbiamo bisogno del Deep RL?
  • Reti Q Profonde, teoria e pratica
  • Buffer di replay
  • Tradurre l’ambiente CartPole in JAX
  • Il modo JAX di scrivere cicli di allenamento efficienti

Come sempre, tutto il codice presentato in questo articolo è disponibile su GitHub:

GitHub – RPegoud/jym: Implementazione JAX di algoritmi RL e ambienti vettorializzati

Implementazione JAX di algoritmi RL e ambienti vettorializzati – GitHub – RPegoud/jym: Implementazione JAX di RL…

github.com

Perché abbiamo bisogno del Deep RL?

In articoli precedenti, abbiamo introdotto algoritmi di apprendimento a differenza temporale e in particolare il Q-learning.

In parole semplici, il Q-learning è un algoritmo di tipo off-policy (la policy di destinazione non è la policy utilizzata per la presa di decisioni) che mantiene e aggiorna una tabella Q, una mappatura esplicita dei stati ai valori delle azioni corrispondenti.

Mentre il Q-learning è una soluzione pratica per gli ambienti con spazi di azione discreti e spazi di osservazione limitati, fatica a scalare bene verso ambienti più complessi. Infatti, creare una Q-tabella richiede la definizione degli spazi di azione e osservazione.

Consideriamo l’esempio della guida autonoma, lo spazio di osservazione è composto da un’infinità di configurazioni potenziali derivate da telecamere e altri input sensoriali. D’altro canto, lo spazio di azione include una vasta gamma di posizioni del volante e livelli variabili di forza applicata al freno e all’acceleratore.

Anche se teoricamente potremmo discretizzare lo spazio di azione, il volume stesso di stati e azioni possibili porta a una tabella Q impraticabile nelle applicazioni del mondo reale.

Foto di Kirill Tonkikh su Unsplash

Trovare azioni ottimali in spazi di stato-azione grandi e complessi richiede quindi potenti algoritmi di approssimazione delle funzioni, che è esattamente ciò che sono le Reti Neurali. Nel caso dell’Apprendimento per Rinforzo Profondo, le reti neurali vengono utilizzate come sostituto della tabella Q e forniscono una soluzione efficiente alla maledizione della dimensionalità introdotta dai grandi spazi di stato. Inoltre, non è necessario definire esplicitamente lo spazio di osservazione.

Deep Q-Network e Replay Buffers

L’DQN utilizza due tipi di reti neurali in parallelo, a partire dalla rete “in linea” utilizzata per la previsione del valore Q e la decisione. D’altro canto, la rete “target” viene utilizzata per creare target Q stabili per valutare le prestazioni della rete online tramite la funzione di perdita.

Similmente al Q-learning, gli agenti DQN sono definiti da due funzioni: act e update.

Act

La funzione act implementa una politica epsilon-greedy rispetto ai valori Q, che vengono stimati dalla rete neurale online. In altre parole, l’agente seleziona l’azione corrispondente al valore Q predetto massimo per uno stato dato, con una probabilità prefissata di agire casualmente.

Potreste ricordare che il Q-learning aggiorna la sua tabella Q dopo ogni passo, tuttavia, nell’Apprendimento Profondo è pratica comune calcolare gli aggiornamenti utilizzando la discesa del gradiente su un gruppo di input.

Per questo motivo, l’DQN memorizza esperienze (tuple contenenti stato, azione, ricompensa, stato_successivo, flag_fine) in un buffer di riproduzione. Per addestrare la rete, prenderemo un gruppo di esperienze da questo buffer anziché utilizzare solo l’ultima esperienza (più dettagli nella sezione Replay Buffer).

Rappresentazione visuale del processo di selezione delle azioni di DQN (creato dall'autore)

Ecco un’implementazione JAX della parte di selezione delle azioni di DQN:

L’unicità di questo frammento è che l’attributo model non contiene parametri interni come di solito accade nei framework come PyTorch o TensorFlow.

Qui, il modello è una funzione che rappresenta un passaggio in avanti attraverso la nostra architettura, ma i pesi mutabili sono memorizzati esternamente e passati come argomenti. Questo spiega perché possiamo usare jit mentre passiamo l’argomento self come statico (essendo il modello senza stato come altri attributi della classe).

Aggiornamento

La funzione update è responsabile dell’addestramento della rete. Calcola una perdita di errore quadratico medio (MSE) basata sull’errore di differenza temporale (TD):

Errore quadratico medio usato in DQN

In questa funzione di perdita, θ rappresenta i parametri della rete online, mentre θ− rappresenta i parametri della rete target. I parametri della rete target vengono impostati sui parametri della rete online ogni N passi, simile a un checkpoint (N è un iperparametro).

Questa separazione dei parametri (con θ per i valori Q attuali e θ− per i valori Q target) è fondamentale per stabilizzare l’addestramento.

Utilizzare gli stessi parametri per entrambi sarebbe simile ad avere come obiettivo un bersaglio in movimento, poiché gli aggiornamenti alla rete sposterebbero immediatamente i valori target. Aggiornando periodicamente θ− (ovvero congelando questi parametri per un numero di passi prefissato), garantiamo Q-target stabili mentre la rete online continua a imparare.

Infine, il termine (1-done) aggiusta il target per gli stati terminali. In effetti, quando un episodio termina (cioè ‘done’ è uguale a 1), non c’è stato successivo. Pertanto, il valore Q per lo stato successivo viene impostato a 0.

Rappresentazione visiva del processo di aggiornamento dei parametri di DQN (realizzata dall'autore)

L’implementazione della funzione di aggiornamento per DQN è leggermente più complessa, analizziamola:

  • Prima di tutto, la funzione _loss_fn implementa l’errore quadratico descritto in precedenza per una singola esperienza.
  • Successivamente, _batch_loss_fn funge da wrapper per _loss_fn e lo decora con vmap, applicando la funzione di perdita a un batch di esperienze. Restituiamo quindi l’errore medio per questo batch.
  • Infine, update funge da ultimo strato per la nostra funzione di perdita, calcolando il suo gradiente rispetto ai parametri della rete online, ai parametri della rete target e a un batch di esperienze. Utilizziamo poi Optax (una libreria JAX comunemente utilizzata per l’ottimizzazione) per eseguire un passo dell’ottimizzatore e aggiornare i parametri della rete online.

Si noti che, analogamente al buffer di riproduzione, il modello e l’ottimizzatore sono funzioni pure che modificano uno stato esterno. La seguente riga serve come buona illustrazione di questo principio:

updates, optimizer_state = optimizer.update(grads, optimizer_state)

Questo spiega anche perché possiamo utilizzare un singolo modello sia per le reti online che per quelle target, poiché i parametri vengono memorizzati e aggiornati esternamente.

# target network predictionsself.model.apply(target_net_params, None, state)# online network predictionsself.model.apply(online_net_params, None, state)

Per contesto, il modello che utilizziamo in questo articolo è un perceptron multi-strato definito come segue:

N_ACTIONS = 2NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)@hk.transformdef model(x):    # simple multi-layer perceptron    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)    return mlp(x)online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))prediction = model.apply(online_net_params, None, state)

Buffer di riproduzione

Ora diamo un passo indietro e osserviamo da vicino il buffer di riproduzione. Sono ampiamente utilizzati nell’apprendimento per rinforzo per una varietà di motivi:

  • Generalizzazione: Campionando dal buffer di riproduzione, rompiamo la correlazione tra le esperienze consecutive mescolando il loro ordine. In questo modo, evitiamo l’overfitting a sequenze specifiche di esperienze.
  • Diversità: Poiché il campionamento non è limitato alle esperienze recenti, osserviamo in generale una varianza inferiore negli aggiornamenti e preveniamo l’overfitting alle esperienze più recenti.
  • Aumento dell’efficienza del campione: Ogni esperienza può essere campionata più volte dal buffer, consentendo al modello di apprendere di più dalle singole esperienze.

Infine, possiamo utilizzare diversi schemi di campionamento per il nostro buffer di riproduzione:

  • Campionamento uniforme: Le esperienze vengono campionate in modo uniforme in modo casuale. Questo tipo di campionamento è semplice da implementare e consente al modello di apprendere dalle esperienze in modo indipendente dal momento in cui sono state raccolte.
  • Campionamento prioritario: Questa categoria include diversi algoritmi come il “Replay di esperienze prioritizzato” (“PER”, Schaul et al., 2015) o il “Replay di esperienze del gradiente” (“GER”, Lahire et al., 2022). Questi metodi cercano di dare priorità alla selezione di esperienze in base a una metrica legata al loro “potenziale di apprendimento” (l’ampiezza dell’errore TD per PER e la norma del gradiente dell’esperienza per GER).

Per semplicità, implementeremo un buffer di riproduzione uniforme in questo articolo. Tuttavia, ho intenzione di approfondire ampiamente il campionamento prioritario in futuro.

Come promesso, il buffer di riproduzione uniforme è abbastanza facile da implementare, tuttavia vi sono alcune complessità legate all’uso di JAX e alla programmazione funzionale. Come sempre, dobbiamo lavorare con funzioni pure che sono senza effetti collaterali. In altre parole, non ci è consentito definire il buffer come un’istanza di classe con uno stato interno variabile.

Al contrario, iniziamo con un dizionario buffer_state che mappa chiavi ad array vuoti con forme predefinite, poiché JAX richiede array di dimensione costante quando compila il codice in XLA.

buffer_state = {    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),}

Utilizzeremo una classe UniformReplayBuffer per interagire con lo stato del buffer. Questa classe ha due metodi:

  • add: Scompone una tupla di esperienze e mappa i suoi componenti in un indice specifico. idx = idx % self.buffer_size garantisce che quando il buffer è pieno, l’aggiunta di nuove esperienze sovrascriva quelle più vecchie.
  • sample: Esegue il campionamento di una sequenza di indici casuali dalla distribuzione casuale uniforme. La lunghezza della sequenza è definita da batch_size, mentre l’intervallo degli indici è [0, current_buffer_size-1]. Questo garantisce che non campioniamo array vuoti quando il buffer non è ancora pieno. Infine, utilizziamo il vmap di JAX in combinazione con tree_map per restituire un batch di esperienze.

Traduzione dell’ambiente di CartPole in JAX

Ora che il nostro agente DQN è pronto per l’addestramento, implementeremo rapidamente un ambiente di CartPole vettorizzato utilizzando lo stesso framework presentato in un articolo precedente. CartPole è un ambiente di controllo con uno spazio osservativo continuo ampio, il che lo rende rilevante per testare il nostro DQN.

Rappresentazione visiva dell'ambiente di CartPole (crediti e documentazione: OpenAI Gymnasium, licenza MIT)

Il processo è abbastanza semplice, riutilizziamo la maggior parte dell’implementazione di Gymnasium di OpenAI facendo attenzione ad utilizzare le matrici JAX e il flusso di controllo di lax invece delle alternative Python o Numpy, ad esempio:

# Implementazione Python
force = self.force_mag if action == 1 else -self.force_mag
# Implementazione Jax
force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag)

# Implementazione Python
costheta, sintheta = math.cos(theta), math.sin(theta)
# Implementazione Jax
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)

# Implementazione Python
if not terminated:
    reward = 1.0
else:
    reward = 0.0
# Implementazione Jax
reward = jnp.float32(jnp.invert(done))

Per motivi di brevità, il codice completo dell’ambiente è disponibile qui:

jym/src/envs/control/cartpole.py su main · RPegoud/jym

Implementazione JAX degli algoritmi di RL e degli ambienti vettorializzati – jym/src/envs/control/cartpole.py su main ·…

github.com

Il modo JAX per scrivere loop di addestramento efficienti

La parte finale della nostra implementazione di DQN è il loop di addestramento (detto anche rollout). Come menzionato in precedenti articoli, dobbiamo rispettare un formato specifico per sfruttare la velocità di JAX.

La funzione di rollout potrebbe sembrare intimidatoria all’inizio, ma la maggior parte della sua complessità è puramente sintattica poiché abbiamo già coperto la maggior parte dei blocchi di base. Ecco una panoramica in pseudo-codice:

1. Inizializzazione:
   * Creare array vuoti che conterranno gli stati, le azioni, le ricompense e i flag di done per ogni istante di tempo. Inizializzare le reti e l'ottimizzatore con array vuoti.
   * Incapsulare tutti gli oggetti inizializzati in una tupla val
2. Loop di addestramento (ripetere per i passi i):
   * Scompattare la tupla val
   * (Opzionale) Decadimento di epsilon utilizzando una funzione di decadimento
   * Eseguire un'azione in base allo stato e ai parametri del modello
   * Eseguire un passo dell'ambiente e osservare lo stato successivo, la ricompensa e il flag di done
   * Creare una tupla di esperienza (stato, azione, ricompensa, nuovo_stato, done) e aggiungerla al buffer di replay
   * Estrarre un batch di esperienze in base alla dimensione del buffer corrente (cioè estrarre solo da esperienze che hanno valori diversi da zero)
   * Aggiornare i parametri del modello utilizzando il batch di esperienze
   * Ogni N passi, aggiornare i pesi della rete target (impostare target_params = online_params)
   * Archiviare i valori delle esperienze per l'episodio corrente e restituire la tupla `val` aggiornata

Ora possiamo eseguire DQN per 20.000 passi e osservare le performance. Dopo circa 45 episodi, l’agente riesce ad ottenere performance decenti, bilanciando il paletto per oltre 100 passi in modo coerente.

Le barre verdi indicano che l’agente è riuscito a bilanciare il paletto per più di 200 passi, risolvendo l’ambiente. In particolare, l’agente ha stabilito il suo record nell’episodio 51, con 393 passi.

Rapporto sulle performance per DQN (realizzato dall’autore)

I 20.000 passi di addestramento sono stati eseguiti in poco più di un secondo, a un tasso di 15.807 passi al secondo (su una singola CPU)!

Queste performance suggeriscono le impressionanti capacità di scalabilità di JAX, che consentono ai professionisti di eseguire esperimenti parallelizzati su larga scala con requisiti hardware minimi.

Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:01<00:00, 15807.81it/s]

Esamineremo più da vicino le procedure di rollout parallelizzate per eseguire esperimenti statisticamente significativi e ricerche di iperparametri in un futuro articolo!

Nel frattempo, sentitevi liberi di riprodurre l’esperimento e sperimentare con gli iperparametri utilizzando questo notebook:

jym/notebooks/control/cartpole/dqn_cartpole.ipynb su main · RPegoud/jym

Implementazione JAX degli algoritmi di RL e degli ambienti vettorializzati – jym/notebooks/control/cartpole/dqn_cartpole.ipynb su…

github.com

Conclusione

Come sempre, grazie per aver letto fino a qui! Spero che questo articolo abbia fornito una decente introduzione al Deep RL in JAX. Se hai domande o feedback relativi al contenuto di questo articolo, assicurati di farmelo sapere, sono sempre felice di fare una piccola chiacchierata 😉

Fino alla prossima volta 👋

Credits: