Backpropagation e il problema del gradiente che si annulla in RNN (Parte 2)
Backpropagation e il problema del gradiente in RNN (Parte 2)
Come viene ridotto in LSTM
Nella prima parte di questa serie, abbiamo esaminato la retropropagazione in un modello RNN e spiegato sia con le formule che mostrando numeri il problema del gradiente che si annulla in RNN. In questo articolo, spiegheremo come possiamo risolvere parzialmente il problema del gradiente che si annulla con LSTM, anche se non scompare completamente e con sequenze molto lunghe il problema persiste ancora.
Motivazione
Come abbiamo visto nella prima parte di questa serie, l’RNN vanilla memorizza le informazioni temporali nello stato nascosto che viene aggiornato ad ogni passo temporale quando vengono aggiunte nuove informazioni, ossia quando viene elaborato un nuovo token in una sequenza. Poiché lo stato nascosto viene aggiornato ad ogni passo, le vecchie informazioni vengono sovrascritte e la rete dimentica ciò che ha visto in passato. Per evitare ciò, abbiamo bisogno di una memoria separata e di un meccanismo che decida cosa scrivere al suo interno, dato che nuove informazioni, cosa eliminare dal passato che non sarà utile in futuro e cosa passare allo stato successivo. LSTM fa esattamente questo: aggiunge una cella di memoria che memorizza informazioni a lungo termine e ha un meccanismo di gating che viene utilizzato per decidere cosa dimenticare dal passato, cosa aggiungere dall’input corrente e cosa passare avanti.
Propagazione in avanti

Vediamo come viene eseguita la propagazione in avanti attraverso il tempo in un modello LSTM. Dato una sequenza di N token e assumendo di aver ricevuto una cella di memoria c(t-1) e uno stato nascosto h(t-1) dalla cella precedente, al passo temporale t calcoliamo i gating per decidere cosa fare con le nuove informazioni in arrivo. Innanzitutto, calcoliamo le attivazioni:
- Creazione di visualizzazioni interattive dei dati in Python Un’introduzione a Plotly
- Come l’IA generativa può supportare le aziende del settore alimentare
- Python Lists La Guida Definitiva per Lavorare con Collezioni di Dati Ordinate

Ricordiamo che tutti i pesi sono condivisi tra i passi temporali. La matrice delle attivazioni viene quindi divisa in 4 matrici, ognuna di dimensione H, e applicando una funzione di attivazione sigmoide alle prime 3 e tangente iperbolica all’ultima, calcoliamo i gating:


Notare come tutti i gating siano funzioni dell’input e dello stato nascosto precedente.
Infine, calcoliamo la cella di memoria corrente c(t) e lo stato nascosto h(t) che verranno passati al passo successivo.

I valori dei gating calcolati hanno le seguenti funzionalità:
- gate f: cosa dimenticare dalla precedente cella di memoria c(t-1). Notare che poiché facciamo una moltiplicazione elemento per elemento (ricordiamo che c(t-1) e h(t-1) sono vettori) e f contiene valori compresi tra 0 e 1 a causa della funzione di attivazione sigmoide, annullerà o ridurrà le informazioni in c(t-1) quando i valori di f saranno uguali o più vicini a 0 e manterrà tutte o quasi tutte le informazioni quando i valori di f saranno uguali o vicini a 1.
- gate g: può essere interpretato come il vettore di aggiornamento della cella di memoria che viene combinato con la precedente cella di memoria c(t-1) per calcolare la nuova cella di memoria c(t). Diversamente dagli altri gating, alla funzione di attivazione a(g) viene applicata una funzione tangente iperbolica che restituisce un valore compreso tra -1 e 1. Ciò serve per consentire alla cella di memoria di aumentare e diminuire, poiché se avessimo una funzione di attivazione sigmoide, gli elementi della cella di memoria non potrebbero mai diminuire.
- gate i: cosa scrivere dal vettore di aggiornamento della cella di memoria (gate g) alla precedente cella di memoria c(t-1).
- gate o: cosa includere nel nuovo stato nascosto h(t)
Questi gate vengono quindi combinati, come illustrato nella Figura 4, per calcolare la nuova cella di memoria c(t) e lo stato nascosto h(t). Queste nuove celle e lo stato nascosto vengono quindi passati alla successiva cella LSTM che ripete lo stesso processo nuovamente. Tutto questo processo può essere illustrato nel diagramma sottostante:

Dopo di ciò, per ogni stato nascosto, calcoliamo l’output e la perdita:

Nel codice:
def softmax(x, axis=2): p = np.exp(x - np.max(x, axis=axis,keepdims=True)) return p / np.sum(p, axis=axis, keepdims=True)def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b): next_h, next_c, cache = None, None, None h = x @ Wx + prev_h @ Wh + b assert h.shape[-1] % 4 == 0 ai, af, ao, ag = np.array_split(h, 4, axis=-1) i = sigmoid(ai) f = sigmoid(af) o = sigmoid(ao) g = np.tanh(ag) next_c = f * prev_c + i * g next_h = o * np.tanh(next_c) cache = (x, next_h, prev_h, prev_c, Wx, Wh, h, np.tanh(next_c), i, f, o ,g) return next_h, next_c, cachenp.random.seed(232)# N - Dimensione batch# D - Dimensione dell'embedding# V - Dimensione del vocabolario# H - Dimensione nascosta# T - timestepsN, D, T, H, V = 2, 5, 3, 4, 4x = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, H)Wh = np.random.randn(H, H)Wy = np.random.randn(H, V)b = np.random.randn(H)y = np.random.randint(V, size=(N, T))mask = np.ones((N, T))all_cache = []h = np.zeros((N, T, H)) next_c = np.zeros((N, H)) for t in range(T): xt = x[:, t , :] if t == 0: next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b) all_cache.append(cache_s) else: next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b) all_cache.append(cache_s) h[:, t, :] = next_h ft = h @ Wyout = softmax(ft)
Backpropagation

Le formule per la retropropagazione sono un po’ più complesse rispetto a quelle della RNN standard. In questo tutorial, andremo a derivare i gradienti rispetto a Wx per poi mostrare come LSTM gestisce i gradienti che svaniscono. Le derivate rispetto agli altri parametri possono essere analogamente derivate ed è lasciato come esercizio al lettore. Il codice, tuttavia, contiene le derivate rispetto a tutti i gradienti e è possibile verificare i risultati basandosi sul codice. La derivata della perdita rispetto allo stato nascosto è ancora la stessa della RNN in quanto nulla cambia lì, poiché la perdita prende solo lo stato nascosto come input:

Ora troviamo le derivate rispetto ad altri singoli componenti:

Si noti che per comodità abbiamo separato dct/dat e dht/dat, e ovunque abbiamo dht/dct dct/dat lo scriviamo direttamente come dht/dat. Inoltre, poiché faremo la retropropagazione nella forma matriciale, concateniamo le derivate dei gate nel seguente modo:

La somma in dht/dat deriva dal fatto che abbiamo due direzioni (vedi Figura 7) – una che va nella cella precedente e l’altra che va nello stato nascosto. Con la stessa logica del flusso del gradiente, la derivata di dct/dc(t-1) è la seguente:

Ora deriviamo il gradiente totale rispetto a Wx. Questo è dato dalla somma delle singole perdite rispetto a Wx come descritto nella parte 1 di questa serie:

Concentrandoci sulla perdita individuale, ad esempio dL3/dWx, quando propaghiamo da L3 a Wx, Wx appare in tutti i componenti degli istanti di tempo, quindi dovremo sommare tutti questi componenti per ottenere il gradiente completo di L3 rispetto a Wx. Abusando leggermente la notazione matematica, stiamo facendo qualcosa del genere (ricorda che Wx3 = Wx2 = Wx1):

Il primo componente sarà come segue. Inoltre, sostituiamo dht/dct dct/dat con dht/dat in modo da poter usare direttamente quella derivata

Salto dL3/dWx2 per brevità e passo direttamente al terzo componente. Abbiamo:

Come in precedenza, sostituiamo ovunque abbiamo dht/dct dct/dat con dht/dat in modo da poter usare direttamente quella derivata:

