Backpropagation e il problema del gradiente che si annulla in RNN (Parte 2)

Backpropagation e il problema del gradiente in RNN (Parte 2)

Come viene ridotto in LSTM

https://unsplash.com/photos/B22I8wnon34

Nella prima parte di questa serie, abbiamo esaminato la retropropagazione in un modello RNN e spiegato sia con le formule che mostrando numeri il problema del gradiente che si annulla in RNN. In questo articolo, spiegheremo come possiamo risolvere parzialmente il problema del gradiente che si annulla con LSTM, anche se non scompare completamente e con sequenze molto lunghe il problema persiste ancora.

Motivazione

Come abbiamo visto nella prima parte di questa serie, l’RNN vanilla memorizza le informazioni temporali nello stato nascosto che viene aggiornato ad ogni passo temporale quando vengono aggiunte nuove informazioni, ossia quando viene elaborato un nuovo token in una sequenza. Poiché lo stato nascosto viene aggiornato ad ogni passo, le vecchie informazioni vengono sovrascritte e la rete dimentica ciò che ha visto in passato. Per evitare ciò, abbiamo bisogno di una memoria separata e di un meccanismo che decida cosa scrivere al suo interno, dato che nuove informazioni, cosa eliminare dal passato che non sarà utile in futuro e cosa passare allo stato successivo. LSTM fa esattamente questo: aggiunge una cella di memoria che memorizza informazioni a lungo termine e ha un meccanismo di gating che viene utilizzato per decidere cosa dimenticare dal passato, cosa aggiungere dall’input corrente e cosa passare avanti.

Propagazione in avanti

Figura dell'autore (0)

Vediamo come viene eseguita la propagazione in avanti attraverso il tempo in un modello LSTM. Dato una sequenza di N token e assumendo di aver ricevuto una cella di memoria c(t-1) e uno stato nascosto h(t-1) dalla cella precedente, al passo temporale t calcoliamo i gating per decidere cosa fare con le nuove informazioni in arrivo. Innanzitutto, calcoliamo le attivazioni:

Figura dell'autore (1)

Ricordiamo che tutti i pesi sono condivisi tra i passi temporali. La matrice delle attivazioni viene quindi divisa in 4 matrici, ognuna di dimensione H, e applicando una funzione di attivazione sigmoide alle prime 3 e tangente iperbolica all’ultima, calcoliamo i gating:

Figura dell'autore (2)
Figura dell'autore (3)

Notare come tutti i gating siano funzioni dell’input e dello stato nascosto precedente.

Infine, calcoliamo la cella di memoria corrente c(t) e lo stato nascosto h(t) che verranno passati al passo successivo.

Figura dell'autore (4)

I valori dei gating calcolati hanno le seguenti funzionalità:

  • gate f: cosa dimenticare dalla precedente cella di memoria c(t-1). Notare che poiché facciamo una moltiplicazione elemento per elemento (ricordiamo che c(t-1) e h(t-1) sono vettori) e f contiene valori compresi tra 0 e 1 a causa della funzione di attivazione sigmoide, annullerà o ridurrà le informazioni in c(t-1) quando i valori di f saranno uguali o più vicini a 0 e manterrà tutte o quasi tutte le informazioni quando i valori di f saranno uguali o vicini a 1.
  • gate g: può essere interpretato come il vettore di aggiornamento della cella di memoria che viene combinato con la precedente cella di memoria c(t-1) per calcolare la nuova cella di memoria c(t). Diversamente dagli altri gating, alla funzione di attivazione a(g) viene applicata una funzione tangente iperbolica che restituisce un valore compreso tra -1 e 1. Ciò serve per consentire alla cella di memoria di aumentare e diminuire, poiché se avessimo una funzione di attivazione sigmoide, gli elementi della cella di memoria non potrebbero mai diminuire.
  • gate i: cosa scrivere dal vettore di aggiornamento della cella di memoria (gate g) alla precedente cella di memoria c(t-1).
  • gate o: cosa includere nel nuovo stato nascosto h(t)

Questi gate vengono quindi combinati, come illustrato nella Figura 4, per calcolare la nuova cella di memoria c(t) e lo stato nascosto h(t). Queste nuove celle e lo stato nascosto vengono quindi passati alla successiva cella LSTM che ripete lo stesso processo nuovamente. Tutto questo processo può essere illustrato nel diagramma sottostante:

