Storia dell’ottimizzazione Inferenza di Bloom

Storia dell'ottimizzazione di Bloom

Questo articolo ti offre il dietro le quinte su come abbiamo creato un efficiente server di inferenza che alimenta bloom. server di inferenza che alimenta https://huggingface.co/bigscience/bloom .

Siamo riusciti a ridurre il tempo di latenza di 5 volte in diverse settimane (e aumentare la velocità di elaborazione di 50 volte). Abbiamo voluto condividere tutte le difficoltà e le grandi vittorie che abbiamo incontrato per ottenere tali miglioramenti di velocità.

Sono state coinvolte molte persone diverse in molte fasi, quindi non tutto sarà trattato qui. E per favore abbiate pazienza, alcune parti potrebbero essere obsolete o completamente sbagliate perché stiamo ancora imparando come ottimizzare modelli estremamente grandi e molte nuove funzionalità hardware e contenuti continuano ad arrivare regolarmente.

Se la tua ottimizzazione preferita non viene discussa o viene rappresentata in modo improprio, ci scusiamo, ti preghiamo di condividerla con noi, siamo più che felici di provare cose nuove e correggere i nostri errori.

Questo va senza dire, ma senza il modello voluminoso che è accessibile in primo luogo, non ci sarebbero ragioni reali per ottimizzare l’inferenza per esso. Questo è stato un incredibile sforzo guidato da molte persone diverse.

Per massimizzare la GPU durante l’addestramento, sono state esplorate diverse soluzioni e alla fine è stato scelto Megatron-Deepspeed per addestrare il modello finale. Ciò significava che il codice così com’era non era necessariamente compatibile con la libreria transformers.

A causa del codice di addestramento originale, ci siamo proposti di fare qualcosa che facciamo regolarmente: portare un modello esistente in transformers . L’obiettivo era estrarre dal codice di addestramento le parti rilevanti e implementarle all’interno di transformers . Questo sforzo è stato affrontato da Younes . Questo non è certo uno sforzo piccolo, poiché ci sono voluti quasi un mese e 200 commit per arrivarci.

Ci sono diverse cose da notare che torneranno più avanti:

Avevamo bisogno di avere modelli più piccoli bigscience/bigscience-small-testing e bigscience/bloom-560m . Questo è estremamente importante perché sono più piccoli, quindi tutto è più veloce quando si lavora con loro.

Innanzitutto, devi abbandonare ogni speranza di avere gli stessi logits alla fine fino ai byte. Le versioni di PyTorch possono cambiare i kernel e introdurre differenze sottili e hardware diverso potrebbe produrre risultati diversi a causa di architetture diverse (e probabilmente non vuoi sviluppare su una GPU A100 tutto il tempo per motivi di costo).

Avere una buona suite di test rigorosi è davvero importante per tutti i modelli

Il miglior test che abbiamo trovato era avere un insieme fisso di prompt. Conosci il prompt, conosci il completamento che deve essere deterministico, quindi greedy. Se due generazioni sono identiche, puoi praticamente ignorare piccole differenze di logits. Ogni volta che vedi una deriva, devi indagare. Potrebbe essere che il tuo codice non stia facendo quello che dovrebbe O che tu sia effettivamente fuori dal dominio per quel modello e quindi il modello è più sensibile al rumore. Se hai diversi prompt e prompt abbastanza lunghi, è meno probabile che si verifichi per tutti i prompt per caso. Più prompt sono meglio, più lunghi sono meglio.

Il primo modello (small-testing) è in bfloat16 come il grande bloom, quindi tutto dovrebbe essere molto simile, ma non è stato addestrato molto o semplicemente non si comporta bene, quindi oscilla molto nelle uscite. Ciò significa che abbiamo avuto problemi con quei test di generazione. Il secondo modello è più stabile ma è stato addestrato e salvato in float16 invece di bfloat16 . Questo lascia più spazio per errori tra i due.

Per essere perfettamente onesti, la conversione bfloat16 -> float16 sembrava essere OK in modalità di inferenza ( bfloat16 esiste principalmente per gestire grandi gradienti, che non esistono nell’inferenza).

