Comprensione dell’attenzione sparsa a blocchi di BigBird

'BigBird's sparse attention comprehension'

Introduzione

I modelli basati su trasformatori si sono dimostrati molto utili per molte attività di NLP. Tuttavia, un limite importante dei modelli basati su trasformatori è la complessità temporale e di memoria O(n^2) (dove n è la lunghezza della sequenza). Di conseguenza, è molto costoso dal punto di vista computazionale applicare modelli basati su trasformatori a sequenze lunghe n > 512. Diversi recenti articoli, ad esempio Longformer, Performer, Reformer, Clustered attention, cercano di rimediare a questo problema approssimando la matrice di attenzione completa. Puoi consultare il recente post sul blog di 🤗 nel caso in cui questi modelli ti siano sconosciuti.

BigBird (introdotto in paper) è uno di questi modelli recenti per affrontare questo problema. BigBird si basa sull’attenzione sparsa a blocchi invece dell’attenzione normale (cioè l’attenzione di BERT) e può gestire sequenze fino a una lunghezza di 4096 a un costo computazionale molto inferiore rispetto a BERT. Ha raggiunto il miglior risultato su varie attività che coinvolgono sequenze molto lunghe come la sintesi di documenti lunghi, la risposta a domande con contesti lunghi.

Il modello BigBird simile a RoBERTa è ora disponibile in 🤗Transformers. Lo scopo di questo post è dare al lettore una comprensione approfondita dell’implementazione di BigBird e facilitare l’utilizzo di BigBird con 🤗Transformers. Ma, prima di approfondire ulteriormente, è importante ricordare che l’attenzione di BigBird è un’approssimazione dell’attenzione completa di BERT e quindi non mira a essere migliore dell’attenzione completa di BERT, ma piuttosto a essere più efficiente. Semplicemente consente di applicare modelli basati su trasformatori a sequenze molto più lunghe poiché il requisito di memoria quadratica di BERT diventa rapidamente insopportabile. In poche parole, se avessimo ∞ calcolo e ∞ tempo, l’attenzione di BERT sarebbe preferita rispetto all’attenzione sparsa a blocchi (di cui parleremo in questo post).

Se ti chiedi perché abbiamo bisogno di ulteriori calcoli quando lavoriamo con sequenze più lunghe, questo post sul blog è perfetto per te!


Alcune delle principali domande che potresti avere quando lavori con l’attenzione standard simile a BERT includono:

  • Tutti i token devono davvero partecipare all’attenzione su tutti gli altri token?
  • Perché non calcolare l’attenzione solo sui token importanti?
  • Come decidere quali token sono importanti?
  • Come attirare l’attenzione solo su pochi token in modo molto efficiente?

In questo post sul blog, cercheremo di rispondere a queste domande.

A quali token dovrebbe essere prestata attenzione?

Daremo un esempio pratico di come funziona l’attenzione considerando la frase “BigBird è ora disponibile in HuggingFace per la risposta estrattiva alle domande”. Nell’attenzione simile a BERT, ogni parola semplicemente parteciperebbe all’attenzione su tutti gli altri token. Matematicamente, questo significherebbe che ogni token interrogato query-token ∈ { BigBird , is , now , available , in , HuggingFace , for , extractive , question , answering } , parteciperebbe all’intera lista di token chiave = [ BigBird , is , now , available , in , HuggingFace , for , extractive , question , answering ].

Pensiamo a una scelta sensata di token chiave a cui un token interrogato dovrebbe effettivamente prestare attenzione scrivendo del pseudo-codice. Assumeremo che il token available sia interrogato e costruiremo una lista sensata di token chiave a cui prestare attenzione.

>>> # consideriamo la seguente frase come esempio
>>> example = ['BigBird', 'è', 'ora', 'disponibile', 'in', 'HuggingFace', 'per', 'il', 'question', 'answering']

>>> # supponiamo inoltre che stiamo cercando di capire la rappresentazione di 'disponibile'
>>> query_token = 'disponibile'