Fonte http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (5)

Dopo di ciò, per ogni stato nascosto, calcoliamo l’output e la perdita:

Figura dell'autore (6)

Nel codice:

def softmax(x, axis=2):    p = np.exp(x - np.max(x, axis=axis,keepdims=True))    return p / np.sum(p, axis=axis, keepdims=True)def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b):    next_h, next_c, cache = None, None, None        h = x @ Wx + prev_h @ Wh + b    assert h.shape[-1] % 4 == 0    ai, af, ao, ag = np.array_split(h, 4, axis=-1)    i = sigmoid(ai)    f = sigmoid(af)    o = sigmoid(ao)    g = np.tanh(ag)    next_c = f * prev_c + i * g     next_h = o * np.tanh(next_c)        cache = (x, next_h, prev_h, prev_c, Wx, Wh, h, np.tanh(next_c), i, f, o ,g)    return next_h, next_c, cachenp.random.seed(232)# N - Dimensione batch# D - Dimensione dell'embedding# V - Dimensione del vocabolario# H - Dimensione nascosta# T - timestepsN, D, T, H, V = 2, 5, 3, 4, 4x  = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, H)Wh = np.random.randn(H, H)Wy = np.random.randn(H, V)b  = np.random.randn(H)y = np.random.randint(V, size=(N, T))mask = np.ones((N, T))all_cache = []h = np.zeros((N, T, H))    next_c = np.zeros((N, H))    for t in range(T):    xt = x[:, t , :]    if t == 0:        next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)        all_cache.append(cache_s)    else:        next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)           all_cache.append(cache_s)    h[:, t, :] = next_h ft = h @ Wyout = softmax(ft)

Backpropagation

Fonte https://www.iitg.ac.in/cseweb/osint/neural/slides/L8.pdf (7)

Le formule per la retropropagazione sono un po’ più complesse rispetto a quelle della RNN standard. In questo tutorial, andremo a derivare i gradienti rispetto a Wx per poi mostrare come LSTM gestisce i gradienti che svaniscono. Le derivate rispetto agli altri parametri possono essere analogamente derivate ed è lasciato come esercizio al lettore. Il codice, tuttavia, contiene le derivate rispetto a tutti i gradienti e è possibile verificare i risultati basandosi sul codice. La derivata della perdita rispetto allo stato nascosto è ancora la stessa della RNN in quanto nulla cambia lì, poiché la perdita prende solo lo stato nascosto come input:

Figura dell'autore (8)

Ora troviamo le derivate rispetto ad altri singoli componenti:

Figura dell'autore (9)

Si noti che per comodità abbiamo separato dct/dat e dht/dat, e ovunque abbiamo dht/dct dct/dat lo scriviamo direttamente come dht/dat. Inoltre, poiché faremo la retropropagazione nella forma matriciale, concateniamo le derivate dei gate nel seguente modo:

Figura dell'autore (10)

La somma in dht/dat deriva dal fatto che abbiamo due direzioni (vedi Figura 7) – una che va nella cella precedente e l’altra che va nello stato nascosto. Con la stessa logica del flusso del gradiente, la derivata di dct/dc(t-1) è la seguente:

Figura dell'autore (11)

Ora deriviamo il gradiente totale rispetto a Wx. Questo è dato dalla somma delle singole perdite rispetto a Wx come descritto nella parte 1 di questa serie:

Figura dell'autore (12)

Concentrandoci sulla perdita individuale, ad esempio dL3/dWx, quando propaghiamo da L3 a Wx, Wx appare in tutti i componenti degli istanti di tempo, quindi dovremo sommare tutti questi componenti per ottenere il gradiente completo di L3 rispetto a Wx. Abusando leggermente la notazione matematica, stiamo facendo qualcosa del genere (ricorda che Wx3 = Wx2 = Wx1):

Figura dell'autore (13)

Il primo componente sarà come segue. Inoltre, sostituiamo dht/dct dct/dat con dht/dat in modo da poter usare direttamente quella derivata

Figura dell'autore (14)

Salto dL3/dWx2 per brevità e passo direttamente al terzo componente. Abbiamo:

Figura dell'autore (15)

Come in precedenza, sostituiamo ovunque abbiamo dht/dct dct/dat con dht/dat in modo da poter usare direttamente quella derivata:

Figura dell'autore (16)

Sommandoli, otteniamo la derivata di dL3/dWx. Per ottenere la derivata di dWx rispetto alla perdita totale, dovremo aggiungere a dL3/dWx, dL2/dWx e dL1/dWx.

Figura dell'autore (17)

Nel codice:

def lstm_forward(x, h0, Wx, Wh, b, next_c=None):    h, cache = None, None       cache = []    N, T, _ = x.shape    H = h0.shape[-1]    h = np.zeros((N, T, H))    if next_c is None:        next_c = np.zeros((N, H))    for t in range(x.shape[1]):        xt = x[:, t , :]        if t == 0:            next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)            cache.append(cache_s)        else:            next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)               cache.append(cache_s)            h[:, t, :] = next_h     return h, cachedef dc_da(h, prev_c, next_c_t, i, f, o, g):    dgrad_c = np.zeros((h.shape[0], 4 * h.shape[1]))    dgrad_h = np.zeros((h.shape[0], 4 * h.shape[1]))    # assert dgrad.shape[1] % 4 == 0    H = dgrad.shape[1] // 4        # compute gradients wrt ai, af, ao and ag from two flows - next_h and next_c    dnextc_dai = (i * (1-i)) * g    dnextc_daf = (f * (1-f)) * prev_c    dnextc_dao = 0    dnextc_dag = (1 - g**2) * i        dh_dc = o * (1 - next_c_t**2)    dnexth_dai = dh_dc * dnextc_dai    dnexth_daf = dh_dc * dnextc_daf    dnexth_dao = (o * (1-o) * next_c_t)    dnexth_dag = dh_dc * dnextc_dag    # join them together in a matrix at this point to conveniently compute    # downstream gradients     dgrad_c[:, 0:H] = dnextc_dai     dgrad_c[:, H:2*H] = dnextc_daf     dgrad_c[:, 2*H:3*H] = dnextc_dao     dgrad_c[:, 3*H:4*H] = dnextc_dag         dgrad_h[:, 0:H] =  dnexth_dai    dgrad_h[:, H:2*H] = dnexth_daf    dgrad_h[:, 2*H:3*H] = dnexth_dao    dgrad_h[:, 3*H:4*H] = dnexth_dag    return dgrad_c, dgrad_hnp.random.seed(1)N, D, T, H = 1, 3, 3, 1x = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, 4 * H)Wh = np.random.randn(H, 4 * H)b = np.random.randn(4 * H)out, cache = lstm_forward(x, h0, Wx, Wh, b)# let's define the dout instead of deriving them for simplicitydout = np.random.randn(*out.shape)    # dL3/dWvxdnext_c2 = np.zeros((h0.shape))dnext_h2 = dout[:, -1, :](x2, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t2, i2, f2, o2 ,g2) = cache[2]dgrad_c2, dgrad_h2 = dc_da(h0, cache[2][3], cache[2][-5], cache[2][-4],  cache[2][-3], cache[2][-2], cache[2][-1]) dL3_dWx2 = x2.T @ (dgrad_h2 * dnext_h2 + dgrad_c2 * dnext_c2)print(dL3_dWx2)dnext_c1 = dnext_c2 * f2 + dnext_h2 * o2 * (1 - next_c_t2**2) * f2dnext_h1 = (dnext_h2 * dgrad_h2 +  dnext_c2 * dgrad_c2) @ Wh.T(x1, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t1, i1, f1, o1 ,g1) = cache[1]dgrad_c1, dgrad_h1 = dc_da(h0, cache[1][3], cache[1][-5], cache[1][-4],  cache[1][-3], cache[1][-2], cache[1][-1])    dL3_dWx1 = x1.T @ (dnext_c1 * dgrad_c1 + dnext_h1 * dgrad_h1)print(dL3_dWx1)dnext_c0 = dnext_c1 * f1 + dnext_h1 * o1 * (1 - next_c_t1**2) * f1dnext_h0 = (dnext_h1 * dgrad_h1 + dnext_c1 * dgrad_c1) @ Wh.T(x0, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t0, i0, f0, o0 ,g0) = cache[0]dgrad_c0, dgrad_h0 = dc_da(h0, cache[0][3], cache[0][-5], cache[0][-4],  cache[0][-3], cache[0][-2], cache[0][-1])    dL3_dWx0 = x0.T @ (dnext_c0 * dgrad_c0 + dnext_h0 * dgrad_h0)print(dL3_dWx0)