In quella fase, è stato scoperto e implementato un importante compromesso. Poiché bloom è stato addestrato in un ambiente distribuito, parte del codice faceva parallelismo dei tensori su un layer lineare, il che significa che l’esecuzione della stessa operazione come un’operazione singola su una singola GPU produceva risultati diversi. Ci è voluto un po’ per individuare la causa e abbiamo optato per la piena conformità e il modello era molto più lento, oppure abbiamo accettato una piccola differenza nella generazione ma era molto più veloce da eseguire e aveva un codice più semplice. Abbiamo optato per una flag configurabile.

Nota: Parallelismo del pipeline (PP) significa in questo contesto che ogni GPU avrà
alcuni layer su cui lavorare, quindi ogni GPU lavorerà su un dato chunk
di dati prima di passarlo alla GPU successiva.

Ora abbiamo una versione pulita di transformers con cui lavorare per farlo funzionare.

Bloom è un modello di 352GB (176 miliardi di parametri in bf16), abbiamo bisogno di almeno tanta memoria GPU per farlo stare. Abbiamo esplorato brevemente l’esternalizzazione su CPU su macchine più piccole, ma la velocità di inferenza era molto più lenta quindi l’abbiamo scartata.

Quindi volevamo fondamentalmente utilizzare la pipeline. Quindi è una sorta di autoalimentazione e questo è ciò che l’API utilizza sotto il cofano tutto il tempo.

Tuttavia, le pipeline non sono consapevoli della distribuzione (non è il loro obiettivo). Dopo una breve discussione sulle opzioni, abbiamo finito per utilizzare l’accelerazione creata di recente device_map=”auto” per gestire la suddivisione del modello. Abbiamo dovuto risolvere alcuni bug e correggere leggermente il codice transformers per aiutare accelerate a fare il lavoro giusto.

Funziona suddividendo i vari strati dei transformers e dando una parte del modello a ciascuna GPU. Quindi GPU0 inizia a lavorare, quindi lo passa a GPU1 e così via.

Alla fine, con un piccolo server HTTP in cima, siamo riusciti a iniziare a servire bloom (il grande modello) !!

Ma non abbiamo ancora iniziato a discutere di ottimizzazioni!

In realtà ne abbiamo abbastanza, tutto questo processo è un castello di carte. Durante le ottimizzazioni apporteremo modifiche al codice sottostante, è molto importante essere sicuri di non danneggiare il modello in un modo o nell’altro ed è più facile di quanto si pensi.

Quindi ora siamo al primo passo delle ottimizzazioni e abbiamo bisogno di iniziare a misurare e continuare a misurare le prestazioni. Quindi dobbiamo considerare cosa ci interessa. Per un server di inferenza aperto che supporta molte opzioni, ci aspettiamo che gli utenti inviino molte query con diversi parametri e ciò che ci interessa sono:

Il numero di utenti che possiamo servire contemporaneamente (throughput) Quanto tempo ci vuole per servire un utente medio (latenza)?

Abbiamo creato uno script di test in locust che è esattamente questo:

from locust import HttpUser, between, task
from random import randrange, random


class QuickstartUser(HttpUser):
    wait_time = between(1, 5)

    @task
    def bloom_small(self):
        sentence = "Traduci in cinese. EN: Mi piace la zuppa. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {"max_new_tokens": 20, "seed": random()},
            },
        )

    @task
    def bloom_small(self):
        sentence = "Traduci in cinese. EN: Mi piace la zuppa. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {
                    "max_new_tokens": 20,
                    "do_sample": True,
                    "top_p": 0.9,
                    "seed": random(),
                },
            },
        )

**Nota: questo non è il miglior né l’unico test di carico che abbiamo utilizzato, ma è stato sempre il primo ad essere eseguito in modo da poter confrontare in modo equo le diverse approcci. Essere il migliore in questo benchmark NON significa essere la soluzione migliore. È stato necessario utilizzare scenari più complessi oltre alle prestazioni effettive del mondo reale. **

Volevamo osservare la crescita graduale per diverse implementazioni e assicurarci anche che in condizioni di sovraccarico il server interrompesse correttamente il circuito. La rottura del circuito significa che il server può rispondere (velocemente) che non risponderà alla tua query perché troppi utenti stanno cercando di utilizzarlo contemporaneamente. È estremamente importante evitare il colpo di morte.

In questo benchmark le prestazioni iniziali sono state (su 16xA100 40Go su GCP, che è la macchina utilizzata in tutto):