>>> # Inizieremo un insieme vuoto e riempiremo i token di nostro interesse man mano che procediamo in questa sezione.
>>> key_tokens = [] # => attualmente il token 'disponibile' non ha nulla a cui prestare attenzione

I token vicini dovrebbero essere importanti perché, in una frase (sequenza di parole), la parola corrente dipende fortemente dai token precedenti e successivi. Questa intuizione è l’idea alla base del concetto di sliding attention.

>>> # considerando `window_size = 3`, considereremo 1 token a sinistra e 1 a destra di 'disponibile'
>>> # token a sinistra: 'ora' ; token a destra: 'in'
>>> sliding_tokens = ["ora", "disponibile", "in"]

>>> # aggiorniamo la nostra collezione con i token sopra
>>> key_tokens.append(sliding_tokens)

Dependenze a lungo raggio: Per alcune attività, è cruciale catturare relazioni a lungo raggio tra i token. Ad esempio, nel `question-answering` il modello deve confrontare ogni token del contesto con l’intera domanda per poter capire quale parte del contesto è utile per una risposta corretta. Se la maggior parte dei token del contesto dovesse prestare attenzione solo ad altri token del contesto, ma non alla domanda, diventa molto più difficile per il modello filtrare i token di contesto importanti da quelli meno importanti.

Ora, BigBird propone due modi per consentire dipendenze di attenzione a lungo termine rimanendo computazionalmente efficienti.

  • Token globali: Introdurre alcuni token che presteranno attenzione a ogni token e che saranno prestati da ogni token. Ad esempio: “HuggingFace sta creando ottime librerie per un facile NLP”. Supponiamo ora che ‘creando’ sia definito come un token globale e che il modello debba conoscere la relazione tra ‘NLP’ e ‘HuggingFace’ per qualche compito (nota: questi 2 token sono agli estremi); aver fatto prestare attenzione a ‘creando’ a tutti gli altri token probabilmente aiuterà il modello ad associare ‘NLP’ a ‘HuggingFace’.
>>> # supponiamo che il primo e l'ultimo token siano `global`, quindi
>>> global_tokens = ["BigBird", "answering"]

>>> # riempiamo i token globali nella nostra collezione di token chiave
>>> key_tokens.append(global_tokens)
  • Token casuali: Selezionare alcuni token casualmente che trasferiranno informazioni trasferendosi ad altri token che a loro volta possono trasferirsi ad altri token. Questo può ridurre il costo del trasferimento delle informazioni da un token all’altro.
>>> # ora possiamo scegliere `r` token casualmente dalla nostra frase di esempio
>>> # scegliamo 'è' assumendo `r=1`
>>> random_tokens = ["è"] # Nota: viene scelto completamente casualmente; quindi potrebbe essere qualsiasi altra cosa.

>>> # riempiamo i token casuali nella nostra collezione
>>> key_tokens.append(random_tokens)

>>> # è ora di vedere quali token sono nella nostra lista `key_tokens`
>>> key_tokens
{'ora', 'è', 'in', 'answering', 'disponibile', 'BigBird'}

# In questo modo, il token di interrogazione presta attenzione solo a un sottoinsieme di tutti i possibili token, producendo un'approssimazione buona dell'attenzione completa. Lo stesso approccio viene utilizzato per tutti gli altri token interrogati. Ma ricorda, il punto principale qui è approssimare l'attenzione completa di BERT nel modo più efficiente possibile. Semplicemente fare sì che ogni token interrogato presti attenzione a tutti i token chiave, come viene fatto per BERT, può essere calcolato molto efficacemente come una sequenza di moltiplicazioni matriciali su hardware moderno, come le GPU. Tuttavia, una combinazione di attenzione scorrevole, globale e casuale sembra implicare una moltiplicazione di matrici sparse, che è più difficile da implementare in modo efficiente su hardware moderno. Uno dei principali contributi di BigBird è la proposta di un meccanismo di attenzione block sparse che consente di calcolare l'attenzione scorrevole, globale e casuale in modo efficace. Approfondiamolo!

Capire la necessità delle chiavi globali, scorrevoli e casuali con i Grafi