Output:

[[-0.02349287  0.00135057 -0.11156069 -0.05284914] [ 0.01024921 -0.00058921  0.04867045  0.02305643] [-0.00429567  0.00024695 -0.02039889 -0.00966347]][[-9.83990139e-03  6.78775168e-05 -1.10660923e-03  4.20773125e-04] [ 7.93641636e-03 -5.47469140e-05  8.92540613e-04 -3.39376441e-04] [-2.11067811e-02  1.45598602e-04 -2.37369846e-03  9.02566589e-04]][[-1.95768961e-05  0.00000000e+00  2.77411349e-05 -9.76467796e-03] [ 7.37299593e-06  0.00000000e+00 -1.04477887e-05  3.67754574e-03] [ 6.36561888e-06  0.00000000e+00 -9.02030083e-06  3.17508036e-03]]

losses_dWx = {i : {x_comp : 0 for x_comp in range(i)} for i in range(T)}dWx = np.zeros((D, 4 * H))dWh = np.zeros((H, 4 * H))db = np.zeros((4 * H, ))for idx in range(T-1, -1, -1):    print(f"Perdita {idx + 1}")    dnext_c = np.zeros((h0.shape))    dnext_h =  dout[:, idx, :]    for j in range(idx, -1, -1):        (x, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t, i, f, o ,g) = cache[j]                dgrad_c, dgrad_h = dc_da(h0, prev_c, next_c_t, i, f, o, g)                 dgrad = dnext_c * dgrad_c + dnext_h * dgrad_h        losses_dWx[idx][j] = x.T @ dgrad                dnext_c = dnext_c * f + dnext_h * o * (1 - next_c_t**2) * f        dnext_h = (dnext_h * dgrad_h +  dnext_c * dgrad_c) @ Wh.T        dnext_h = dgrad @ Wh.T          # accumula il gradiente di dWx e altri parametri per ogni perdita        dWx += x.T @ dgrad        dWh += prev_h.T @ dgrad        db += dgrad.sum(0)        print(f"componente {j} - ", np.linalg.norm(losses_dWx[idx][j]))

Gradiente che scompare in LSTM

Come nella parte 1 per RNN, vediamo i gradienti per la Perdita L3 per ogni componente:

Perdita 3componente 0 - 0.010906688399113558componente 1 - 0.02478099846737857componente 2 - 0.13901933055672275

Dal precedente, possiamo vedere che X3, che è il più vicino a L3, ha ancora l’aggiornamento più grande, mentre X1 e X2 contribuiscono meno all’aggiornamento di Wx1. Per RNN questa differenza è molto più grande. Infatti, il gradiente che passa attraverso lo stato nascosto subirà il gradiente che scompare per lo stesso motivo di RNN – i termini Wh (dat/dh(t-1)) appaiono ancora nella retropropagazione, ad esempio qui in dL3/dW(x-1):

Figura dell'autore da Figura 15 (18)

Tuttavia, il gradiente che fluisce attraverso la cella che è ancora una funzione dell’input e dello stato nascosto non ha i termini Wh ma i termini sigmoide invece (vedere la formula per il gate dimenticato ft in Figura 3):

Figura dell'autore da Figura 15 (18)

Ricordiamo che dct/dc(t-1) = ft. Quindi, se il gate di dimenticanza è alto, cioè vicino a 1, allora il problema del gradiente che tende a svanire avviene a un ritmo molto più lento rispetto alla RNN standard, ma avverrà comunque a meno che tutti i gate di dimenticanza siano esattamente 1, cosa che non accade nella pratica.

Conclusioni

Il punto principale di questo articolo era capire, attraverso la derivazione della retropropagazione, che LSTM soffre ancora del problema del gradiente che tende a svanire nella pratica, tuttavia a un ritmo molto più basso rispetto alla RNN standard grazie allo stato della cella, che fa decadere il gradiente al ritmo del gate di dimenticanza anziché al ritmo di Wx. Se trovi degli errori, per favore fammelo sapere nei commenti.

Riferimenti

  • https://web.stanford.edu/class/cs224n/slides/cs224n-2021-lecture06-fancy-rnn.pdf
  • http://cs231n.stanford.edu/assignments.html