Source code for extorch.nn.optim.lr_scheduler

from typing import List

import numpy as np
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer


[docs]class CosineWithRestarts(_LRScheduler): r""" Cosine annealing with restarts (`Link`_). Args: optimizer (Optimizer): The optimizer. t_max (int): The maximum number of iterations within the first cycle. eta_min (float): The minimum learning rate. Default: 0. last_epoch (int): The index of the last epoch. This is used when restarting. Default: -1. factor (float): The factor by which the cycle length (T_max) increases after each restart. Default: 1. .. _Link: https://arxiv.org/abs/1608.03983 """ def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float = 0., last_epoch: int = -1, factor: float = 1.) -> None: self.T_max = T_max self.eta_min = eta_min self.factor = factor self._last_restart = 0 self._cycle_counter = 0 self._cycle_factor = 1. self._updated_cycle_len = t_0 self._initialized = False super(CosineWithRestarts, self).__init__(optimizer, last_epoch)
[docs] def get_lr(self) -> List[float]: r""" Get updated learning rate. Returns: lrs (List[float]): A list of current learning rates. Note: We need to check if this is the first time ``self.get_lr()`` was called, since ``torch.optim.lr_scheduler._LRScheduler`` will call ``self.get_lr()`` when first initialized, but the learning rate should remain unchanged for the first epoch. """ if not self._initialized: self._initialized = True return self.base_lrs step = self.last_epoch self._cycle_counter = step - self._last_restart lrs = [ self.eta_min + ((lr - self.eta_min) / 2) * ( np.cos( np.pi * (self._cycle_counter % self._updated_cycle_len) / self._updated_cycle_len ) + 1 ) for lr in self.base_lrs ] if self._cycle_counter != 0 and self._cycle_counter % self._updated_cycle_len == 0: # Adjust the cycle length. self._cycle_factor *= self.factor self._cycle_counter = 0 self._updated_cycle_len = int(self._cycle_factor * self.T_max) self._last_restart = step return lrs