Prima di tutto, cerchiamo di capire meglio l'attenzione globale, scorrevole e casuale utilizzando i grafici e cerchiamo di capire come la combinazione di questi tre meccanismi di attenzione produce un'ottima approssimazione dell'attenzione standard Bert-like.

<p+Nella figura sopra si mostrano rispettivamente le connessioni globali (sinistra), scorrevoli (centro) e casuali (destra) come grafo. Ogni nodo corrisponde a un token e ogni linea rappresenta un punteggio di attenzione. Se non viene stabilita una connessione tra due token, allora si assume un punteggio di attenzione pari a 0.

L'attenzione sparso a blocchi BigBird è una combinazione di connessioni scorrevoli, globali e casuali (10 connessioni in totale) come mostrato nella gif a sinistra. Mentre un grafo di attenzione normale (a destra) avrà tutte e 15 le connessioni (nota: sono presenti in totale 6 nodi). Si può pensare semplicemente all'attenzione normale come tutti i token che partecipano globalmente 1 {}^1 1 .

Attenzione normale: Il modello può trasferire informazioni da un token a un altro direttamente in un singolo strato, poiché ogni token viene interrogato su ogni altro token ed è partecipato da ogni altro token. Consideriamo un esempio simile a quanto mostrato nelle figure precedenti. Se il modello ha bisogno di associare 'andare' con 'ora', può farlo semplicemente in un singolo strato poiché c'è una connessione diretta che unisce entrambi i token.

Attenzione sparso a blocchi: Se il modello ha bisogno di condividere informazioni tra due nodi (o token), le informazioni dovranno viaggiare attraverso vari altri nodi nel percorso per alcuni dei token; poiché tutti i nodi non sono direttamente collegati in un singolo strato. Ad esempio, supponendo che il modello debba associare 'andare' con 'ora', quindi se è presente solo l'attenzione scorrevole, il flusso di informazioni tra quei 2 token è definito dal percorso: andare -> sono -> io -> ora (cioè dovrà viaggiare su 2 altri token). Pertanto, potremmo avere bisogno di più strati per catturare l'intera informazione della sequenza. L'attenzione normale può catturare tutto questo in un singolo strato. In un caso estremo, ciò potrebbe significare che sono necessari tanti strati quanti sono i token di input. Tuttavia, se introduciamo alcuni token globali, le informazioni possono viaggiare tramite il percorso: andare -> io -> ora (che è più breve). Se inoltre introduciamo connessioni casuali, può viaggiare tramite: andare -> sono -> ora . Con l'aiuto delle connessioni casuali e delle connessioni globali, le informazioni possono viaggiare molto rapidamente (con solo pochi strati) da un token al successivo.

Nel caso in cui abbiamo molti token globali, allora potremmo non avere bisogno di connessioni casuali poiché ci saranno percorsi brevi multipli attraverso i quali le informazioni possono viaggiare. Questa è l'idea alla base del mantenimento di num_random_tokens = 0 quando si lavora con una variante di BigBird, chiamata ETC (ne parleremo nelle sezioni successive).

1 {}^1 1 In queste grafiche, si assume che la matrice di attenzione sia simmetrica, ovvero A i j = A j i \mathbf{A}_{ij} = \mathbf{A}_{ji} A i j ​ = A j i ​ poiché in un grafo se un token A partecipa a B , allora B parteciperà anche ad A . Si può vedere dalla figura della matrice di attenzione mostrata nella prossima sezione che questa assunzione vale per la maggior parte dei token in BigBird

original_full rappresenta l'attenzione di BERT mentre block_sparse rappresenta l'attenzione di BigBird. Vi state chiedendo cosa sia block_size? Lo affronteremo nelle sezioni successive. Per ora, consideratelo pari a 1 per semplicità

Attenzione sparsa a blocchi di BigBird