Richieste/s: 0,3 (throughput) Latenza: 350ms/token (latenza)

Questi numeri non sono così buoni. Prima di metterci al lavoro, stimiamo il meglio che possiamo immaginare di raggiungere. La formula per la quantità di operazioni è 24Bsh^2 + 4𝐵s^2h24Bsh^2 + 4𝐵s^2h dove B è la dimensione del batch, s è la lunghezza della sequenza e h è la dimensione nascosta.

Facciamo i calcoli e otteniamo 17 TFlop per un singolo passaggio in avanti. Guardando le specifiche di A100, dichiara 312 TFLOPS per una singola scheda. Ciò significa che una singola GPU potrebbe potenzialmente funzionare a 17 / 312 = 54ms/token. Ne stiamo utilizzando 16 quindi 3ms/token sulla macchina complessiva. Prendi tutti questi numeri con una grande dose di scetticismo, non è mai possibile raggiungere quei numeri e le prestazioni reali raramente corrispondono alle specifiche. Inoltre, se il calcolo non è il fattore limitante, allora questo non è il minimo che puoi ottenere. È solo una buona pratica sapere quanto sei lontano dal tuo obiettivo. In questo caso, siamo a 2 ordini di grandezza quindi abbastanza lontani. Inoltre, questa stima mette tutti i flops al servizio della latenza, il che significa che solo una singola richiesta può procedere alla volta (va bene perché stai massimizzando la tua macchina quindi non c’è molto altro da fare, ma possiamo ottenere una latenza più alta e riavere il throughput attraverso il batching in modo molto più facile).

Nota: Il parallelismo dei tensori (TP) significa in questo contesto che ogni GPU avrà una parte dei pesi, quindi TUTTE le GPU sono attive tutto il tempo e fanno meno lavoro. Di solito questo comporta un leggero overhead in cui alcune operazioni vengono duplicate e, cosa più importante, le GPU devono comunicare regolarmente tra di loro i loro risultati per continuare il calcolo.

Ora che abbiamo una buona comprensione di dove ci troviamo, è il momento di mettersi al lavoro.

Abbiamo provato molte cose diverse in base alle persone e alle nostre diverse conoscenze.

TUTTI gli sforzi meritano un proprio post nel blog, quindi li elencherò solo, spiegherò alcune conclusioni finali e approfondirò i dettagli solo di ciò che è stato incluso nel server attuale. Passare dal parallelismo dei pipeline (PP) al parallelismo dei tensori (TP) è un grande cambiamento interessante per la latenza. Ogni GPU avrà una parte dei parametri e tutte lavoreranno contemporaneamente. Quindi la latenza dovrebbe diminuire drasticamente, ma il prezzo da pagare è l’overhead di comunicazione dal momento che devono comunicare regolarmente tra di loro i loro risultati.

Si nota che questa è una gamma molto ampia di approcci e l’intento era deliberatamente quello di imparare di più su ogni strumento e su come potesse essere utilizzato in future iniziative.

Portare il codice di JAX/Flax per eseguirlo su TPU:

  • Si pensava che sarebbe stato più facile scegliere il tipo di parallelismo, quindi TP dovrebbe essere più facile da testare. È uno dei vantaggi del design di Jax.
  • Più vincolato sull’hardware, le prestazioni su TPU sono probabilmente superiori rispetto alle GPU e c’è meno scelta di fornitori per le TPU.
  • Svantaggi: è necessario un altro porting. Ma sarebbe comunque benvenuto nelle nostre librerie.

