Hugging Face su PyTorch / XLA TPUs

Hugging Face su PyTorch/XLA TPUs

Allenare i Tuoi Transformers Preferiti su Cloud TPUs utilizzando PyTorch / XLA

Il progetto PyTorch-TPU è nato come uno sforzo collaborativo tra i team di Facebook PyTorch e Google TPU ed è stato ufficialmente lanciato alla conferenza PyTorch Developer Conference 2019. Da allora, abbiamo lavorato con il team di Hugging Face per portare un supporto di prima classe all’allenamento su Cloud TPUs utilizzando PyTorch / XLA. Questa nuova integrazione consente agli utenti di PyTorch di eseguire e scalare i loro modelli su Cloud TPUs mantenendo la stessa interfaccia dei trainer di Hugging Face.

Questo post del blog fornisce una panoramica delle modifiche apportate nella libreria di Hugging Face, cosa fa la libreria di PyTorch / XLA, un esempio per iniziare ad allenare i tuoi transformers preferiti su Cloud TPUs e alcuni benchmark delle prestazioni. Se non puoi aspettare di iniziare con TPUs, vai direttamente alla sezione “Allenare il tuo Transformer su Cloud TPUs” – gestiamo tutte le meccaniche di PyTorch / XLA per te all’interno del modulo Trainer!

Tipo di Dispositivo XLA:TPU

PyTorch / XLA aggiunge un nuovo tipo di dispositivo xla a PyTorch. Questo tipo di dispositivo funziona proprio come gli altri tipi di dispositivo di PyTorch. Ad esempio, ecco come creare e stampare un tensore XLA:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

Questo codice dovrebbe sembrarti familiare. PyTorch / XLA utilizza la stessa interfaccia di PyTorch regolare con alcune aggiunte. Importando torch_xla si inizializza PyTorch / XLA e xm.xla_device() restituisce il dispositivo XLA corrente. Questo può essere una CPU, una GPU o una TPU a seconda del tuo ambiente, ma in questo post ci concentreremo principalmente sulla TPU.

Il modulo Trainer utilizza una classe di dati TrainingArguments per definire i dettagli dell’allenamento. Gestisce diversi argomenti, dalle dimensioni dei batch, all’learning rate, all’accumulo del gradiente e altri, fino ai dispositivi utilizzati. Basandoci su quanto detto sopra, in TrainingArguments._setup_devices() quando si utilizzano dispositivi XLA:TPU, restituiamo semplicemente il dispositivo TPU da utilizzare da Trainer:

@dataclass
class TrainingArguments:
    ...
    @cached_property
    @torch_required
    def _setup_devices(self) -> Tuple["torch.device", int]:
        ...
        elif is_torch_tpu_available():
            device = xm.xla_device()
            n_gpu = 0
        ...

        return device, n_gpu

Calcolo del Passo del Dispositivo XLA

In uno scenario tipico di allenamento XLA:TPU, stiamo allenando su più core TPU in parallelo (un singolo dispositivo Cloud TPU include 8 core TPU). Quindi è necessario assicurarsi che tutti i gradienti vengano scambiati tra le repliche parallele dei dati consolidando i gradienti e facendo un passo con l’ottimizzatore. A tal scopo, forniamo il metodo xm.optimizer_step(optimizer) che esegue la consolidazione dei gradienti e il passo dell’ottimizzatore. Nel trainer di Hugging Face, aggiorniamo di conseguenza il train step utilizzando le API di PyTorch / XLA:

class Trainer:
…
   def train(self, *args, **kwargs):
       ...
                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)

Pipeline di Input PyTorch / XLA

