L’implementazione del momento di Nesterov di PyTorch è sbagliata?

Implementazione errata del momento di Nesterov in PyTorch?

Il momentum aiuta SGD a attraversare più efficientemente paesaggi di perdita complessi. Foto di Maxim Berg su Unsplash.

Introduzione

Se osservi attentamente la documentazione di SGD di PyTorch, noterai che la loro implementazione del momentum di Nesterov ha alcune differenze rispetto alla formulazione presente nell’articolo originale. In particolare, l’implementazione di PyTorch valuta il gradiente sui parametri correnti, mentre il punto principale del momentum di Nesterov è valutare il gradiente sui parametri spostati. Purtroppo, sembra che la discussione su queste discrepanze su internet sia scarsa. In questo post, esamineremo e spiegheremo le differenze tra l’implementazione di PyTorch e la formulazione originale del momentum di Nesterov. Alla fine, vedremo come l’implementazione di PyTorch non sia sbagliata, ma piuttosto un’approssimazione, e speculeremo sui benefici della loro implementazione.

Le Formulazioni

L’articolo originale descrive il momentum di Nesterov utilizzando le seguenti regole di aggiornamento:

dove v_{t+1} e θ_{t+1} sono rispettivamente il vettore velocità e i parametri del modello al tempo t, μ è il fattore di momentum e ε è il tasso di apprendimento. La nota nella documentazione di SGD di PyTorch afferma che utilizzano le seguenti regole di aggiornamento:

dove g_{t+1} rappresenta il gradiente utilizzato per calcolare v_{t+1}. Possiamo espandere la regola di aggiornamento per θ_{t+1} in questo modo:

Da ciò possiamo dedurre che:

e le regole di aggiornamento diventano:

Queste sono le regole di aggiornamento utilizzate da PyTorch in teoria. Ho accennato in precedenza che PyTorch valuta effettivamente il gradiente sui parametri correnti anziché sui parametri spostati. Questo può essere visto osservando la descrizione dell’algoritmo nella documentazione di SGD di PyTorch. Approfondiremo questo argomento più avanti.

Si noti che per entrambe le formulazioni originali (1, 2) e PyTorch (3, 4), se v_0 = 0, il primo aggiornamento di θ diventa:

Nonostante la nota nella documentazione di SGD di PyTorch affermi che l’algoritmo inizializza il buffer di momentum al gradiente al primo passo, dimostreremo in seguito che ciò implica v_0 = 0.

Differenze Preliminari

Ci sono due differenze immediate quando si passa dalla formulazione originale (1, 2) alla formulazione di PyTorch (3, 4):

  1. Il tasso di apprendimento viene spostato fuori da v_{t+1}.
  2. Nella regola di aggiornamento per v_{t+1}, viene aggiunto invece che sottratto il termine che coinvolge il gradiente, e nella regola di aggiornamento per θ_{t+1}, viene sottratto invece che aggiunto il termine che coinvolge il vettore velocità. La differenza di segno all’interno del termine del gradiente è semplicemente una conseguenza di ciò, come mostrato nella sezione precedente.

Per capire queste differenze, espandiamo prima le regole di aggiornamento. Come suggerito qui, l’effetto della prima differenza è più evidente se consideriamo programmi di tassi di apprendimento. Quindi, consideriamo una generalizzazione delle regole di aggiornamento in cui ε non è più fisso ma può variare nel tempo, e indichiamo ε_t come il tasso di apprendimento al passo temporale t. Per brevità, indichiamo:

Supponendo v_0 = 0, la formulazione originale diventa:

e la formulazione di PyTorch diventa:

Nella formulazione originale (6), se il tasso di apprendimento cambiasse al tempo t, allora solo la magnitudine del termine i = t nella sommatoria sarebbe influenzata, mentre le magnitudini di tutti gli altri termini rimarrebbero uguali. Di conseguenza, l’influenza immediata del cambiamento del tasso di apprendimento è piuttosto limitata, e dovremmo aspettare che il cambiamento del tasso di apprendimento si “propaghi” nei passaggi temporali successivi per avere un’influenza più forte sulla dimensione complessiva del passo. Al contrario, nella formulazione di PyTorch (7), se il tasso di apprendimento cambiasse al tempo t, allora l’intera magnitudine del passo sarebbe influenzata immediatamente.

Per v_0 = 0, è chiaro dalle regole espandenti che la seconda differenza alla fine non ha alcun effetto; in entrambe le formulazioni, il passo si traduce in una somma scontata dei gradienti che viene sottratta dai parametri correnti.

Differenze Principali

Ignorando la riduzione del peso e l’ammorbidimento, analizzando l’algoritmo SGD nella documentazione di PyTorch, possiamo vedere che le regole di aggiornamento implementate sono:

dove θ’_{t+1} sono i parametri del modello al tempo t e

Chiameremo le equazioni 3 e 4 come la formulazione “note” di PyTorch, e le equazioni 8 e 9 come la formulazione “implementata” di PyTorch. Facciamo una distinzione tra θ e θ’ per una ragione che diventerà presto evidente. La differenza più evidente rispetto alla formulazione “note” è che il gradiente viene valutato sui parametri attuali anziché sui parametri spostati. Solo da questo potrebbe sembrare che le regole di aggiornamento implementate dall’algoritmo non siano una corretta implementazione del momento di Nesterov.