Risultati:

  • Il porting non è stato un compito facile poiché alcune condizioni e kernel erano difficili da riprodurre correttamente a sufficienza. Tuttavia, era gestibile.
  • Il parallelismo è stato piuttosto facile da ottenere una volta eseguito il porting. Un plauso a Jax, la rivendicazione è vera.
  • Ray/la comunicazione con i lavoratori TPU si è rivelata un vero problema per noi. Non sappiamo se sia un problema dello strumento, della rete o semplicemente della nostra mancanza di conoscenza, ma ha rallentato gli esperimenti e il lavoro molto più di quanto previsto. Lanciavamo un esperimento che richiedeva 5 minuti per essere eseguito, aspettavamo 5 minuti e non succedeva nulla, 10 minuti dopo ancora nulla, alla fine si scopre che un lavoratore era giù/non rispondeva e dovevamo intervenire manualmente, capire cosa era successo, sistemarlo, riavviare qualcosa e rilanciare, e avevamo appena perso mezz’ora. Ripetere questo processo abbastanza volte fa accumulare rapidamente giorni persi. Sottolineiamo che non è necessariamente una critica agli strumenti che abbiamo usato, ma l’esperienza soggettiva che abbiamo avuto rimane.
  • Nessun controllo sulla compilazione. Una volta fatto funzionare la cosa, abbiamo provato diverse impostazioni per capire quale si adattasse meglio all’inferenza che avevamo in mente, ed è risultato davvero difficile indovinare dalle impostazioni cosa sarebbe successo in termini di latenza/flusso. Ad esempio, abbiamo ottenuto 0,3 rps con batch_size=1 (quindi ogni richiesta/utente è indipendente) con una latenza di 15 ms/token (non confrontare troppo con altri numeri in questo articolo perché si tratta di una macchina diversa con un profilo molto diverso), il che è ottimo, ma il flusso complessivo non è molto migliore rispetto al vecchio codice. Quindi abbiamo deciso di aggiungere il batching e con BS=2 la latenza è aumentata di 5 volte, con solo il doppio del flusso… Dopo ulteriori indagini, si è scoperto che fino a batch_size=16 ogni batch_size aveva lo stesso profilo di latenza. Quindi avremmo potuto ottenere un flusso 16 volte maggiore a un costo di latenza 5 volte superiore. Non male, ma guardando i numeri avremmo preferito un controllo più dettagliato. I numeri che stavamo cercando erano basati sulla regola dei 100 ms, 1 s, 10 s, 1 min.

Utilizzo di ONNX/TRT o di altri approcci compilati

  • Dovrebbero occuparsi della maggior parte del lavoro di ottimizzazione
  • Svantaggio: di solito il parallelismo deve essere gestito manualmente.

Risultati:

  • Si è scoperto che per essere in grado di tracciare/jit/esportare cose dovevamo rielaborare parte di PyTorch, in modo che si fondesse facilmente con l’approccio PyTorch puro. Complessivamente abbiamo capito che potevamo ottenere la maggior parte delle ottimizzazioni desiderate rimanendo nel mondo di PyTorch, consentendoci di mantenere flessibilità senza dover fare troppo sforzo di programmazione. Un’altra cosa da notare, poiché stiamo eseguendo su GPU e la generazione di testo comporta molteplici passaggi in avanti, abbiamo bisogno che i tensori rimangano sulla GPU ed è a volte difficile inviare i tensori a una libreria, ricevere il risultato, eseguire il calcolo dei logit (come argmax o campionamento) e reinviarlo nuovamente. Mettere il ciclo all’interno della libreria esterna significa perdere flessibilità proprio come Jax, quindi non è stato considerato nel nostro caso d’uso.

DeepSpeed

  • Questa è la tecnologia che ha alimentato l’addestramento, sembrava solo giusto usarla per l’elaborazione
  • Svantaggi, non è mai stata utilizzata/preparata per l’elaborazione prima d’ora.