L'attenzione sparsa a blocchi di BigBird è semplicemente un'implementazione efficiente di quanto discusso in precedenza. Ogni token si concentra su alcuni token globali, token scorrevoli e token casuali anziché su tutti gli altri token. Gli autori hanno codificato separatamente la matrice di attenzione per componenti di query multiple; e hanno utilizzato un trucco interessante per velocizzare l'addestramento/inferenza su GPU e TPU.

Nota: in alto abbiamo 2 frasi extra. Come puoi notare, ogni token è semplicemente spostato di una posizione in entrambe le frasi. Questo è come è implementata l'attenzione scorrevole. Quando q[i] viene moltiplicato per k[i,0:3], otterremo un punteggio di attenzione scorrevole per q[i] (dove i è l'indice dell'elemento nella sequenza).

Puoi trovare l'implementazione effettiva dell'attenzione block_sparse qui . Questo potrebbe sembrare molto spaventoso 😨😨 ora. Ma questo articolo sicuramente renderà più facile la comprensione del codice.

Attenzione globale

Per l'attenzione globale, ogni query si concentra semplicemente su tutti gli altri token della sequenza ed è presa in considerazione da ogni altro token. Supponiamo che Vasudev (primo token) e them (ultimo token) siano globali (nella figura precedente). Puoi vedere che questi token sono direttamente collegati a tutti gli altri token (scatole blu).

# pseudo codice

Q -> Matrice di query (lunghezza_seq, dim_testa)
K -> Matrice chiave (lunghezza_seq, dim_testa)

# Il primo e l'ultimo token si concentrano su tutti gli altri token
Q[0] x [K[0], K[1], K[2], ......, K[n-1]]
Q[n-1] x [K[0], K[1], K[2], ......, K[n-1]]

# Il primo e l'ultimo token vengono presi in considerazione da tutti gli altri token
K[0] x [Q[0], Q[1], Q[2], ......, Q[n-1]]
K[n-1] x [Q[0], Q[1], Q[2], ......, Q[n-1]]

Attenzione scorrevole

La sequenza di token chiave viene copiata 2 volte, con ogni elemento spostato a destra in una delle copie e a sinistra nell'altra copia. Ora, se moltiplichiamo i vettori di sequenza di query per questi 3 vettori di sequenza, copriremo tutti i token scorrevoli. La complessità computazionale è semplicemente O(3xn) = O(n). Riferendoci all'immagine precedente, le caselle arancioni rappresentano l'attenzione scorrevole. Puoi vedere 3 sequenze nella parte superiore dell'immagine, di cui 2 spostate di un token (1 a sinistra, 1 a destra).

# cosa vogliamo fare
Q[i] x [K[i-1], K[i], K[i+1]] per i = 1:-1

# implementazione efficiente nel codice (supponiamo una moltiplicazione per prodotto scalare 👇)
[Q[0], Q[1], Q[2], ......, Q[n-2], Q[n-1]] x [K[1], K[2], K[3], ......, K[n-1], K[0]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[n-1], K[0], K[1], ......, K[n-2]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[0], K[1], K[2], ......, K[n-1]]

# Ogni sequenza viene moltiplicata solo per 3 sequenze per mantenere `window_size = 3`.
# Alcuni calcoli potrebbero mancare; questa è solo un'idea approssimativa.

Attenzione casuale

L'attenzione casuale garantisce che ogni token di query si concentri anche su alcuni token casuali. Per l'implementazione effettiva, ciò significa che il modello raccoglie alcuni token casualmente e calcola il loro punteggio di attenzione.

# r1, r2, r sono alcuni indici casuali; Nota: r1, r2, r3 sono diversi per ogni riga 👇
Q[1] x [Q[r1], Q[r2], ......, Q[r]]
.
.
.
Q[n-2] x [Q[r1], Q[r2], ......, Q[r]]

# lasciando il token 0 e il token (n-1) poiché sono già globali

Nota: L'implementazione attuale divide ulteriormente la sequenza in blocchi e ogni notazione è definita rispetto al blocco anziché ai token. Discutiamone in dettaglio nella prossima sezione.

Implementazione

