Iniziare con JAX

'JAX inizio'

Alimentare il futuro del calcolo numerico ad alte prestazioni e della ricerca di ML

Foto di Lance Asper su Unsplash

Introduzione

JAX è una libreria Python sviluppata da Google per effettuare calcoli numerici ad alte prestazioni su qualsiasi tipo di dispositivo (CPU, GPU, TPU, ecc…). Una delle principali applicazioni di JAX è lo sviluppo di ricerca di Machine Learning e Deep Learning, anche se la libreria è principalmente progettata per fornire tutte le capacità necessarie per eseguire compiti di calcolo scientifico ad uso generale (operazioni su matrici altamente dimensionali, ecc…).

Considerando il focus specificamente sul calcolo ad alte prestazioni, JAX è stato progettato per essere estremamente veloce essendo basato su XLA (Accelerated Linear Algebra). XLA è infatti un compilatore progettato per velocizzare le operazioni di algebra lineare e può essere utilizzato anche dietro ad altri framework come TensorFlow e Pytorch. Inoltre, gli array di JAX sono stati progettati per seguire gli stessi principi di Numpy, rendendo molto facile migrare vecchio codice Numpy a JAX e sfruttare i miglioramenti delle prestazioni tramite GPU e TPU.

Alcune delle principali caratteristiche di JAX sono:

  • Compilazione Just in Time (JIT) : La compilazione JIT e l’hardware accelerato sono ciò che consente a JAX di essere molto più veloce di Numpy normale. Utilizzando la funzione jit() è possibile compilare e memorizzare nella cache funzioni personalizzate con il kernel XLA. Utilizzando la memorizzazione nella cache aumenteremo il tempo di esecuzione complessivo quando eseguiamo la funzione per la prima volta, per poi ridurre drasticamente il tempo per le esecuzioni successive. Quando si utilizza la memorizzazione nella cache è importante assicurarsi di cancellare le cache quando necessario per evitare risultati obsoleti (ad esempio, variabili globali che cambiano).
  • Parallelizzazione automatica: La Dispatch asincrona consente ai vettori JAX di essere valutati in modo pigro, materializzando il contenuto solo quando viene acceduto (il controllo viene restituito al programma prima del completamento del calcolo). Inoltre, per rendere possibile l’ottimizzazione del grafo, gli array di JAX sono immutabili (concetti simili con lazy evaluation e ottimizzazione del grafo si applicano ad Apache Spark). La funzione pmap() può essere utilizzata per parallelizzare i calcoli su più GPU/TPU.
  • Vectorizzazione automatica : La vectorizzazione automatica per parallelizzare le operazioni può essere eseguita utilizzando la funzione vmap(). Durante la vectorizzazione, un algoritmo viene trasformato da un’elaborazione con un singolo valore a un insieme di valori.
  • Differenziazione automatica : La funzione grad() può essere utilizzata per calcolare automaticamente il gradiente (derivata) delle funzioni. In particolare, la Differenziazione Automatica di JAX consente lo sviluppo di programmi differenziali ad uso generale al di fuori dello spettro del Deep Learning. Rendendo possibile differenziare attraverso ricorsione, rami, cicli, eseguire differenziazione di ordine superiore (ad esempio, Jacobiani ed Hessiani) e utilizzare sia la differenziazione di modo diretto che inverso.

Pertanto, JAX è in grado di fornirci tutte le basi necessarie per costruire modelli avanzati di Deep Learning, ma non fornisce utilità di alto livello pronte all’uso per alcune delle operazioni di Deep Learning più comuni (ad esempio, funzioni di perdita/attivazione, strati, ecc…). Ad esempio, i parametri del modello appresi durante l’addestramento di ML possono essere memorizzati in una struttura Pytree in JAX. Considerando tutti i vantaggi offerti da JAX, sono stati costruiti su di esso diversi framework orientati al DL come Haiku (utilizzato da DeepMind) e Flax (utilizzato da Google Brain).

Dimostrazione

Come parte di questo articolo, vedremo ora come risolvere un semplice problema di classificazione utilizzando JAX e il dataset di Kaggle Mobile Price Classification [1] per predire in quale fascia di prezzo si troverà un telefono. Tutto il codice utilizzato in questo articolo (e altro ancora!) è disponibile sui miei account GitHub e Kaggle .