Risultati:

  • Abbiamo avuto risultati davvero impressionanti in modo rapido, che sono approssimativamente gli stessi dell’ultima iterazione che stiamo eseguendo attualmente.
  • Abbiamo dovuto inventare un modo per mettere un server web (quindi gestire la concorrenza) su DeepSpeed che ha anche diversi processi (uno per ogni GPU). Dato che c’è una libreria eccellente Mii . Non si adatta agli obiettivi estremamente flessibili che avevamo in mente, ma probabilmente avremmo iniziato a lavorare su di essa adesso. (La soluzione attuale viene discussa in seguito).
  • Il problema più grande che abbiamo riscontrato con DeepSpeed è stata la mancanza di stabilità. Abbiamo avuto problemi quando l’abbiamo eseguito su CUDA 11.4 dove il codice era stato creato per 11.6 E il problema di lunga data che non siamo mai riusciti a risolvere è che ci sarebbero stati regolari crash del kernel (accesso illegale a Cuda, mismatch delle dimensioni, ecc..). Ne abbiamo risolti parecchi di questi, ma non siamo mai riusciti a ottenere una stabilità completa sotto lo stress del nostro server web. Nonostante ciò, voglio ringraziare le persone di Microsoft che ci hanno aiutato, abbiamo avuto una conversazione molto produttiva che ha migliorato la nostra comprensione di ciò che stava accadendo e ci ha dato spunti concreti per fare ulteriori lavori di follow-up.
  • Uno dei punti critici che sento è che il nostro team si trova principalmente in Europa, mentre Microsoft è in California, quindi la collaborazione è stata complicata dal punto di vista temporale e probabilmente abbiamo perso una grossa fetta di tempo a causa di ciò. Questo non ha nulla a che fare con la parte tecnica, ma è importante riconoscere che la parte organizzativa del lavoro in collaborazione è anche molto importante.
  • Un’altra cosa da notare è che DeepSpeed si basa su transformers per iniettare le sue ottimizzazioni, e dato che stavamo aggiornando il nostro codice in modo molto costante, è stato difficile per il team di DeepSpeed far funzionare le cose sul nostro ramo main. Ci scusiamo per aver reso le cose difficili, suppongo che sia per questo che viene chiamato bleeding edge.

Idee per il server web

  • Dato che stiamo per eseguire un server gratuito in cui gli utenti invieranno testi lunghi, testi brevi, vorranno pochi token o una ricetta intera, con diversi parametri, qui qualcosa doveva essere fatto.

Risultati:

  • Abbiamo riscritto tutto in Rust con gli eccellenti binding tch-rs . Rust non era mirato a ottenere vantaggi di prestazioni, ma solo un controllo molto più fine sul parallelismo (thread/processi) e una maggiore precisione sulla concorrenza del server web e di PyTorch. Python è notoriamente difficile da gestire nei dettagli di basso livello grazie al GIL .
  • Si è scoperto che la maggior parte dei problemi proveniva dal porting, e dopo quello, sperimentare è stato un gioco da ragazzi. E abbiamo capito che con un controllo sufficiente sui cicli avremmo potuto ottenere ottime prestazioni per tutti anche nel contesto di una vasta gamma di richieste con diverse proprietà. Codice per i curiosi, ma non viene fornito alcun supporto o documentazione.
  • È diventato produttivo per alcune settimane perché era più flessibile sul parallelismo, potevamo utilizzare le GPU in modo più efficiente (utilizzando la GPU0 per la richiesta 1 mentre la GPU1 sta elaborando la richiesta 0). e siamo passati da 0.3 RPS a ~2.5 RPS con la stessa latenza. Il caso ottimale sarebbe stato aumentare la capacità di elaborazione di 16 volte, ma i numeri mostrati qui sono misurazioni reali del carico di lavoro, quindi non è così male.

PyTorch puro

  • Modificare completamente il codice esistente per renderlo più veloce rimuovendo operazioni come reshape , utilizzando kernel meglio ottimizzati e così via.
  • Svantaggio, dobbiamo codificare noi stessi TP e abbiamo il vincolo che il codice si adatti ancora alla nostra libreria (principalmente).

Risultati

  • Prossimo capitolo.

Scrivere PyTorch più efficiente

Il primo punto dell’elenco era rimuovere le operazioni non necessarie nelle prime implementazioni. Alcune possono essere individuate semplicemente guardando il codice e individuando errori evidenti:

  • Alibi viene utilizzato in Bloom per aggiungere le posizioni di inserimento ed era calcolato in troppi punti, potevamo calcolarlo solo una volta e in modo più efficiente.

Il vecchio codice: link Il nuovo codice: link

Questo è un incremento di velocità di 10 volte e l’ultima versione include anche il padding! Dal momento che questa fase viene calcolata solo una volta, la velocità effettiva non è importante, ma ridurre il numero di operazioni e la creazione di tensori è una buona direzione complessiva.

Altre parti emergono più chiaramente quando si inizia il profiling e abbiamo utilizzato in modo piuttosto estensivo l’estensione tensorboard

Questo fornisce questo tipo di immagine che offre spunti:

Attenzione richiede molto tempo, attenzione questa è una vista CPU quindi le barre lunghe non significano lunghe, ma significano che la CPU sta aspettando i risultati della GPU del passo precedente. Vediamo molte operazioni `cat` prima di `baddbmm`.