Ricapitolando: Nell'attenzione regolare di BERT, una sequenza di token, cioè X = x 1 , x 2 , . . . . , x n, viene proiettata attraverso uno strato denso in Q , K , V e il punteggio di attenzione Z viene calcolato come Z = S o f t m a x ( Q K T ) . Nel caso dell'attenzione sparso a blocchi di BigBird, viene utilizzato lo stesso algoritmo, ma solo con alcuni vettori di query e chiave selezionati.

Analizziamo come è implementata l'attenzione sparso a blocchi di BigBird. Per cominciare, assumiamo che b , r , s , g rappresentino la block_size, la num_random_blocks, la num_sliding_blocks e la num_global_blocks rispettivamente. Visivamente, possiamo illustrare i componenti dell'attenzione sparso a blocchi di BigBird con b = 4 , r = 1 , g = 2 , s = 3 , d = 5 come segue:

I punteggi di attenzione per q 1 , q 2 , q 3 : n − 2 , q n − 1 , q n vengono calcolati separatamente come descritto di seguito:


Il punteggio di attenzione per q 1 rappresentato da a 1 , dove a 1 = S o f t m a x ( q 1 ∗ K T ) , è semplicemente il punteggio di attenzione tra tutti i token nel primo blocco con tutti gli altri token nella sequenza.

q 1 rappresenta il primo blocco, g i rappresenta l'i-esimo blocco. Stiamo semplicemente eseguendo un'operazione di attenzione normale tra q 1 e g (cioè tutte le chiavi).


Per calcolare il punteggio di attenzione per i token nel secondo blocco, raccogliamo i primi tre blocchi, l'ultimo blocco e il quinto blocco. Quindi possiamo calcolare a 2 = S o f t m a x ( q 2 ∗ c o n c a t ( k 1 , k 2 , k 3 , k 5 , k 7 ) .

Rappresento i token con g, r, s solo per mostrare esplicitamente la loro natura (cioè mostrando token globali, casuali, scorrevoli), altrimenti sono solo k.


Per calcolare il punteggio di attenzione per q 3 : n − 2, raccoglieremo le chiavi globali, scorrevoli e casuali e calcoleremo l'operazione di attenzione normale su q 3 : n − 2 e le chiavi raccolte. Nota che le chiavi scorrevoli vengono raccolte utilizzando il trucco speciale dello scorrimento discusso in precedenza nella sezione di attenzione scorrevole.


Per calcolare il punteggio di attenzione per i token nel blocco precedente all'ultimo (cioè q n − 1 {q}_{n-1} q n − 1 ​), stiamo raccogliendo il primo blocco, gli ultimi tre blocchi e il terzo blocco. Quindi possiamo applicare la formula a n − 1 = S o f t m a x ( q n − 1 ∗ c o n c a t ( k 1 , k 3 , k 5 , k 6 , k 7 ) ) {a}_{n-1} = Softmax({q}_{n-1} * concat(k_1, k_3, k_5, k_6, k_7)) a n − 1 ​ = S o f t m a x ( q n − 1 ​ ∗ c o n c a t ( k 1 ​ , k 3 ​ , k 5 ​ , k 6 ​ , k 7 ​ ) ). Questo è molto simile a quello che abbiamo fatto per q 2 q_2 q 2 ​.


Il punteggio di attenzione per q n \mathbf{q}_{n} q n ​ è rappresentato da a n a_n a n ​ dove a n = S o f t m a x ( q n ∗ K T ) a_n=Softmax(q_n * K^T) a n ​ = S o f t m a x ( q n ​ ∗ K T ) , e non è altro che il punteggio di attenzione tra tutti i token nell'ultimo blocco con tutti gli altri token nella sequenza. Questo è molto simile a quello che abbiamo fatto per q 1 q_1 q 1 ​.


Uniamo le matrici sopra per ottenere la matrice di attenzione finale. Questa matrice di attenzione può essere utilizzata per ottenere una rappresentazione di tutti i token.

