Fine-tuning di Llama 2 70B utilizzando PyTorch FSDP

Fine-tuning di Llama 2 70B con PyTorch FSDP.

Introduzione

In questo post del blog, vedremo come ottimizzare Llama 2 70B utilizzando PyTorch FSDP e le migliori pratiche correlate. Utilizzeremo Hugging Face Transformers, Accelerate e TRL. Impareremo anche come utilizzare Accelerate con SLURM.

La Parallelizzazione dei Dati Completamente Frammentati (FSDP) è un paradigma in cui gli stati dell’ottimizzatore, i gradienti e i parametri sono frammentati tra i dispositivi. Durante il passaggio in avanti, ogni unità FSDP esegue un’operazione di raccolta completa per ottenere i pesi completi, viene eseguita la computazione seguita dall’eliminazione dei frammenti dagli altri dispositivi. Dopo il passaggio in avanti, viene calcolata la perdita seguita dal passaggio all’indietro. Nel passaggio all’indietro, ogni unità FSDP esegue un’operazione di raccolta completa per ottenere i pesi completi, con la computazione eseguita per ottenere i gradienti locali. Questi gradienti locali vengono mediati e frammentati tra i dispositivi tramite un’operazione di riduzione-dispersone in modo che ogni dispositivo possa aggiornare i parametri del suo frammento. Per ulteriori informazioni su cosa è PyTorch FSDP, si prega di fare riferimento a questo post del blog: Accelerare l’addestramento di modelli di grandi dimensioni utilizzando PyTorch Parallelizzazione dei Dati Completamente Frammentati.

(Fonte: link)

Hardware Utilizzato

Numero di nodi: 2. Il minimo richiesto è 1. Numero di GPU per nodo: 8. Tipo di GPU: A100. Memoria delle GPU: 80GB. Connessione intra-nodo: NVLink. RAM per nodo: 1TB. Core CPU per nodo: 96. Connessione inter-nodo: Elastic Fabric Adapter

Sfide con la messa a punto di LLaMa 70B

Abbiamo riscontrato tre sfide principali nel tentativo di ottimizzare LLaMa 70B con FSDP:

  1. FSDP avvolge il modello dopo aver caricato il modello pre-addestrato. Se ogni processo/rank all’interno di un nodo carica il modello Llama-70B, richiederebbe 70*4*8 GB ~ 2TB di RAM CPU, dove 4 è il numero di byte per parametro e 8 è il numero di GPU su ogni nodo. Ciò porterebbe alla memoria RAM della CPU che si esaurisce e ai processi che vengono terminati.

  2. Salvare interi checkpoint intermedi utilizzando FULL_STATE_DICT con il trasferimento CPU su rank 0 richiede molto tempo e spesso porta a errori di timeout NCCL a causa di un blocco indefinito durante la trasmissione. Tuttavia, alla fine dell’addestramento, vogliamo l’intero stato del modello anziché lo stato frammentato compatibile solo con FSDP.

  3. Dobbiamo migliorare la velocità e ridurre l’utilizzo della VRAM per addestrare più velocemente e risparmiare sui costi di calcolo.

Vediamo come risolvere le sfide sopra e ottimizzare un modello 70B!

Prima di iniziare, ecco tutte le risorse necessarie per riprodurre i nostri risultati:

  1. Codebase: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training con patch di scimmia flash-attn V2

  2. Configurazione FSDP: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml

  3. Script SLURM launch.slurm: https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25

  4. Modello: meta-llama/Llama-2-70b-chat-hf

  5. Dataset: smangrul/code-chat-assistant-v1 (miscela di LIMA+GUANACO con formattazione corretta in un formato pronto per l’addestramento)

Prerequisiti

Seguire prima questi passaggi per installare Flash Attention V2: Dao-AILab/flash-attention: Fast and memory-efficient exact attention (github.com). Installare le ultime versioni notturne di PyTorch con CUDA ≥11.8. Installare i requisiti rimanenti come indicato in DHS-LLM-Workshop/code_assistant/training/requirements.txt. Qui, installeremo 🤗 Accelerate e 🤗 Transformers dal ramo principale.

Ottimizzazione

Affrontare la sfida 1