Sommandoli, otteniamo la derivata di dL3/dWx. Per ottenere la derivata di dWx rispetto alla perdita totale, dovremo aggiungere a dL3/dWx, dL2/dWx e dL1/dWx.

Nel codice:
def lstm_forward(x, h0, Wx, Wh, b, next_c=None): h, cache = None, None cache = [] N, T, _ = x.shape H = h0.shape[-1] h = np.zeros((N, T, H)) if next_c is None: next_c = np.zeros((N, H)) for t in range(x.shape[1]): xt = x[:, t , :] if t == 0: next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b) cache.append(cache_s) else: next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b) cache.append(cache_s) h[:, t, :] = next_h return h, cachedef dc_da(h, prev_c, next_c_t, i, f, o, g): dgrad_c = np.zeros((h.shape[0], 4 * h.shape[1])) dgrad_h = np.zeros((h.shape[0], 4 * h.shape[1])) # assert dgrad.shape[1] % 4 == 0 H = dgrad.shape[1] // 4 # compute gradients wrt ai, af, ao and ag from two flows - next_h and next_c dnextc_dai = (i * (1-i)) * g dnextc_daf = (f * (1-f)) * prev_c dnextc_dao = 0 dnextc_dag = (1 - g**2) * i dh_dc = o * (1 - next_c_t**2) dnexth_dai = dh_dc * dnextc_dai dnexth_daf = dh_dc * dnextc_daf dnexth_dao = (o * (1-o) * next_c_t) dnexth_dag = dh_dc * dnextc_dag # join them together in a matrix at this point to conveniently compute # downstream gradients dgrad_c[:, 0:H] = dnextc_dai dgrad_c[:, H:2*H] = dnextc_daf dgrad_c[:, 2*H:3*H] = dnextc_dao dgrad_c[:, 3*H:4*H] = dnextc_dag dgrad_h[:, 0:H] = dnexth_dai dgrad_h[:, H:2*H] = dnexth_daf dgrad_h[:, 2*H:3*H] = dnexth_dao dgrad_h[:, 3*H:4*H] = dnexth_dag return dgrad_c, dgrad_hnp.random.seed(1)N, D, T, H = 1, 3, 3, 1x = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, 4 * H)Wh = np.random.randn(H, 4 * H)b = np.random.randn(4 * H)out, cache = lstm_forward(x, h0, Wx, Wh, b)# let's define the dout instead of deriving them for simplicitydout = np.random.randn(*out.shape) # dL3/dWvxdnext_c2 = np.zeros((h0.shape))dnext_h2 = dout[:, -1, :](x2, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t2, i2, f2, o2 ,g2) = cache[2]dgrad_c2, dgrad_h2 = dc_da(h0, cache[2][3], cache[2][-5], cache[2][-4], cache[2][-3], cache[2][-2], cache[2][-1]) dL3_dWx2 = x2.T @ (dgrad_h2 * dnext_h2 + dgrad_c2 * dnext_c2)print(dL3_dWx2)dnext_c1 = dnext_c2 * f2 + dnext_h2 * o2 * (1 - next_c_t2**2) * f2dnext_h1 = (dnext_h2 * dgrad_h2 + dnext_c2 * dgrad_c2) @ Wh.T(x1, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t1, i1, f1, o1 ,g1) = cache[1]dgrad_c1, dgrad_h1 = dc_da(h0, cache[1][3], cache[1][-5], cache[1][-4], cache[1][-3], cache[1][-2], cache[1][-1]) dL3_dWx1 = x1.T @ (dnext_c1 * dgrad_c1 + dnext_h1 * dgrad_h1)print(dL3_dWx1)dnext_c0 = dnext_c1 * f1 + dnext_h1 * o1 * (1 - next_c_t1**2) * f1dnext_h0 = (dnext_h1 * dgrad_h1 + dnext_c1 * dgrad_c1) @ Wh.T(x0, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t0, i0, f0, o0 ,g0) = cache[0]dgrad_c0, dgrad_h0 = dc_da(h0, cache[0][3], cache[0][-5], cache[0][-4], cache[0][-3], cache[0][-2], cache[0][-1]) dL3_dWx0 = x0.T @ (dnext_c0 * dgrad_c0 + dnext_h0 * dgrad_h0)print(dL3_dWx0)
Output:
[[-0.02349287 0.00135057 -0.11156069 -0.05284914] [ 0.01024921 -0.00058921 0.04867045 0.02305643] [-0.00429567 0.00024695 -0.02039889 -0.00966347]][[-9.83990139e-03 6.78775168e-05 -1.10660923e-03 4.20773125e-04] [ 7.93641636e-03 -5.47469140e-05 8.92540613e-04 -3.39376441e-04] [-2.11067811e-02 1.45598602e-04 -2.37369846e-03 9.02566589e-04]][[-1.95768961e-05 0.00000000e+00 2.77411349e-05 -9.76467796e-03] [ 7.37299593e-06 0.00000000e+00 -1.04477887e-05 3.67754574e-03] [ 6.36561888e-06 0.00000000e+00 -9.02030083e-06 3.17508036e-03]]
losses_dWx = {i : {x_comp : 0 for x_comp in range(i)} for i in range(T)}dWx = np.zeros((D, 4 * H))dWh = np.zeros((H, 4 * H))db = np.zeros((4 * H, ))for idx in range(T-1, -1, -1): print(f"Perdita {idx + 1}") dnext_c = np.zeros((h0.shape)) dnext_h = dout[:, idx, :] for j in range(idx, -1, -1): (x, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t, i, f, o ,g) = cache[j] dgrad_c, dgrad_h = dc_da(h0, prev_c, next_c_t, i, f, o, g) dgrad = dnext_c * dgrad_c + dnext_h * dgrad_h losses_dWx[idx][j] = x.T @ dgrad dnext_c = dnext_c * f + dnext_h * o * (1 - next_c_t**2) * f dnext_h = (dnext_h * dgrad_h + dnext_c * dgrad_c) @ Wh.T dnext_h = dgrad @ Wh.T # accumula il gradiente di dWx e altri parametri per ogni perdita dWx += x.T @ dgrad dWh += prev_h.T @ dgrad db += dgrad.sum(0) print(f"componente {j} - ", np.linalg.norm(losses_dWx[idx][j]))
Gradiente che scompare in LSTM
Come nella parte 1 per RNN, vediamo i gradienti per la Perdita L3 per ogni componente:
Perdita 3componente 0 - 0.010906688399113558componente 1 - 0.02478099846737857componente 2 - 0.13901933055672275
Dal precedente, possiamo vedere che X3, che è il più vicino a L3, ha ancora l’aggiornamento più grande, mentre X1 e X2 contribuiscono meno all’aggiornamento di Wx1. Per RNN questa differenza è molto più grande. Infatti, il gradiente che passa attraverso lo stato nascosto subirà il gradiente che scompare per lo stesso motivo di RNN – i termini Wh (dat/dh(t-1)) appaiono ancora nella retropropagazione, ad esempio qui in dL3/dW(x-1):

Tuttavia, il gradiente che fluisce attraverso la cella che è ancora una funzione dell’input e dello stato nascosto non ha i termini Wh ma i termini sigmoide invece (vedere la formula per il gate dimenticato ft in Figura 3):

Ricordiamo che dct/dc(t-1) = ft. Quindi, se il gate di dimenticanza è alto, cioè vicino a 1, allora il problema del gradiente che tende a svanire avviene a un ritmo molto più lento rispetto alla RNN standard, ma avverrà comunque a meno che tutti i gate di dimenticanza siano esattamente 1, cosa che non accade nella pratica.
Conclusioni
Il punto principale di questo articolo era capire, attraverso la derivazione della retropropagazione, che LSTM soffre ancora del problema del gradiente che tende a svanire nella pratica, tuttavia a un ritmo molto più basso rispetto alla RNN standard grazie allo stato della cella, che fa decadere il gradiente al ritmo del gate di dimenticanza anziché al ritmo di Wx. Se trovi degli errori, per favore fammelo sapere nei commenti.
Riferimenti
- https://web.stanford.edu/class/cs224n/slides/cs224n-2021-lecture06-fancy-rnn.pdf
- http://cs231n.stanford.edu/assignments.html