Ci sono due parti principali nell’esecuzione di un modello PyTorch / XLA: (1) tracciare ed eseguire il grafo del tuo modello in modo pigro (fare riferimento alla sezione “Libreria PyTorch / XLA” in basso per una spiegazione più approfondita) e (2) alimentare il tuo modello. Senza alcuna ottimizzazione, il tracciamento / esecuzione del tuo modello e l’alimentazione dell’input verrebbero eseguiti in modo seriale, lasciando delle pause durante le quali la CPU dell’host e gli acceleratori TPU sarebbero inattivi rispettivamente. Per evitare questo, forniamo un’API che permette di sovrapporre il tracciamento del passo n+1 mentre il passo n è ancora in esecuzione.

import torch_xla.distributed.parallel_loader as pl
...
  dataloader = pl.MpDeviceLoader(dataloader, device)

Scrittura e Caricamento dei Checkpoint

Quando un tensore viene checkpointed da un dispositivo XLA e poi caricato nuovamente dal checkpoint, verrà caricato nuovamente nel dispositivo originale. Prima di checkpointing dei tensori nel tuo modello, assicurati che tutti i tuoi tensori siano su dispositivi CPU invece che su dispositivi XLA. In questo modo, quando carichi nuovamente i tensori, li caricherai tramite dispositivi CPU e poi avrai l’opportunità di posizionarli su qualsiasi dispositivo XLA desideri. Forniamo l’API xm.save() per questo, che si occupa già di scrivere solo in una posizione di archiviazione da un solo processo su ogni host (o uno globalmente se si utilizza un sistema di file condiviso tra gli host).

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
…
    def save_pretrained(self, save_directory):
        ...
        if getattr(self.config, "xla_device", False):
            import torch_xla.core.xla_model as xm

            if xm.is_master_ordinal():
                # Salva il file di configurazione
                model_to_save.config.save_pretrained(save_directory)
            # xm.save si occupa di salvare solo dal master
            xm.save(state_dict, output_model_file)

class Trainer:
…
   def train(self, *args, **kwargs):
       ...
       if is_torch_tpu_available():
           xm.rendezvous("saving_optimizer_states")
           xm.save(self.optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
           xm.save(self.lr_scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))

Libreria PyTorch / XLA

PyTorch / XLA è un pacchetto Python che utilizza il compilatore di algebra lineare XLA per collegare il framework di deep learning PyTorch con i dispositivi XLA, che includono CPU, GPU e Cloud TPU. Parte dei contenuti seguenti è disponibile anche nella nostra API_GUIDE.md .

I tensori di PyTorch / XLA sono Lazy

Utilizzare i tensori e i dispositivi XLA richiede di cambiare solo poche righe di codice. Tuttavia, anche se i tensori XLA si comportano molto come i tensori CPU e CUDA, le loro parti interne sono diverse. I tensori CPU e CUDA avviano le operazioni immediatamente o in modo diligente. I tensori XLA, d’altra parte, sono lazy. Registrano le operazioni in un grafo fino a quando i risultati non sono necessari. Ritardando l’esecuzione in questo modo, XLA ottimizza il tutto. Un grafo di operazioni separate potrebbe essere fuso in un’unica operazione ottimizzata.

L’esecuzione lazy è in generale invisibile per il chiamante. PyTorch / XLA costruisce automaticamente i grafi, li invia ai dispositivi XLA e si sincronizza quando si copiano i dati tra un dispositivo XLA e la CPU. Inserendo una barriera durante l’esecuzione di un passo di ottimizzazione, si sincronizza esplicitamente la CPU con il dispositivo XLA.

Ciò significa che quando chiami model(input) forward pass, calcoli la tua loss loss.backward() e fai un passo di ottimizzazione xm.optimizer_step(optimizer), il grafo di tutte le operazioni viene costruito in background. Solo quando valuti esplicitamente il tensore (ad esempio, stampando il tensore o spostandolo su un dispositivo CPU) o segni un passo (questo verrà fatto dal MpDeviceLoader ogni volta che lo attraversi), l’intero passo viene eseguito.

Traccia, Compila, Esegui e Ripeti