I PR huggingface/transformers#25107 e huggingface/accelerate#1777 risolvono la prima sfida e non richiedono modifiche al codice da parte dell’utente. Fanno quanto segue:

  1. Crea il modello senza pesi su tutti i ranghi (utilizzando il dispositivo meta).
  2. Carica lo stato del dizionario solo su rank==0 e imposta i pesi del modello con tale stato del dizionario su rank 0.
  3. Per tutti gli altri ranghi, esegui torch.empty(*param.size(), dtype=dtype) per ogni parametro sul dispositivo meta.
  4. Quindi, rank==0 avrà caricato il modello con il corretto stato del dizionario mentre tutti gli altri ranghi avranno pesi casuali.
  5. Imposta sync_module_states=True in modo che l’oggetto FSDP si occupi di trasmetterli a tutti i ranghi prima dell’inizio dell’addestramento.

Di seguito è riportato uno snippet di output su un modello da 7B su 2 GPU che misura la memoria consumata e i parametri del modello in varie fasi. Possiamo osservare che durante il caricamento del modello preaddestrato il rank 0 e il rank 1 hanno una memoria totale di picco della CPU di 32744 MB e 1506 MB, rispettivamente. Pertanto, solo il rank 0 sta caricando il modello preaddestrato, garantendo un uso efficiente della RAM della CPU. È possibile trovare tutti i log qui

accelerator.process_index=0 Memoria GPU prima dell'ingresso del caricamento: 0
accelerator.process_index=0 Memoria GPU consumata alla fine del caricamento (fine-inizio): 0
accelerator.process_index=0 Memoria di picco GPU consumata durante il caricamento (massimo-inizio): 0
accelerator.process_index=0 Memoria totale di picco GPU consumata durante il caricamento (massimo): 0
accelerator.process_index=0 Memoria CPU prima dell'ingresso del caricamento: 926
accelerator.process_index=0 Memoria CPU consumata alla fine del caricamento (fine-inizio): 26415
accelerator.process_index=0 Memoria di picco CPU consumata durante il caricamento (massimo-inizio): 31818
accelerator.process_index=0 Memoria totale di picco CPU consumata durante il caricamento (massimo): 32744

accelerator.process_index=1 Memoria GPU prima dell'ingresso del caricamento: 0
accelerator.process_index=1 Memoria GPU consumata alla fine del caricamento (fine-inizio): 0
accelerator.process_index=1 Memoria di picco GPU consumata durante il caricamento (massimo-inizio): 0
accelerator.process_index=1 Memoria totale di picco GPU consumata durante il caricamento (massimo): 0
accelerator.process_index=1 Memoria CPU prima dell'ingresso del caricamento: 933
accelerator.process_index=1 Memoria CPU consumata alla fine del caricamento (fine-inizio): 10
accelerator.process_index=1 Memoria di picco CPU consumata durante il caricamento (massimo-inizio): 573
accelerator.process_index=1 Memoria totale di picco CPU consumata durante il caricamento (massimo): 1506

Affrontare la sfida 2

Viene risolta scegliendo il tipo di stato del dizionario SHARDED_STATE_DICT durante la creazione della configurazione FSDP. SHARDED_STATE_DICT salva il frammento per GPU separatamente, il che rende veloce salvare o riprendere l’addestramento da un checkpoint intermedio. Quando viene utilizzato FULL_STATE_DICT, il primo processo (rank 0) raccoglie l’intero modello sulla CPU e quindi lo salva in un formato standard.

Creiamo la configurazione di accelerate tramite il seguente comando:

accelerate config --config_file "fsdp_config.yaml"

La configurazione risultante è disponibile qui: fsdp_config.yaml. Qui, la strategia di sharding è FULL_SHARD. Stiamo utilizzando TRANSFORMER_BASED_WRAP per la politica di avvolgimento automatico e utilizza _no_split_module per trovare il nome del blocco Transformer per il nested FSDP auto wrap. Utilizziamo SHARDED_STATE_DICT per salvare i checkpoint intermedi e gli stati dell’ottimizzatore in questo formato raccomandato dal team PyTorch. Assicurarsi di abilitare la trasmissione dei parametri dei moduli dal rank 0 all’inizio, come indicato nel paragrafo precedente sulla sfida 1. Abilitiamo l’addestramento in precisione mista bf16.

Per l’ultimo checkpoint in cui viene utilizzato l’intero stato del modello, viene utilizzato lo snippet di codice seguente:

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir) # in alternativa, trainer.push_to_hub() se l'intero ckpt è inferiore a 50GB, poiché il limite LFS per file è di 50GB 