Ora esamineremo come l’algoritmo di PyTorch approssima il momento di Nesterov. Le derivazioni per una versione precedente di PyTorch possono essere trovate qui da Ivo Danihelka, citate in questo problema di GitHub. Le derivazioni per la versione corrente di PyTorch possono essere trovate qui, che è un aggiustamento relativamente semplice rispetto alle derivazioni precedenti. Forniamo una rappresentazione LaTeX di queste derivazioni (riderivate) qui per comodità del lettore. La formulazione implementata è derivata da un semplice cambio di variabili. In particolare, facciamo:

Diventa immediatamente chiaro che la regola di aggiornamento delle note per v_{t+1} (3) diventa equivalente alla regola di aggiornamento implementata per v_{t+1} (8) dopo il cambio di variabili. Vogliamo ora derivare una regola di aggiornamento per θ’_{t+1} in termini di θ’_t:

Questa è esattamente la regola di aggiornamento che abbiamo visto implementata in PyTorch (9). In generale, l’implementazione di PyTorch assume che i parametri attuali θ’_t siano già la versione spostata dei parametri “effettivi” θ_t. Pertanto, ad ogni passo temporale, i parametri “effettivi” θ_t sono legati ai parametri attuali θ’_t da:

Tuttavia, sembra dal codice sorgente che l’implementazione SGD di PyTorch non apporti alcuna correzione alla fine dell’algoritmo per recuperare i parametri “effettivi” finali, quindi l’output finale è tecnicamente un’approssimazione dei parametri “effettivi”.

Infine, mostriamo ora che v_0 deve essere 0:

Inoltre, possiamo confermare che il primo aggiornamento dei parametri “effettivi” è lo stesso primo aggiornamento effettuato nella formulazione originale quando v_0 = 0:

Possiamo vedere che questo è equivalente all’equazione 5.

Il Beneficio della Formulazione Implementata

Ovviamente, la grande domanda rimanente è: Perché PyTorch si preoccupa di riformulare il momento di Nesterov dalle equazioni 3 e 4 alle equazioni 8 e 9? Una possibile spiegazione è che la riformulazione potrebbe fornire dei risparmi nel numero di operazioni aritmetiche richieste. Per valutare questa possibile spiegazione, contiamo il numero di operazioni aritmetiche. Per la formulazione delle note (3, 4) abbiamo:

Qui, ci sono un totale di sette operazioni. Per la formulazione implementata (8, 9) abbiamo:

Qui, ci sono un totale di sei operazioni. Il secondo gradiente nell’implementazione di PyTorch utilizza semplicemente il risultato salvato dal primo calcolo del gradiente, quindi viene eseguito solo un calcolo del gradiente ad ogni passo temporale. Quindi, un beneficio evidente è che l’implementazione di PyTorch riduce di una moltiplicazione aggiuntiva ad ogni passo.

Conclusione

In sintesi:

  1. Le regole di aggiornamento indicate nelle note della documentazione SGD di PyTorch (3, 4) hanno una posizione diversa per il tasso di apprendimento rispetto alle regole di aggiornamento originarie del momento di Nesterov (1, 2). Ciò consente ai programmi di tasso di apprendimento di avere un effetto immediato sulla dimensione complessiva del passo, mentre la formulazione originale avrebbe l’effetto che i cambiamenti del tasso di apprendimento “scorrono” nei successivi passi temporali.
  2. Le regole di aggiornamento implementate nell’algoritmo SGD di PyTorch (8, 9) sono un’approssimazione delle regole di aggiornamento indicate nelle note della documentazione (3, 4) dopo un semplice cambio di variabili. Anche se i parametri “effettivi” sono facilmente recuperabili dai parametri attuali ad ogni passo temporale, l’implementazione di PyTorch non apporta alcuna correzione alla fine dell’algoritmo, quindi i parametri finali rimangono tecnicamente un’approssimazione dei parametri “effettivi” finali.
  3. Un beneficio evidente dell’implementazione di PyTorch è che evita una moltiplicazione aggiuntiva ad ogni passo temporale.

Riferimenti

  1. “SGD.” SGD — Documentazione PyTorch 2.0, pytorch.org/docs/stable/generated/torch.optim.SGD.html. Consultato il 2 settembre 2023.
  2. Sutskever, Ilya, et al. “Sull’importanza dell’inizializzazione e del momentum nell’apprendimento profondo.” Conferenza Internazionale sul Machine Learning. PMLR, 2013.
  3. Danihelka, Ivo. “Momentum di Nesterov reso semplice.” 25 agosto 2012.
  4. Chintala, Soumith. “Il momentum di Nesterov è sbagliato in sgd · Issue #27 · torch/optim.” GitHub, 13 ottobre 2014, github.com/torch/optim/issues/27.
  5. Gross, Sam. “Aggiungere una nota nella documentazione sulla formulazione del momentum utilizzata in optim · Issue #1099 · pytorch/pytorch.” GitHub, 25 marzo 2017, github.com/pytorch/pytorch/issues/1099#issuecomment-289190614.
  6. Zhao, Yilong. “Correggere il bug del momentum di Nesterov · Issue #5920 · pytorch/pytorch.” GitHub, 21 marzo 2018, https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908.