Rimuovendo molte reshape/traspose, ad esempio, abbiamo scoperto che: – L’attention è il percorso critico (è previsto ma è sempre bene verificare). – Nell’attention, molti kernel erano effettivamente copie dovute alla massiccia quantità di reshape – Abbiamo potuto rimuovere le reshape rielaborando i pesi stessi e il passato. Questo è un cambiamento che richiede una modifica al codice ma ha migliorato notevolmente le prestazioni!

Supporto per TP

Ok, abbiamo rimosso la maggior parte dei miglioramenti immediati, siamo passati da una latenza di 350ms/token a 300ms/token in PP. Questo è un riduzione del 15% nella latenza, ma in realtà ha fornito più di quello, ma non siamo stati estremamente rigorosi nella nostra misurazione iniziale quindi atteniamoci a quella cifra.

Poi abbiamo fornito un’implementazione TP. Si è rivelata molto più veloce di quanto ci aspettassimo, l’implementazione ha richiesto mezza giornata di un singolo sviluppatore esperto. Il risultato è qui . Siamo stati in grado anche di riutilizzare codice da altri progetti che ha aiutato.

La latenza è passata direttamente da 300ms/token a 91ms/token, il che rappresenta un enorme miglioramento dell’esperienza utente. Una semplice richiesta di 20 token è passata da 6s a 2s, passando da un’esperienza “lenta” a una leggermente ritardata.

Inoltre, il throughput è aumentato molto a 10RPS. Il throughput deriva dal fatto che eseguire una query con batch_size=1 richiede lo stesso tempo di batch_size=32 e il throughput diventa essenzialmente gratuito in termini di latenza.

Miglioramenti immediati

Ora che avevamo un’implementazione TP, potevamo iniziare di nuovo a profilare e ottimizzare. È un cambiamento abbastanza significativo da dover ricominciare da capo.

La prima cosa che si evidenzia è che la sincronizzazione (ncclAllReduce) inizia a diventare una parte preponderante del carico, il che è previsto, questa è la parte di sincronizzazione ed effettivamente richiede del tempo. Non abbiamo provato a cercare e ottimizzare questo aspetto in quanto già utilizza nccl ma potrebbe comunque esserci margine di miglioramento. Abbiamo supposto che sarebbe stato difficile fare molto meglio.

La seconda cosa è che l’operatore Gelu stava lanciando molti kernel elementwise e nel complesso stava prendendo una quota di calcolo più grande di quanto ci aspettassimo.

Abbiamo apportato il cambiamento da:

def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

a

@torch.jit.script
def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

Questo trasforma le operazioni da molti piccoli kernel element-wise (e quindi copie di tensori) a un’unica operazione di kernel!

Questo ha fornito un miglioramento del 10% nella latenza, passando da 91ms/token a 81ms/token, giusto lì!

Attenzione però, questo non è una scatola nera magica che si può semplicemente applicare ovunque, la fusione dei kernel potrebbe non necessariamente accadere o le operazioni precedentemente utilizzate potrebbero già essere estremamente efficienti.

I luoghi in cui abbiamo scoperto che funziona bene sono:

  • Hai molte operazioni piccole/elementwise
  • Hai un punto critico con poche reshape difficili da rimuovere, copie in generale
  • Quando avviene la fusione.

Fallimento epico

Inoltre, durante i nostri periodi di test, abbiamo riscontrato alcuni punti in cui il server Rust aveva una latenza costantemente inferiore del 25% rispetto a quello Python. Questo era piuttosto strano, ma poiché era costantemente misurato e poiché la rimozione dei kernel forniva un aumento di velocità, eravamo dell’impressione che eliminare il sovraccarico di Python potesse fornire una bella spinta.

Abbiamo iniziato un lavoro di 3 giorni per reimplementare le parti necessarie di torch.distributed per farlo funzionare nel mondo Rust con nccl-rs. Avevamo la versione funzionante ma c’era qualcosa di sbagliato nelle generazioni rispetto alla controparte Python. Durante l’indagine dei problemi, abbiamo capito… che avevamo dimenticato di rimuovere il profiler nelle misurazioni di Pytorch