Affrontare la sfida 3

Sono richiesti Flash Attention e l’abilitazione del checkpointing del gradiente per un addestramento più veloce e una riduzione dell’uso di VRAM per consentire il fine-tuning e risparmiare sui costi di calcolo. Attualmente, il codice utilizza il monkey patching e l’implementazione si trova in chat_assistant/training/llama_flash_attn_monkey_patch.py.

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness introduce un modo per calcolare l’attenzione esatta in modo più rapido ed efficiente in termini di memoria, sfruttando la conoscenza della gerarchia di memoria dell’hardware/GPU sottostante: maggiore è la larghezza di banda/velocità della memoria, minore è la sua capacità in quanto diventa più costosa.

Se seguiamo il blog Making Deep Learning Go Brrrr From First Principles, possiamo capire che il modulo Attention sull’hardware attuale è limitato dalla memoria/limitato dalla larghezza di banda. Il motivo è che l’Attention consiste principalmente in operazioni elemento per elemento come mostrato di seguito a sinistra. Possiamo osservare che le operazioni di mascheramento, softmax e dropout occupano la maggior parte del tempo anziché le moltiplicazioni tra matrici, che costituiscono la maggior parte delle FLOP.

(Fonte: link)

Questo è esattamente il problema che Flash Attention risolve. L’idea è rimuovere le letture/scritture ridondanti sulla HBM. Lo fa mantenendo tutto nella SRAM, eseguendo tutti i passaggi intermedi e scrivendo solo il risultato finale sulla HBM, anche conosciuto come Kernel Fusion. Di seguito è illustrato come questo supera il collo di bottiglia legato alla memoria.

(Fonte: link)

Tiling viene utilizzato durante i passaggi in avanti e all’indietro per suddividere il calcolo dei softmax/punteggi NxN in blocchi per superare la limitazione della dimensione della memoria SRAM. Per abilitare il tiling, viene utilizzato l’algoritmo softmax online. Viene utilizzato il ricomputo durante il passaggio all’indietro per evitare di memorizzare l’intera matrice softmax/punteggio NxN durante il passaggio in avanti. Ciò riduce notevolmente il consumo di memoria.

Per una comprensione semplificata e approfondita di Flash Attention, consultare i post del blog ELI5: FlashAttention e Making Deep Learning Go Brrrr From First Principles insieme all’articolo originale FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Mettere tutto insieme

Per eseguire l’addestramento utilizzando il lanciatore Accelerate con SLURM, fare riferimento a questo gist launch.slurm. Di seguito è riportato un comando equivalente che mostra come utilizzare il lanciatore Accelerate per eseguire l’addestramento. Notare che stiamo sovrascrivendo i valori di main_process_ip, main_process_port, machine_rank, num_processes e num_machines del file fsdp_config.yaml. Inoltre, un altro punto importante da notare è che lo storage viene condiviso tra tutti i nodi.

accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$MACHINE_RANK \
    --num_processes 16 \
    --num_machines 2 \
    train.py \
    --model_name "meta-llama/Llama-2-70b-chat-hf" \
    --dataset_name "smangrul/code-chat-assistant-v1" \
    --max_seq_len 2048 \
    --max_steps 500 \
    --logging_steps 25 \
    --eval_steps 100 \
    --save_steps 250 \
    --bf16 True \
    --packing True \
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing True \
    --learning_rate 5e-5  \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --use_flash_attn True