blu -> blocchi globali, rosso -> blocchi casuali, arancione -> blocchi scorrevoli Questa matrice di attenzione è solo a scopo illustrativo. Durante il passaggio in avanti, non memorizziamo i blocchi bianchi, ma calcoliamo direttamente una matrice di valori ponderata (ovvero la rappresentazione di ogni token) per ciascun componente separato come discusso in precedenza.

Ora abbiamo affrontato la parte più difficile dell'attenzione sparsa a blocchi, ovvero la sua implementazione. Speriamo che ora tu abbia una migliore comprensione del codice effettivo. Sentiti libero di analizzarlo e collegare ogni parte del codice a uno dei componenti sopra menzionati.

Complessità temporale e di memoria

Confronto della complessità temporale e spaziale dell'attenzione di BERT e dell'attenzione sparsa a blocchi di BigBird.

Espandi questo snippet nel caso in cui desideri vedere i calcoli

Complessità temporale di BigBird = O(w x n + r x n + g x n)
Complessità temporale di BERT = O(n^2)

Ipotesi:
    w = 3 x 64
    r = 3 x 64
    g = 2 x 64

Quando seqlen = 512
=> complessità temporale di BERT = 512^2

Quando seqlen = 1024
=> complessità temporale di BERT = (2 x 512)^2
=> complessità temporale di BERT = 4 x 512^2

=> complessità temporale di BigBird = (8 x 64) x (2 x 512)
=> complessità temporale di BigBird = 2 x 512^2

Quando seqlen = 4096
=> complessità temporale di BERT = (8 x 512)^2
=> complessità temporale di BERT = 64 x 512^2

=> complessità di calcolo in BigBird = (8 x 64) x (8 x 512)
=> complessità di calcolo in BigBird = 8 x (512 x 512)
=> complessità temporale di BigBird = 8 x 512^2

ITC vs ETC

Il modello BigBird può essere addestrato utilizzando 2 diverse strategie: ITC e ETC. ITC (internal transformer construction) è semplicemente ciò di cui abbiamo discusso in precedenza. In ETC (extended transformer construction), alcuni token aggiuntivi vengono resi globali in modo che partecipino a / siano coinvolti da tutti i token.

ITC richiede meno calcolo poiché pochi token sono globali, ma allo stesso tempo il modello può catturare informazioni globali sufficienti (anche con l'aiuto dell'attenzione casuale). D'altra parte, ETC può essere molto utile per compiti in cui sono necessari molti token globali, come la `domanda-risposta, per la quale l'intera domanda deve essere considerata globalmente dal contesto per poter correlare correttamente il contesto alla domanda.

Nota: Nel paper Big Bird è mostrato che in molti esperimenti ETC, il numero di blocchi casuali è impostato su 0. Questo è ragionevole date le nostre discussioni sopra nella sezione del grafo.

La tabella qui sotto riassume ITC & ETC:

Usare BigBird con 🤗Transformers

Puoi utilizzare BigBirdModel come qualsiasi altro modello 🤗. Vediamo del codice di seguito:

from transformers import BigBirdModel

# caricamento di bigbird dal suo checkpoint preaddestrato
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
# Questo inizializzerà il modello con la configurazione predefinita, ovvero attention_type = "block_sparse", num_random_blocks = 3, block_size = 64.
# Ma puoi liberamente cambiare questi argomenti con qualsiasi checkpoint. Questi 3 argomenti cambieranno solo il numero di token a cui ogni token di query parteciperà.
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", num_random_blocks=2, block_size=16)

# Impostando attention_type su `original_full`, BigBird si baserà sull'attenzione completa di complessità n^2. In questo modo, BigBird sarà simile al 99,9% a BERT.
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")

Ci sono in totale 3 checkpoint disponibili in 🤗Hub (al momento della stesura di questo articolo): bigbird-roberta-base, bigbird-roberta-large, bigbird-base-trivia-itc. I primi due checkpoint provengono dal preaddestramento di BigBirdForPretraining con perdita di masked_lm, mentre l'ultimo corrisponde al checkpoint dopo il fine-tuning di BigBirdForQuestionAnswering sul dataset trivia-qa.