Questo è stato un epic fail perché rimuoverlo ci ha restituito il 25% e poi entrambi i codici sono stati eseguiti alla stessa velocità. Questo è quello che ci aspettavamo inizialmente, che Python non dovesse avere un impatto sulle prestazioni, dato che sta principalmente eseguendo il codice cpp di torch. Alla fine, 3 giorni non sono la fine del mondo e potrebbe diventare utile in futuro, ma è comunque piuttosto brutto. Questo è abbastanza comune quando si fanno ottimizzazioni per fare misurazioni sbagliate o fuorvianti che finiscono per essere deludenti o addirittura dannose per il prodotto complessivo. Ecco perché farlo a piccoli passi e avere aspettative sul risultato il prima possibile aiuta a contenere quel rischio.

Un altro punto in cui abbiamo dovuto fare molta attenzione è stato il passaggio iniziale in avanti (senza passato) e i passaggi successivi in avanti (con passato). Se ottimizzi il primo, è molto probabile che rallenti i successivi che sono molto più importanti e rappresentano la maggior parte del tempo di esecuzione. Un altro colpevole abbastanza comune è rappresentare i tempi che sono tempi CPU e non tempi CUDA effettivi, quindi è necessario utilizzare torch.cuda.synchronize() durante le esecuzioni per essere sicuri che i kernel siano completi.

Kernel personalizzato

Fino ad ora, abbiamo raggiunto prestazioni simili a DeepSpeed senza alcun codice personalizzato al di fuori di PyTorch! Piuttosto interessante. Inoltre, non abbiamo dovuto fare compromessi sulla flessibilità della dimensione del batch durante l’esecuzione!

Ma data l’esperienza con DeepSpeed, volevamo provare a scrivere un kernel personalizzato per fondere alcune operazioni nel percorso più critico dove torch.jit.script non è stato in grado di farlo per noi. Fondamentalmente le seguenti due righe:

attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

Il primo masked fill sta creando un nuovo tensore, che serve solo a dire all’operatore softmax di ignorare quei valori. Inoltre, il softmax deve essere calcolato su float32 (per la stabilità), ma all’interno di un kernel personalizzato potremmo limitare la quantità di upcasting necessario in modo da limitarli alle somme e agli accumuli effettivi.

Il codice può essere trovato qui . Tieni presente che avevamo un’architettura GPU singola su cui concentrarci, quindi abbiamo potuto focalizzarci su questo e non siamo esperti (ancora) nel scrivere kernel, quindi potrebbero esserci modi migliori per farlo.

Questo kernel personalizzato ha fornito un ulteriore aumento del 10% della latenza, passando da una latenza di 81ms/token a 71ms/token. Tutto ciò mantenendo la nostra flessibilità.

Dopo ciò, abbiamo indagato ed esplorato altre cose come la fusione di più operatori, la rimozione di altre ridimensionamenti o il loro inserimento in altri punti. Ma nessun tentativo ha mai avuto un impatto significativo sufficiente da arrivare alle versioni finali.

Parte del webserver

Come nel caso di Rust, anche noi abbiamo dovuto implementare la suddivisione delle richieste con diversi parametri. Dato che ci trovavamo nel mondo di PyTorch, avevamo praticamente il pieno controllo su ciò che stava accadendo. Tuttavia, essendo in Python, abbiamo il fattore limitante che torch.distributed deve essere eseguito su diversi processi anziché thread, il che significa che è leggermente più difficile comunicare tra i processi. Alla fine, abbiamo optato per comunicare stringhe grezze tramite un canale di pubblicazione/sottoscrizione Redis per distribuire le richieste a tutti i processi contemporaneamente. Essendo in processi diversi, è più facile farlo in questo modo anziché comunicare tensori (che sono molto più grandi), ad esempio.

Poi abbiamo dovuto eliminare l’uso di generate poiché questo applica i parametri a tutti i membri del batch, e in realtà vogliamo applicare un diverso set di parametri. Fortunatamente, possiamo riutilizzare elementi di livello inferiore come il LogitsProcessor per risparmiarci molto lavoro.

Quindi abbiamo ricostruito una funzione generate che prende una lista di parametri e li applica a ciascun membro del batch.

Un altro aspetto davvero importante dell’esperienza utente finale è la latenza. Dato che abbiamo diversi set di parametri per diverse richieste, potremmo avere