La messa a punto è stata completata in circa 13,5 ore e di seguito è riportato il grafico della perdita di addestramento. Calcoliamo l’utilizzo dei flop del modello (MFU) per l’esecuzione dell’addestramento.

  1. Le GPU A100 eseguono circa 3,12e14 FLOPS al secondo (in float32 o bfloat16)
  2. Numero di token addestrati negli esperimenti precedenti = lunghezza della sequenza * dimensione del batch * numero di passaggi di addestramento = (2048 * 16 *
    1. = 16.384.000 = 1.64e7
  3. Calcolo approssimativo per gli esperimenti precedenti = 6 * P (num_params) * D (num_tokens) = 6 * 7e10 * 1.64e7 = 6.89e18 FLOPS
  4. FLOPS di addestramento al secondo = Calcolo approssimativo / tempo di addestramento = 6.89e17 / (13.5 * 3600) = 6.89e18 / 4.86e4 = 1.42e14 FLOPS al secondo
  5. MFU = FLOPS di addestramento al secondo / Prestazioni di picco A100 = 1.42e14 / 3.12e14 = 0.4551 = 45,51% delle prestazioni di picco

Di seguito è riportata una conversazione di esempio utilizzando il modello sopra:

Prompt di sistema: Sei un assistente disponibile, rispettoso e onesto. Rispondi sempre nel modo più utile possibile, garantendo la sicurezza. Le tue risposte non devono includere contenuti dannosi, non etici, razzisti, sessisti, tossici, pericolosi o illegali. Assicurati che le tue risposte siano socialmente imparziali e positive.

Se una domanda non ha senso o non è coerente dal punto di vista fattuale, spiega il motivo anziché fornire una risposta non corretta. Se non conosci la risposta a una domanda, per favore non fornire informazioni false.

- Persona: Cos'è il Deep Learning? Spiegalo come un Pirata.

+ Assistente: Arrr Grumete! Cerchi una spiegazione adeguata del Deep Learning? 
+ Bene, allora ascolta attentamente, mio caro! Il Deep Learning è un tipo di apprendimento automatico che utilizza reti neurali artificiali per analizzare i dati. 
+ Aye, hai capito bene, amico! Le reti neurali artificiali sono modelli del cervello umano, ma invece di utilizzare neuroni, utilizzano nodi interconnessi o "neuroni artificiali" per elaborare le informazioni. 
+ Questi nodi lavorano insieme per formare strati di unità di elaborazione che possono imparare da grandi quantità di dati. 
+ Più dati ricevono questi sistemi, migliori diventano nel riconoscere modelli e prendere decisioni in autonomia. 
+ Quindi alza le vele e imbarcati in una grande avventura con il deep learning, amici miei! È un tesoro di conoscenza che attende solo di essere scoperto. 
+ Buoni venti e mari calmi!

- Persona: Ora spiegalo come uno chef.

+ Assistente: Certamente! Ecco una spiegazione del deep learning da un punto di vista culinario:

+ Pensate a una ricetta come una sequenza di passaggi utilizzati per trasformare ingredienti grezzi in un delizioso piatto. 
+ Allo stesso modo, nel deep learning, ci sono più strati di "ingredienti" (o caratteristiche) che vengono combinati e trasformati attraverso varie operazioni per produrre un risultato finale o una previsione. 
+ Ogni strato si basa su quello precedente, affinando e migliorando la rappresentazione dei dati di input fino a raggiungere l'obiettivo desiderato. 
+ Proprio come le diverse spezie esaltano i sapori del cibo, ogni strato aggiunge profondità e complessità alla comprensione del modello dell'input, consentendogli di effettuare previsioni più accurate. 
+ E proprio come gli chef devono bilanciare con cura la quantità di sale, zucchero e altre spezie in una ricetta, il processo di addestramento nel deep learning prevede l'aggiustamento dei pesi e dei bias della rete per ottenere la perfetta combinazione di precisione ed efficienza. 
+ Con sufficiente pratica e sperimentazione, chiunque può diventare un vero chef... intendo dire, un maestro del deep learning! Buon appetito!

L’intera conversazione è formattata come segue:

<|system|> messaggio di sistema <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...

Conclusioni

Siamo riusciti a ottimizzare con successo il modello Llama da 70B utilizzando PyTorch FSDP in un ambiente multi-nodo multi-gpu, affrontando vari sfide. Abbiamo visto come 🤗 Transformers e 🤗 Accelerates ora supportano un modo efficiente di inizializzare modelli di grandi dimensioni quando si utilizza FSDP per superare i problemi di memoria RAM della CPU. Ciò è stato seguito dalle pratiche consigliate per il salvataggio/caricamento di checkpoint intermedi e come salvare il modello finale in modo da poterlo utilizzare facilmente. Per consentire un addestramento più rapido e ridurre l’utilizzo della memoria GPU, abbiamo sottolineato l’importanza di Flash Attention e Gradient Checkpointing. Nel complesso, possiamo vedere come una semplice configurazione utilizzando 🤗 Accelerate consente di ottimizzare tali modelli di grandi dimensioni in un ambiente multi-nodo multi-gpu.