Dal punto di vista dell’utente, un tipico regime di addestramento per un modello che viene eseguito su PyTorch / XLA comporta l’esecuzione di un forward pass, backward pass e un passo di ottimizzazione. Dal punto di vista della libreria PyTorch / XLA, le cose appaiono un po’ diverse.

Mentre un utente esegue i passaggi di forward e backward, viene tracciato un grafo di rappresentazione intermedia (IR) in tempo reale. Il grafo IR che porta a ciascun tensore di radice/output può essere ispezionato come segue:

>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>> t = torch.tensor(1, device=xm.xla_device())
>>> s = t*t
>>> print(torch_xla._XLAC._get_xla_tensors_text([s]))
IR {
  %0 = s64[] prim::Constant(), value=1
  %1 = s64[] prim::Constant(), value=0
  %2 = s64[] xla::as_strided_view_update(%1, %0), size=(), stride=(), storage_offset=0
  %3 = s64[] aten::as_strided(%2), size=(), stride=(), storage_offset=0
  %4 = s64[] aten::mul(%3, %3), ROOT=0
}

Questo grafico in tempo reale si accumula mentre vengono eseguiti i passaggi in avanti e all’indietro sul programma dell’utente e una volta chiamato xm.mark_step() (indirettamente tramite pl.MpDeviceLoader), il grafico dei tensori in tempo reale viene tagliato. Questa troncatura segna il completamento di un passo e successivamente abbassiamo il grafico IR in Operazioni di Livello Superiore XLA (HLO), che è il linguaggio IR per XLA.

Questo grafico HLO viene quindi compilato in un binario TPU e successivamente eseguito sui dispositivi TPU. Tuttavia, questa fase di compilazione può essere costosa, richiedendo di solito più tempo di un singolo passo, quindi se compilassimo il programma dell’utente ad ogni singolo passo, l’overhead sarebbe elevato. Per evitare ciò, abbiamo delle cache che memorizzano i binari TPU compilati in base ai loro identificatori di hash unici dei grafici HLO. Quindi, una volta che questa cache dei binari TPU è stata popolata al primo passo, i passi successivi di solito non dovranno ricompilare nuovi binari TPU; invece, possono semplicemente cercare i binari necessari nella cache.

Dato che le compilazioni TPU sono di solito molto più lente rispetto al tempo di esecuzione del passo, ciò significa che se il grafico cambia continuamente di forma, avremo dei cache miss e compileremo troppo frequentemente. Per ridurre al minimo i costi di compilazione, consigliamo di mantenere le forme del tensore statiche quando possibile. Le forme della libreria Hugging Face sono già statiche per la maggior parte, con i token di input che vengono riempiti in modo appropriato, quindi durante l’addestramento la cache dovrebbe essere colpita costantemente. Ciò può essere verificato utilizzando gli strumenti di debug forniti da PyTorch / XLA. Nell’esempio sottostante, puoi vedere che la compilazione è avvenuta solo 5 volte (CompileTime), mentre l’esecuzione è avvenuta in ogni singolo dei 1220 passaggi (ExecuteTime):

>>> import torch_xla.debug.metrics as met
>>> print(met.metrics_report())
Metrica: CompileTime
  TotaleCampioni: 5
  Accumulatore: 28s920ms153.731us
  VelocitàValore: 092ms152.037us / secondo
  Velocità: 0.0165028 / secondo
  Percentili: 1%=428ms053.505us; 5%=428ms053.505us; 10%=428ms053.505us; 20%=03s640ms888.060us; 50%=03s650ms126.150us; 80%=11s110ms545.595us; 90%=11s110ms545.595us; 95%=11s110ms545.595us; 99%=11s110ms545.595us
Metrica: DeviceLockWait
  TotaleCampioni: 1281
  Accumulatore: 38s195ms476.007us
  VelocitàValore: 151ms051.277us / secondo
  Velocità: 4.54374 / secondo
  Percentili: 1%=002.895us; 5%=002.989us; 10%=003.094us; 20%=003.243us; 50%=003.654us; 80%=038ms978.659us; 90%=192ms495.718us; 95%=208ms893.403us; 99%=221ms394.520us
