Vectorizzare e Parallelizzare gli Ambienti di RL con JAX Q-learning alla Velocità della Luce⚡

Vectorizzare e Parallelizzare gli Ambienti di RL con JAX Q-learning alla Velocità della Luce⚡

In questo articolo impariamo a vettorizzare un ambiente RL e addestrare 30 agenti di apprendimento Q in parallelo su una CPU, a 1,8 milioni di iterazioni al secondo.

Immagine di Google DeepMind su Unsplash

Nella storia precedente, abbiamo introdotto il Temporal-Difference Learning, in particolare il Q-learning, nel contesto di un GridWorld.

Temporal-Difference Learning e l’importanza dell’esplorazione: Una guida illustrata

Confronto tra metodi TD senza modello (Q-learning) e basati su modello (Dyna-Q e Dyna-Q+) su un mondo a griglia dinamica.

towardsdatascience.com

Sebbene questa implementazione servisse a dimostrare le differenze nelle prestazioni e nei meccanismi di esplorazione di questi algoritmi, era estremamente lenta.

Infatti, l’ambiente e gli agenti erano principalmente codificati in Numpy, che non è affatto uno standard in RL, anche se rende il codice facile da comprendere e debuggare.

In questo articolo, vedremo come scalare gli esperimenti RL attraverso la vettorizzazione degli ambienti e la parallelizzazione senza soluzione di continuità dell’addestramento di dozzine di agenti utilizzando JAX. In particolare, questo articolo copre:

  • Fondamenti di JAX e funzionalità utili per RL
  • Ambienti vettorizzati e perché sono così veloci
  • Implementazione di un ambiente, una politica e un agente di Q-learning in JAX
  • Addestramento di un singolo agente
  • Come parallelizzare l’addestramento degli agenti e quanto sia facile!

Tutto il codice presente in questo articolo è disponibile su GitHub:

GitHub – RPegoud/jax_rl: Implementazione di algoritmi RL e ambienti vettorizzati in JAX

Implementazione di algoritmi RL e ambienti vettorizzati in JAX – GitHub – RPegoud/jax_rl: Implementazione di RL…

github.com

Principi fondamentali di JAX

JAX è un altro framework di Deep Learning Python sviluppato da Google e ampiamente utilizzato da aziende come DeepMind.

“JAX è Autograd (differenziazione automatica) e XLA (Accelerated Linear Algebra, un compilatore TensorFlow), uniti per il calcolo numerico ad alte prestazioni.” — Documentazione ufficiale

A differenza di quanto a cui sono abituati la maggior parte degli sviluppatori Python, JAX non adotta il paradigma della programmazione orientata agli oggetti (OOP), ma piuttosto la programmazione funzionale (FP)[1].

In parole semplici, si basa su funzioni pure (deterministiche e senza effetti collaterali) e strutture dati immutabili (invece di modificare i dati direttamente, vengono create nuove strutture dati con le modifiche desiderate) come blocchi di costruzione primari. Di conseguenza, la FP incoraggia un approccio più funzionale e matematico alla programmazione, rendendola adatta a compiti come il calcolo numerico e l’apprendimento automatico.

Illustreremo le differenze tra questi due paradigmi analizzando il pseudocodice di una funzione di aggiornamento Q:

  • Il approccio orientato agli oggetti si basa su un’istanza di classe che contiene diverse variabili di stato (come i valori Q). La funzione di aggiornamento è definita come un metodo di classe che aggiorna lo stato interno dell’istanza.
  • Il approccio della programmazione funzionale si basa su una funzione pura. Infatti, questo aggiornamento Q è deterministico poiché i valori Q vengono passati come argomento. Pertanto, qualsiasi chiamata a questa funzione con gli stessi input produrrà gli stessi output, mentre gli output di un metodo di classe possono dipendere dallo stato interno dell’istanza. Inoltre, le strutture dati, come gli array, sono definite e modificate nello scope globale.