Prima di tutto, dobbiamo assicurarci di avere JAX installato nel nostro ambiente.

pip install jax

A questo punto, siamo pronti per importare le librerie e i dataset necessari (Figura 1). Per semplificare la nostra analisi, anziché utilizzare tutte le classi nel nostro etichetta, filtriamo i dati per utilizzare solo 2 classi e riduciamo il numero di caratteristiche.

import pandas as pdimport jax.numpy as jnpfrom jax import gradfrom sklearn.preprocessing import StandardScalerfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_reportimport matplotlib.pyplot as pltdf = pd.read_csv('/kaggle/input/mobile-price-classification/train.csv')df = df.iloc[:, 10:]df = df.loc[df['price_range'] <= 1]df.head()
Figura 1: Dataset di classificazione dei prezzi dei cellulari (Immagine dell'autore).

Una volta pulito il dataset, possiamo ora dividerlo in sottoinsiemi di addestramento e test e standardizzare le caratteristiche di input in modo che siano tutte comprese negli stessi intervalli. A questo punto, i dati di input vengono anche convertiti in array JAX.

X = df.iloc[:, :-1]y = df.iloc[:, -1]X_train, X_test, y_train, y_test = train_test_split(X, y,                                                     test_size=0.20,                                                     stratify=y)X_train, X_test, y_train, Y_test = jnp.array(X_train), jnp.array(X_test), \                                   jnp.array(y_train), jnp.array(y_test)scaler = StandardScaler()scaler.fit(X_train)X_train = scaler.transform(X_train)X_test = scaler.transform(X_test)

Per prevedere l’intervallo di prezzo dei telefoni, creeremo un modello di regressione logistica da zero. Per farlo, dobbiamo prima creare un paio di funzioni di supporto (una per creare la funzione di attivazione Sigmoid e un’altra per la funzione di perdita binaria).

def activation(r):    return 1 / (1 + jnp.exp(-r))def loss(c, w, X, y, lmbd=0.1):    p = activation(jnp.dot(X, w) + c)    loss = jnp.sum(y * jnp.log(p) + (1 - y) * jnp.log(1 - p)) / y.size    reg = 0.5 * lmbd * (jnp.dot(w, w) + c * c)     return - loss + reg 

Siamo ora pronti a creare il nostro ciclo di addestramento e a tracciare i risultati (Figura 2).

n_iter, eta = 100, 1e-1w = 1.0e-5 * jnp.ones(X.shape[1])c = 1.0history = [float(loss(c, w, X_train, y_train))]for i in range(n_iter):    c_current = c    c -= eta * grad(loss, argnums=0)(c_current, w, X_train, y_train)    w -= eta * grad(loss, argnums=1)(c_current, w, X_train, y_train)    history.append(float(loss(c, w, X_train, y_train)))
Figura 2: Storia dell'addestramento della regressione logistica (Immagine dell'autore).

Una volta soddisfatti dei risultati, possiamo quindi testare il modello sul nostro set di test (Figura 3).

y_pred = jnp.array(activation(jnp.dot(X_test, w) + c))y_pred = jnp.where(y_pred > 0.5, 1, 0) print(classification_report(y_test, y_pred))
Figura 3: Rapporto di classificazione sui dati di test (Immagine dell'autore).

Conclusioni

Come dimostrato in questo breve esempio, JAX ha un’API molto intuitiva che segue da vicino le convenzioni di Numpy, consentendo di utilizzare lo stesso codice per l’uso di CPU/GPU/TPU. Utilizzando questi blocchi di costruzione, è possibile creare modelli di Deep Learning altamente personalizzabili ottimizzati in modo nativo per le prestazioni.

Contatti

Se desideri essere aggiornato con i miei ultimi articoli e progetti, seguimi su VoAGI e iscriviti alla mia mailing list. Questi sono alcuni dei miei dettagli di contatto:

  • Linkedin
  • Sito web personale
  • Profilo VoAGI
  • GitHub
  • Kaggle

Bibliografia

[1] “Classificazione del prezzo dei dispositivi mobili” (ABHISHEK SHARMA). Accessibile a: https://thecleverprogrammer.com/2021/03/05/classificazione-del-prezzo-dei-dispositivi-mobili-con-machine-learning/ (Licenza MIT: https://github.com/alifrmf/Analisi-di-classificazione-e-previsione-del-prezzo-dei-dispositivi-mobili)