Metrica: ExecuteTime
  TotaleCampioni: 1220
  Accumulatore: 04m22s555ms668.071us
  VelocitàValore: 923ms872.877us / secondo
  Velocità: 4.33049 / secondo
  Percentili: 1%=045ms041.018us; 5%=213ms379.757us; 10%=215ms434.912us; 20%=217ms036.764us; 50%=219ms206.894us; 80%=222ms335.146us; 90%=227ms592.924us; 95%=231ms814.500us; 99%=239ms691.472us
Contatore: CachedCompile
  Valore: 1215
Contatore: CreateCompileHandles
  Valore: 5
...

Allenare il tuo Transformer su TPUs cloud

Per configurare la tua VM e le TPUs cloud, segui le sezioni “Configurare un’istanza di Compute Engine” e “Lanciare una risorsa TPU cloud” (versione pytorch-1.7 al momento della scrittura). Una volta creati la tua VM e il TPU cloud, utilizzarli è semplice come accedere tramite SSH alla tua VM GCE ed eseguire i seguenti comandi per avviare l’addestramento di bert-large-uncased (la dimensione del batch è per il dispositivo v3-8, potrebbe causare OOM su v2-8):

conda activate torch-xla-1.7
export TPU_IP_ADDRESS="INSERISCI_IL_TUO_INDIRIZZO_IP_TPU"  # es. 10.0.0.2
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
git clone -b v4.2.2 https://github.com/huggingface/transformers.git
cd transformers && pip install .
pip install datasets==1.2.1
python examples/xla_spawn.py \
  --num_cores 8 \
  examples/language-modeling/run_mlm.py \
  --dataset_name wikitext \
  --dataset_config_name wikitext-103-raw-v1 \
  --max_seq_length 512 \
  --pad_to_max_length \
  --logging_dir ./tensorboard-metrics \
  --cache_dir ./cache_dir \
  --do_train \
  --do_eval \
  --overwrite_output_dir \
  --output_dir language-modeling \
  --overwrite_cache \
  --tpu_metrics_debug \
  --model_name_or_path bert-large-uncased \
  --num_train_epochs 3 \
  --per_device_train_batch_size 8 \
  --per_device_eval_batch_size 8 \
  --save_steps 500000

Quanto sopra dovrebbe completare l’addestramento in circa meno di 200 minuti con una perplexity di valutazione di ~3.25.

Valutazione delle prestazioni

La seguente tabella mostra le prestazioni dell’addestramento di bert-large-uncased su un sistema Cloud TPU v3-8 (contenente 4 chip TPU v3) in esecuzione su PyTorch / XLA. Il dataset utilizzato per tutte le misurazioni di benchmarking è il dataset WikiText103, e utilizziamo lo script run_mlm.py fornito negli esempi di Hugging Face. Per garantire che i carichi di lavoro non siano limitati dalla CPU dell’host, utilizziamo la configurazione CPU n1-standard-96 per questi test, ma potresti essere in grado di utilizzare configurazioni più piccole senza influire sulle prestazioni.

Inizia con PyTorch / XLA su TPUs

Vedi la sezione “Esecuzione su TPUs” negli esempi di Hugging Face per iniziare. Per una descrizione più dettagliata delle nostre API, consulta la nostra API_GUIDE, e per le migliori pratiche di prestazioni, dai un’occhiata alla nostra guida TROUBLESHOOTING. Per esempi generici di PyTorch / XLA, esegui i seguenti notebook di Colab che offriamo con accesso gratuito a Cloud TPU. Per eseguire direttamente su GCP, consulta i nostri tutorial contrassegnati come “PyTorch” sul nostro sito di documentazione.

Hai altre domande o problemi? Apri una segnalazione o una domanda su https://github.com/huggingface/transformers/issues o direttamente su https://github.com/pytorch/xla/issues.