1 richiesta per 20 token e l’altra per 250 token. Dato che ci vogliono 75ms/token di latenza, una richiesta richiede 1,5 s e l’altra 18 s. Se facessimo il batching fino in fondo, costringeremmo l’utente che ha fatto la richiesta a aspettare 18 s e gli sembrerebbe che stiamo eseguendo a 900ms/token, il che è piuttosto lento!

Dato che ci troviamo in un mondo PyTorch con estrema flessibilità, ciò che possiamo fare invece è estrarre dal batch la prima richiesta non appena generiamo i primi 20 token e restituirli a quell’utente entro l’1,5 secondi richiesti! Riusciamo anche a risparmiare il calcolo di 230 token.

Quindi la flessibilità è importante per ottenere la migliore latenza possibile.

L’ottimizzazione è un lavoro senza fine e, come qualsiasi altro progetto, il 20% del lavoro di solito produce l’80% dei risultati. Ad un certo punto, abbiamo iniziato ad avere una piccola strategia di test per individuare i possibili risultati di alcune idee che avevamo e se i test non producevano risultati significativi, allora scartavamo l’idea. Un giorno per un aumento del 10% è abbastanza prezioso, due settimane per un aumento del 10 volte è abbastanza prezioso. Due settimane per un aumento del 10% non è così interessante.

Hai provato…?

Roba che sappiamo esistere e non abbiamo usato per vari motivi. Potrebbe essere che sembrava non adattato al nostro caso d’uso, che richiedeva troppo lavoro, i risultati non sembravano abbastanza promettenti o semplicemente avevamo troppe opzioni da provare e ne scartavamo alcune senza un motivo particolare, solo per mancanza di tempo. Di seguito non c’è un ordine particolare:

  • Grafici Cuda
  • nvFuser (Questo è ciò che alimenta torch.jit.script quindi l’abbiamo usato.)
  • FasterTransformer
  • Triton di Nvidia
  • XLA (Anche Jax sta usando xla!)
  • torch.fx
  • TensorRT

Non esitate a contattarci se il vostro strumento preferito manca qui o se pensate che abbiamo trascurato qualcosa di importante che potrebbe risultare utile!

Attenzione flash

Abbiamo brevemente esaminato l’integrazione dell’attenzione flash e, sebbene funzioni estremamente bene nella prima passata in avanti (senza past_key_values), non ha portato miglioramenti significativi quando viene eseguita utilizzando past_key_values. Poiché dovevamo adattarla per includere il tensore alibi nel calcolo, abbiamo deciso di non farlo (almeno non ancora).

OpenAI Triton

Triton è un ottimo framework per la creazione di kernel personalizzati in Python. Vogliamo usarlo di più, ma finora non l’abbiamo fatto. Saremmo entusiasti di vedere se si comporta meglio del nostro kernel Cuda. Scrivere direttamente in Cuda sembrava il percorso più breve per il nostro obiettivo quando abbiamo considerato le opzioni per quella parte.

Padding e Reshape

Come menzionato in tutto questo articolo, ogni copia del tensore ha un costo e un altro costo nascosto dell’esecuzione di produzione è il padding. Quando arrivano due query con lunghezze molto diverse, è necessario eseguire il padding (usare un token fittizio) per farle rientrare in un quadrato. Ciò porta a molti calcoli inutili. Maggiori informazioni .

Idealemente, potremmo evitare del tutto questi calcoli e non avere mai dei reshape. Tensorflow ha il concetto di RaggedTensor e Pytorch ha i tensori nidificati . Entrambi sembrano meno efficienti rispetto ai tensori regolari, ma potrebbero consentirci di eseguire meno calcoli, il che è sempre un vantaggio.

In un mondo ideale, l’interferenza completa verrebbe scritta in CUDA o implementazione pura su GPU. Considerando i miglioramenti delle prestazioni ottenuti quando si possono fonderà le operazioni, sembra desiderabile. Ma in che misura ciò sarebbe vantaggioso, non ne abbiamo idea. Se le persone più esperte in GPU hanno idee, siamo interessati ad ascoltare!

Tutto questo lavoro è il risultato della collaborazione di molti membri del team di HF. In nessun ordine particolare, @ThomasWang @stas @Nouamane @Suraj @Sanchit @Patrick @Younes @Sylvain @Jeff (Microsoft) @Reza E tutta l’organizzazione di BigScience.