Stanford Research presenta FlashAttention-2 un balzo in avanti in termini di velocità ed efficienza per i modelli di linguaggio a lungo contesto
Stanford Research introduces FlashAttention-2, a leap forward in terms of speed and efficiency for long-context language models.
Nell’ultimo anno, l’elaborazione del linguaggio naturale ha visto notevoli progressi con l’emergere di modelli linguistici dotati di contesti significativamente più lunghi. Tra questi modelli ci sono GPT-4 con una lunghezza del contesto di 32k, MPT di MosaicML con 65k di contesto e Claude di Anthropic, che vanta un’impressionante lunghezza del contesto di 100k. Con l’aumento di applicazioni come la ricerca di documenti lunghi e la scrittura di storie, diviene evidente la necessità di modelli linguistici con contesto esteso. Tuttavia, la sfida sta nel aumentare la lunghezza del contesto dei Transformers, poiché il loro strato di attenzione richiede risorse computazionali e di memoria che crescono quadraticamente con la lunghezza della sequenza di input.
Per affrontare questa sfida, FlashAttention, un algoritmo innovativo rilasciato solo un anno fa, ha ottenuto un rapido successo in varie organizzazioni e laboratori di ricerca. Questo algoritmo ha accelerato con successo il calcolo dell’attenzione riducendo l’impronta di memoria senza sacrificare l’accuratezza o approssimare i risultati. Con una velocità 2-4 volte superiore rispetto alle basi ottimizzate al suo rilascio iniziale, FlashAttention si è dimostrato un avanzamento rivoluzionario. Tuttavia, aveva ancora un potenziale inespresso, poiché non raggiungeva le operazioni di moltiplicazione di matrici ottimizzate (GEMM) che raggiungevano fino a 124 TFLOPs/s sulle GPU A100.
Facendo il prossimo salto in avanti, gli sviluppatori di FlashAttention hanno ora introdotto FlashAttention-2, una versione reinventata che supera significativamente il suo predecessore. Sfruttando le librerie CUTLASS 3.x e CuTe di Nvidia, FlashAttention-2 raggiunge un notevole aumento di velocità, arrivando fino a 230 TFLOPs/s sulle GPU A100. Inoltre, nell’addestramento end-to-end dei modelli linguistici di stile GPT, FlashAttention-2 raggiunge una velocità di addestramento fino a 225 TFLOPs/s, con un impressionante utilizzo del 72% delle FLOP del modello.
- Migliori strumenti di intelligenza artificiale per proteggerti nel futuro (2023)
- Guida passo-passo a Word2Vec con Gensim
- La genialità strategica di Meta Llama 2 potrebbe essere il loro nuovo grafo sociale
Le principali migliorie di FlashAttention-2 risiedono nella sua migliore parallelismo e nelle strategie di suddivisione del lavoro. Inizialmente, FlashAttention parallelizzava sulla dimensione della dimensione del batch e sul numero di testate, utilizzando efficacemente le risorse di calcolo sulla GPU. Tuttavia, per sequenze lunghe con dimensioni di batch più piccole o meno testate, FlashAttention-2 ora parallelizza sulla dimensione della lunghezza della sequenza, ottenendo un notevole aumento di velocità in questi scenari.
Un’altra miglioria riguarda la suddivisione efficiente del lavoro tra i diversi warp all’interno di ogni blocco di thread. In FlashAttention, la suddivisione di K e V su quattro warp mantenendo Q accessibile a tutti i warp, chiamata “schema con K suddiviso”, ha comportato letture e scritture di memoria condivisa inutili, rallentando il calcolo. FlashAttention-2 adotta un approccio diverso, suddividendo ora Q su quattro warp mantenendo K e V accessibili a tutti i warp. Ciò elimina la necessità di comunicazione tra i warp e riduce significativamente le letture/scritture di memoria condivisa, aumentando ulteriormente le prestazioni.
FlashAttention-2 introduce diverse nuove funzionalità per ampliare la sua applicabilità e migliorare le sue capacità. Ora supporta dimensioni delle testate fino a 256, consentendo modelli come GPT-J, CodeGen, CodeGen2 e StableDiffusion 1.x, aprendo ulteriori opportunità di aumento di velocità e risparmio di memoria. Inoltre, FlashAttention-2 adotta le varianti di attenzione multi-query (MQA) e grouped-query attention (GQA), in cui più testate della query possono prestare attenzione alla stessa testata di chiave e valore, ottenendo un maggiore throughput di inferenza e migliori prestazioni.
Le prestazioni di FlashAttention-2 sono davvero impressionanti. Valutato su una GPU A100 80GB SXM4, ottiene un aumento di velocità di circa 2x rispetto al suo predecessore e fino a 9x rispetto a un’implementazione standard di attenzione in PyTorch. Inoltre, quando viene utilizzato per l’addestramento end-to-end dei modelli di stile GPT, FlashAttention-2 raggiunge fino a 225 TFLOPs/s sulle GPU A100, rappresentando un aumento di velocità end-to-end del 1.3x rispetto a modelli già altamente ottimizzati con FlashAttention.
Riguardo al futuro, le potenziali applicazioni di FlashAttention-2 sono promettenti. Con la capacità di addestrare modelli con contesti 16k più lunghi allo stesso prezzo dei modelli con contesto 8k precedenti, questa tecnologia può aiutare ad analizzare libri lunghi, report, immagini ad alta risoluzione, audio e video. Sono in corso piani per una maggiore applicabilità su dispositivi come le GPU H100 e le GPU AMD e per l’ottimizzazione per nuovi tipi di dati come fp8. Inoltre, la combinazione delle ottimizzazioni a basso livello di FlashAttention-2 con i cambiamenti algoritmici ad alto livello potrebbe aprire la strada all’addestramento di modelli di intelligenza artificiale con un contesto straordinariamente più lungo. La collaborazione con i ricercatori dei compilatori per migliorare la programmabilità è anche all’orizzonte, promettendo un futuro luminoso per la prossima generazione di modelli di linguaggio.