Diamo un'occhiata al codice minimo che puoi scrivere (nel caso in cui desideri utilizzare il tuo trainer PyTorch) per utilizzare il modello BigBird di 🤗 per il fine-tuning dei tuoi compiti.

# prendiamo in considerazione il nostro compito di domanda-risposta come esempio

from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

# inizializziamo il modello bigbird dai pesi preaddestrati con una testa inizializzata casualmente in cima
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base", block_size=64, num_random_blocks=3)
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model.to(device)

dataset = "oggetto torch.utils.data.DataLoader"
optimizer = "oggetto torch.optim"
epochs = ...

# ciclo di addestramento molto minimale
for e in range(epochs):
    for batch in dataset:
        model.train()
        batch = {k: batch[k].to(device) for k in batch}

        # passaggio in avanti
        output = model(**batch)

        # retropropagazione
        output["loss"].backward()
        optimizer.step()
        optimizer.zero_grad()

# salviamo i pesi finali in una directory locale
model.save_pretrained("<TUA-CART-DEI-PESI>")

# pubblichiamo i nostri pesi su 🤗Hub
from huggingface_hub import ModelHubMixin
ModelHubMixin.push_to_hub("<TUA-CART-DEI-PESI>", model_id="<TUA-ID-FINE-TUNING>")

# utilizzando il modello fine-tuned per l'infereza
domanda = ["Come stai?", "Come va la vita?"]
contesto = ["<un grande contesto con risposta-1>", "<un grande contesto con risposta-2>"]
batch = tokenizer(domanda, contesto, return_tensors="pt")
batch = {k: batch[k].to(device) for k in batch}

model = BigBirdForQuestionAnswering.from_pretrained("<TUA-ID-FINE-TUNING>")
model.to(device)
with torch.no_grad():
    start_logits, end_logits = model(**batch).to_tuple()
    # decodifica start_logits, end_logits con la strategia che preferisci.

# Nota:
# Questo era un codice molto minimale (nel caso in cui si desideri utilizzare PyTorch grezzo) solo per mostrare come BigBird può essere utilizzato molto facilmente
# Suggerirei di utilizzare 🤗Trainer per avere accesso a molte funzionalità

È importante tenere presente i seguenti punti durante il lavoro con Big Bird:

  • La lunghezza della sequenza deve essere un multiplo della dimensione del blocco, ovvero seqlen % block_size = 0. Non devi preoccuparti perché 🤗Transformers eseguirà automaticamente il <pad> (al multiplo più piccolo della dimensione del blocco che è maggiore della lunghezza della sequenza) se la lunghezza della sequenza del batch non è un multiplo di block_size.
  • Attualmente, la versione di HuggingFace non supporta ETC e quindi solo il primo e l'ultimo blocco saranno globali.
  • L'implementazione attuale non supporta num_random_blocks = 0.
  • Gli autori raccomandano di impostare attention_type = "original_full" quando la lunghezza della sequenza < 1024.
  • Deve valere: seq_length > global_token + random_tokens + sliding_tokens + buffer_tokens dove global_tokens = 2 x block_size, sliding_tokens = 3 x block_size, random_tokens = num_random_blocks x block_size e buffer_tokens = num_random_blocks x block_size. Nel caso in cui non riesci a farlo, 🤗Transformers passerà automaticamente attention_type a original_full con un avviso.
  • Quando si utilizza big bird come decoder (o si utilizza BigBirdForCasualLM), attention_type dovrebbe essere original_full. Ma non devi preoccuparti, 🤗Transformers passerà automaticamente attention_type a original_full nel caso in cui dimentichi di farlo.

Qual è il prossimo passo?

@patrickvonplaten ha realizzato un notebook davvero interessante su come valutare BigBirdForQuestionAnswering sul dataset trivia-qa. Sentiti libero di utilizzare BigBird utilizzando quel notebook.

Presto troverai il modello BigBird simile a Pegasus nella libreria per la riassunzione di documenti lunghi 💥.

Note finali

L'implementazione originale della matrice di attenzione sparsa a blocchi può essere trovata qui. Puoi trovare la versione di 🤗 qui.