Implementazione di un aggiornamento Q in programmazione orientata agli oggetti e programmazione funzionale (realizzata dall'autore)

Pertanto, JAX offre una varietà di decoratori di funzioni che sono particolarmente utili nel contesto dell’apprendimento automatico:

  • vmap (mappa vettorializzata): Permette di applicare una funzione che agisce su un singolo campione a un lotto di campioni. Ad esempio, se env.step() è una funzione che esegue un passo in un singolo ambiente, vmap(env.step)() è una funzione che esegue un passo in più ambienti. In altre parole, vmap aggiunge una dimensione di batch a una funzione.
Illustrazione di una funzione step vettorializzata usando vmap (realizzata dall'autore)
  • jit (compilazione “Just In Time”): Consente a JAX di eseguire una “compilazione Just In Time di una funzione Python JAX”, rendendola compatibile con XLA. In sostanza, utilizzando jit possiamo compilare le funzioni e ottenere significativi miglioramenti di velocità (a discapito di un po’ di overhead aggiuntivo durante la compilazione iniziale della funzione).
  • pmap (mappa parallela): Similmente a vmap, pmap consente una facile parallelizzazione. Tuttavia, invece di aggiungere una dimensione di batch a una funzione, replica la funzione e la esegue su diversi dispositivi XLA. Nota: quando si applica pmap, jit viene applicato automaticamente.
Illustrazione di una funzione step parallelizzata usando pmap (realizzata dall'autore)

Ora che abbiamo delineato le basi di JAX, vedremo come ottenere un miglioramento significativo delle prestazioni vettorizzando gli ambienti.

Ambienti Vettorizzati:

Prima di tutto, cos’è un ambiente vettorizzato e quali problemi risolve la vettorizzazione?

Nella maggior parte dei casi, gli esperimenti di apprendimento automatico vengono ritardati dai trasferimenti di dati tra CPU e GPU. Gli algoritmi di apprendimento automatico basati sul deep learning, come l’ottimizzazione della politica proximal (PPO), utilizzano le reti neurali per approssimare la politica.

Come sempre nel Deep Learning, le Reti Neurali utilizzano GPU durante l’addestramento e l’inference. Tuttavia, nella maggior parte dei casi, gli ambienti vengono eseguiti sulla CPU (anche nel caso di utilizzo parallelo di più ambienti).

Ciò significa che la solita iterazione RL di selezione delle azioni tramite la policy (Reti Neurali) e di ricezione di osservazioni e ricompense dall’ambiente richiede uno scambio costante tra GPU e CPU, il che danneggia le prestazioni.

Inoltre, l’utilizzo di framework come PyTorch senza “jitting” potrebbe causare qualche overhead, poiché la GPU potrebbe dover attendere che Python invii nuove osservazioni e ricompense dalla CPU.

Solita configurazione di addestramento RL batched in PyTorch (realizzata dall'autore)

D’altra parte, JAX ci consente di eseguire facilmente ambienti batched sulla GPU, eliminando l’attrito causato dal trasferimento di dati tra GPU e CPU.

Inoltre, grazie al jit, il nostro codice JAX viene compilato in XLA, il che rende l’esecuzione meno (o almeno meno) influenzata dall’inefficienza di Python.

Configurazione di addestramento RL batched in JAX (realizzata dall'autore)

Per ulteriori dettagli e applicazioni entusiasmanti nella ricerca di meta-apprendimento RL, consiglio vivamente questo articolo del blog di Chris Lu.

Implementazione dell’ambiente, dell’agente e della policy:

Diamo un’occhiata all’implementazione delle diverse parti del nostro esperimento RL. Ecco una panoramica ad alto livello delle funzioni di base di cui avremo bisogno:

Metodi di classe richiesti per una semplice configurazione RL (realizzata dall'autore)

L’ambiente

Questa implementazione segue lo schema fornito da Nikolaj Goodger nel suo ottimo articolo sulla scrittura di ambienti in JAX.

Scrittura di un ambiente RL in JAX

Come eseguire CartPole a 1,25 miliardi di passi al secondo

VoAGI.com

Iniziamo con una vista ad alto livello dell’ambiente e dei suoi metodi. Questo è un piano generale per l’implementazione di un ambiente in JAX:

Esaminiamo più da vicino i metodi di classe (come promemoria, le funzioni che iniziano con il simbolo “_” sono private e non devono essere chiamate al di fuori del contesto della classe):

  • _get_obs: Questo metodo converte lo stato dell’ambiente in un’osservazione per l’agente. In un ambiente parzialmente osservabile o stocastico, qui verrebbero applicate le funzioni di elaborazione dello stato.
  • _reset: Poiché eseguiremo più agenti in parallelo, abbiamo bisogno di un metodo per ripristini individuali al termine di un episodio.
  • _reset_if_done: Questo metodo verrà chiamato ad ogni passo e attiverà _reset se il flag “done” viene impostato su True.
  • reset: Questo metodo viene chiamato all’inizio dell’esperimento per ottenere lo stato iniziale di ogni agente, nonché le chiavi casuali associate.
  • step: Dato uno stato e un’azione, l’ambiente restituisce un’osservazione (nuovo stato), una ricompensa e il flag “done” aggiornato.

Nella pratica, una implementazione generica di un ambiente GridWorld sarebbe così:

Si noti che, come già menzionato, tutti i metodi di classe seguono il paradigma della programmazione funzionale. Infatti, non aggiorniamo mai lo stato interno dell’istanza della classe. Inoltre, gli attributi di classe sono tutti costanti che non verranno modificati dopo l’istanziazione.

Analizziamo più da vicino:

  • __init__: Nel contesto del nostro GridWorld, le azioni disponibili sono [0, 1, 2, 3]. Queste azioni vengono tradotte in una matrice bidimensionale usando self.movements e aggiunte allo stato nella funzione step.
  • _get_obs: Il nostro ambiente è deterministico e completamente osservabile, pertanto l’agente riceve direttamente lo stato anziché un’osservazione elaborata.
  • _reset_if_done: L’argomento env_state corrisponde alla tupla (state, key) in cui key è un jax.random.PRNGKey. Questa funzione restituisce semplicemente lo stato iniziale se il flag done è impostato su True, tuttavia non possiamo utilizzare il normale flusso di controllo di Python all’interno delle funzioni jitted di JAX. Utilizzando jax.lax.cond otteniamo essenzialmente un’espressione equivalente a:
def cond(condition, true_fun, false_fun, operand):  if condition: # se done flag == True    return true_fun(operand)  # restituisci self._reset(key)  else:    return false_fun(operand) # restituisci env_state
  • step: Convertiamo l’azione in un movimento e lo aggiungiamo allo stato corrente (jax.numpy.clip garantisce che l’agente rimanga all’interno della griglia). Aggiorniamo quindi la tupla env_state prima di verificare se l’ambiente deve essere ripristinato. Poiché la funzione step viene utilizzata frequentemente durante l’addestramento, il jitting consente notevoli miglioramenti delle prestazioni. Il decoratore @partial(jit, static_argnums=(0, ) segnala che l’argomento “self” del metodo di classe dovrebbe essere considerato statico. In altre parole, le proprietà della classe sono costanti e non cambieranno durante le chiamate successive alla funzione step.

Agente di Q-Learning

L’agente di Q-Learning è definito dalla funzione update, nonché da un tasso di apprendimento e un fattore di sconto statici.

Di nuovo, quando jittiamo la funzione di aggiornamento, passiamo l’argomento “self” come statico. Inoltre, si noti che la matrice q_values viene modificata direttamente usando set() e il suo valore non viene memorizzato come attributo di classe.

Politica Epsilon-Greedy

Infine, la politica utilizzata in questo esperimento è la standard politica epsilon-greedy. Un dettaglio importante è che utilizza scontri casuali, il che significa che se il valore Q massimale non è unico, l’azione verrà campionata in modo uniforme dai valori Q massimali (usando argmax restituirebbe sempre la prima azione con valore Q massimale). Questo è particolarmente importante se i valori Q vengono inizializzati come una matrice di zeri, poiché l’azione 0 (spostamento a destra) sarebbe sempre selezionata.

Altrimenti, la politica può essere riassunta da questo frammento di codice:

action = lax.cond(            explore, # se p < epsilon            _random_action_fn, # seleziona un'azione casuale dato il key            _greedy_action_fn, # seleziona l'azione avarata rispetto ai valori Q            operand=subkey, # utilizza subkey come argomento per le funzioni precedenti        )return action, subkey

Si noti che quando si utilizza una chiave in JAX (ad esempio qui abbiamo selezionato un numero casuale e utilizzato random.choice) è pratica comune suddividere la chiave successivamente (ovvero “passare a uno stato casuale successivo”, ulteriori dettagli qui).

Ciclo di addestramento di un singolo agente:

Ora che abbiamo tutti i componenti necessari, addestriamo un singolo agente.

Ecco un ciclo di addestramento Pythonico, come si può vedere stiamo essenzialmente selezionando un’azione utilizzando la politica, eseguendo un passo nell’ambiente e aggiornando i valori Q, fino alla fine di un episodio. Quindi ripetiamo il processo per N episodi. Come vedremo tra un minuto, questo modo di addestrare un agente è piuttosto inefficiente, tuttavia riassume in modo chiaro i passaggi fondamentali dell’algoritmo:

Su una singola CPU, completiamo 10.000 episodi in 11 secondi, ad una velocità di 881 episodi e 21.680 passi al secondo.

100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]Numero totale di passi: 238.488 Numero di passi al secondo: 21.680

Ora, ripetiamo lo stesso ciclo di addestramento utilizzando la sintassi di JAX. Ecco una descrizione generale della funzione rollout:

Training rollout function using JAX syntax (made by the author)

Per riassumere, la funzione rollout:

  1. Inizializza le variabili osservazioni, ricompense e done come array vuoti con una dimensione pari al numero di passaggi temporali utilizzando jax.numpy.zeros. I Q-values sono inizializzati come una matrice vuota con forma [timesteps+1, grid_dimension_x, grid_dimension_y, n_actions].
  2. Chiama la funzione env.reset() per ottenere lo stato iniziale
  3. Utilizza la funzione jax.lax.fori_loop() per chiamare N volte una funzione fori_body(), dove N è il parametro timestep
  4. La funzione fori_body() si comporta in modo simile al ciclo Python precedente. Dopo aver selezionato un’azione, eseguito un passaggio e calcolato l’aggiornamento Q, aggiorniamo gli array obs, rewards, done e q_values nello stesso posto (l’aggiornamento Q riguarda il passaggio temporale t+1).

Questa complessità aggiuntiva porta ad un aumento di velocità di 85x, ora addestriamo il nostro agente a circa 1,83 milioni di passi al secondo. Si noti che qui l’addestramento viene eseguito su una singola CPU poiché l’ambiente è semplice.

Tuttavia, la vectorizzazione end-to-end scala ancora meglio quando viene applicata a ambienti complessi e algoritmi che beneficiano di più GPU (l’articolo di Chris Lu riporta un aumento di velocità enorme di 4000x tra un’implementazione di PPO CleanRL PyTorch e una riproduzione JAX).

100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]Numero totale di passi: 1 000 000 Numero di passi al secondo: 1 837 563

Dopo aver addestrato il nostro agente, tracciamo il valore Q massimale per ogni cella (ossia stato) del GridWorld e osserviamo che ha effettivamente imparato ad andare dallo stato iniziale (angolo in basso a destra) all’obiettivo (angolo in alto a sinistra).

Rappresentazione heatmap del valore Q massimale per ogni cella del GridWorld (creato dall’autore)

Ciclo di addestramento di agenti paralleli:

Come promesso, ora che abbiamo scritto le funzioni necessarie per addestrare un agente singolo, ci resta poco o niente da fare per addestrare agenti multipli in parallelo su ambienti raggruppati!

Grazie a vmap possiamo trasformare rapidamente le nostre funzioni precedenti per lavorare su gruppi di dati. Dobbiamo solo specificare le forme di input e output attese, ad esempio per env.step:

  • in_axes = ((0,0), 0) rappresenta la forma di input, composta dalla tupla env_state (dimensione (0, 0)) e un’osservazione (dimensione 0).
  • out_axes = ((0, 0), 0, 0, 0) rappresenta la forma di output, con output ((env_state), obs, ricompensa, done).
  • Ora, possiamo chiamare v_step su un array di env_state e azioni e ricevere un array di env_state, osservazioni, ricompense e flag done processati.
  • Si noti che anche jit tutte le funzioni raggruppate per prestazioni (discutibilmente, jitting env.reset() è superfluo dato che viene chiamato solo una volta nella nostra funzione di addestramento).

L’ultimo aggiustamento che dobbiamo fare è aggiungere una dimensione di batch ai nostri array per tener conto dei dati di ciascun agente.

Facendo ciò, otteniamo una funzione che ci permette di allenare diversi agenti in parallelo, con minime modifiche rispetto alla funzione per un singolo agente:

Otteniamo prestazioni simili con questa versione della nostra funzione di allenamento:

100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]Numero totale di passi: 100 000 * 30 = 3 000 000Numero di passi al secondo: 49 036 * 30 = 1 471 080

E questo è tutto! Grazie per aver letto fino a qui, spero che questo articolo abbia fornito un’introduzione utile per implementare ambienti vettorializzati in JAX.

Se ti è piaciuta la lettura, considera di condividere questo articolo e di segnalare con una stella il mio repository GitHub, grazie per il tuo supporto! 🙏

GitHub – RPegoud/jax_rl: Implementazione di JAX di algoritmi di RL e ambienti vettorializzati

Implementazione di JAX di algoritmi di RL e ambienti vettorializzati – GitHub – RPegoud/jax_rl: Implementazione di JAX di RL…

github.com

Infine, per coloro interessati ad approfondire un po’ di più, ecco una lista di risorse utili che mi hanno aiutato ad iniziare con JAX e a redigere questo articolo:

Una lista selezionata di fantastici articoli e risorse su JAX:

[1] Coderized, (programmazione funzionale) Lo stile di codifica più puro, in cui i bug sono quasi impossibili, YouTube

[2] Aleksa Gordić, Playlist YouTube JAX Da Zero a Eroe (2022), The AI Epiphany

[3] Nikolaj Goodger, Scrivere un Ambiente RL in JAX (2021)

[4] Chris Lu, Raggiungere Accelerazioni di 4000x e Scoperte Meta-Evolutive con PureJaxRL (2023), Università di Oxford, Foerster Lab for AI Research

[5] Nicholas Vadivelu, Awesome-JAX (2020), una lista di librerie, progetti e risorse JAX

[6] Documentazione Ufficiale di JAX, Allenamento di una Semplice Rete Neurale, con Caricamento Dati PyTorch