Utilizzando JAX per accelerare la nostra ricerca
'Usando JAX per accelerare la ricerca'
Gli ingegneri di DeepMind accelerano la nostra ricerca costruendo strumenti, scalando algoritmi e creando mondi virtuali e fisici impegnativi per l’addestramento e il test dei sistemi di intelligenza artificiale (AI). Come parte di questo lavoro, valutiamo costantemente nuove librerie e framework di apprendimento automatico.
Recentemente, abbiamo scoperto che un numero crescente di progetti viene ben servito da JAX, un framework di apprendimento automatico sviluppato dai team di ricerca di Google. JAX risuona bene con la nostra filosofia ingegneristica ed è stato ampiamente adottato dalla nostra comunità di ricerca nell’ultimo anno. Qui condividiamo la nostra esperienza di lavoro con JAX, illustrando perché lo troviamo utile per la nostra ricerca di intelligenza artificiale e fornendo una panoramica dell’ecosistema che stiamo creando per supportare i ricercatori ovunque.
Perché JAX?
JAX è una libreria Python progettata per il calcolo numerico ad alte prestazioni, in particolare per la ricerca sull’apprendimento automatico. La sua API per le funzioni numeriche si basa su NumPy, una collezione di funzioni utilizzate nel calcolo scientifico. Sia Python che NumPy sono ampiamente utilizzati e familiari, rendendo JAX semplice, flessibile e facile da adottare.
- Imitando l’intelligenza interattiva
- MuZero Padronanza di Go, scacchi, shogi e Atari senza regole
- Dati, Architettura o Perdite Cosa Contribuisce di Più al Successo del Transformer Multimodale?
Oltre alla sua API NumPy, JAX include un sistema estensibile di trasformazioni di funzioni componibili che aiutano la ricerca sull’apprendimento automatico, tra cui:
- Differenziazione: L’ottimizzazione basata sul gradiente è fondamentale per l’IA. JAX supporta nativamente sia la differenziazione automatica in modalità diretta che inversa di funzioni numeriche arbitrarie, tramite trasformazioni di funzioni come grad, hessian, jacfwd e jacrev.
- Vectorizzazione: Nella ricerca sull’IA spesso applichiamo una singola funzione a molti dati, ad esempio calcolando la perdita su un batch o valutando i gradienti per l’apprendimento privato differenziato. JAX fornisce una vectorizzazione automatica tramite la trasformazione vmap che semplifica questa forma di programmazione. Ad esempio, i ricercatori non devono ragionare sul batching durante l’implementazione di nuovi algoritmi. JAX supporta anche il data parallelism su larga scala tramite la trasformazione correlata pmap, distribuendo in modo elegante i dati che sono troppo grandi per la memoria di un singolo acceleratore.
- JIT-compilazione: XLA viene utilizzato per compilare e eseguire programmi JAX in tempo reale (JIT) su acceleratori GPU e Cloud TPU. La JIT-compilazione, insieme all’API consistente con NumPy di JAX, consente ai ricercatori senza precedenti esperienze nell’informatica ad alte prestazioni di scalare facilmente su uno o più acceleratori.
Abbiamo scoperto che JAX ha permesso una sperimentazione rapida con algoritmi e architetture innovative e ora è alla base di molte delle nostre pubblicazioni recenti. Per saperne di più, si prega di considerare di partecipare al nostro Roundtable su JAX, mercoledì 9 dicembre alle 19:00 GMT, alla conferenza virtuale NeurIPS.
JAX presso DeepMind
Supportare la ricerca di AI all’avanguardia significa bilanciare la prototipazione rapida e la rapida iterazione con la capacità di distribuire esperimenti a una scala tradizionalmente associata ai sistemi di produzione. Ciò che rende questi tipi di progetti particolarmente sfidanti è che il panorama della ricerca evolve rapidamente ed è difficile da prevedere. In qualsiasi momento, una nuova scoperta di ricerca può, e regolarmente lo fa, cambiare la traiettoria e le esigenze di interi team. In questo panorama in continua evoluzione, una responsabilità fondamentale del nostro team di ingegneria è assicurarsi che le lezioni apprese e il codice sviluppato per un progetto di ricerca vengano riutilizzati in modo efficace nel successivo.
Un approccio che si è dimostrato efficace è la modularizzazione: estraiamo i blocchi di costruzione più importanti e critici sviluppati in ciascun progetto di ricerca in componenti ben testati ed efficienti. Questo consente ai ricercatori di concentrarsi sulla loro ricerca beneficiando anche del riutilizzo del codice, delle correzioni degli errori e dei miglioramenti delle prestazioni negli ingredienti algoritmici implementati dalle nostre librerie principali. Abbiamo anche scoperto che è importante assicurarsi che ogni libreria abbia uno scopo chiaramente definito e garantire che siano interoperabili ma indipendenti. L’acquisizione incrementale, la possibilità di scegliere le funzionalità senza essere vincolati ad altre, è fondamentale per fornire la massima flessibilità ai ricercatori e supportarli sempre nella scelta dello strumento giusto per il lavoro.
Altre considerazioni che sono state prese in considerazione nello sviluppo del nostro ecosistema JAX includono assicurarsi che rimanga coerente (quando possibile) con la progettazione delle nostre librerie TensorFlow esistenti (ad esempio Sonnet e TRFL). Ci siamo anche proposti di costruire componenti che (quando rilevante) corrispondano il più possibile alla matematica sottostante, per essere autoesplicativi e ridurre al minimo i passaggi mentali “dalla carta al codice”. Infine, abbiamo scelto di rendere open source le nostre librerie per facilitare la condivisione dei risultati della ricerca e per incoraggiare la comunità più ampia a esplorare l’ecosistema JAX.
Il nostro Ecosistema oggi
Haiku
Il modello di programmazione JAX di trasformazioni di funzioni componibili può rendere complicato gestire oggetti con stato, ad esempio reti neurali con parametri addestrabili. Haiku è una libreria per reti neurali che consente agli utenti di utilizzare modelli di programmazione orientati agli oggetti familiari sfruttando al contempo la potenza e la semplicità del paradigma funzionale puro di JAX.
Haiku è attualmente utilizzato da centinaia di ricercatori di DeepMind e Google, ed è già stato adottato in diversi progetti esterni (ad esempio Coax, DeepChem, NumPyro). Si basa sull’API per Sonnet, il nostro modello di programmazione basato su moduli per reti neurali in TensorFlow, e abbiamo cercato di rendere il più semplice possibile il passaggio da Sonnet a Haiku.
Scopri di più su GitHub
Optax
L’ottimizzazione basata sui gradienti è fondamentale per l’apprendimento automatico. Optax fornisce una libreria di trasformazioni del gradiente, insieme agli operatori di composizione (ad esempio chain) che consentono di implementare molti ottimizzatori standard (ad esempio RMSProp o Adam) con una sola riga di codice.
La natura compositiva di Optax supporta naturalmente la ricombinazione degli stessi ingredienti di base in ottimizzatori personalizzati. Inoltre, offre una serie di strumenti per la stima del gradiente stocastico e l’ottimizzazione di secondo ordine.
Molti utenti di Optax hanno adottato Haiku ma, in linea con la nostra filosofia di progressivo coinvolgimento, viene supportata qualsiasi libreria che rappresenti i parametri come strutture ad albero JAX (ad esempio Elegy, Flax e Stax). Per ulteriori informazioni su questo ricco ecosistema di librerie JAX, consulta qui.
Scopri di più su GitHub
RLax
Molti dei nostri progetti di maggior successo si trovano all’intersezione dell’apprendimento profondo e dell’apprendimento per rinforzo (RL), noto anche come apprendimento profondo per rinforzo . RLax è una libreria che fornisce blocchi di costruzione utili per la creazione di agenti RL.
I componenti in RLax coprono un’ampia gamma di algoritmi e idee: apprendimento TD, gradienti di politica, attori critici, MAP, ottimizzazione di politiche prossimali, trasformazione non lineare del valore, funzioni di valore generale e diversi metodi di esplorazione.
Anche se vengono forniti alcuni agenti di esempio introduttivi, RLax non è destinato a essere un framework per la creazione e il rilascio di sistemi completi di agenti RL. Un esempio di framework di agenti completamente funzionante che si basa su componenti RLax è Acme .
Scopri di più su GitHub
Chex
Il testing è fondamentale per la affidabilità del software e il codice di ricerca non fa eccezione. Trarre conclusioni scientifiche dagli esperimenti di ricerca richiede la certezza della correttezza del proprio codice. Chex è una raccolta di utility di testing utilizzate dagli autori di librerie per verificare che i blocchi di costruzione comuni siano corretti e robusti e dagli utenti finali per verificare il proprio codice sperimentale.
Chex fornisce una varietà di utility, tra cui testing unitario consapevole di JAX, asserzioni sulle proprietà dei tipi di dati JAX, mock e falsi e ambienti di test multi-dispositivo. Chex viene utilizzato in tutto l’Ecosistema JAX di DeepMind e da progetti esterni come Coax e MineRL .
Scopri di più su GitHub
Jraph
Le reti neurali a grafo (GNN) sono un’area di ricerca affascinante con molte applicazioni promettenti. Vedi, ad esempio, il nostro recente lavoro sulla previsione del traffico in Google Maps e il nostro lavoro sulla simulazione fisica . Jraph (pronunciato “giraffe”) è una libreria leggera per supportare il lavoro con GNN in JAX.
Jraph fornisce una struttura dati standardizzata per i grafi, un insieme di utility per lavorare con i grafi e una “zoo” di modelli di reti neurali grafiche facilmente forkabili ed estensibili. Altre caratteristiche chiave includono: batching di GraphTuples che sfruttano efficientemente gli acceleratori hardware, supporto di compilazione JIT di grafi di forma variabile tramite padding e mascheramento e perdite definite su partizioni di input. Come Optax e le nostre altre librerie, Jraph non impone vincoli sulla scelta della libreria di reti neurali da parte dell’utente.
Scopri di più sull’utilizzo della libreria dalla nostra ricca collezione di esempi.
Scopri di più su GitHub
Il nostro Ecosistema JAX è in costante evoluzione e incoraggiamo la comunità di ricerca di ML ad esplorare le nostre librerie e il potenziale di JAX per accelerare la propria ricerca.