Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions timm/scheduler/plateau_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from typing import List
from typing import List, Optional

from .scheduler import Scheduler

Expand Down Expand Up @@ -86,12 +86,14 @@ def step(self, epoch, metric=None):
param_group['lr'] = self.restore_lr[i]
self.restore_lr = None

self.lr_scheduler.step(metric, epoch) # step the base scheduler
# step the base scheduler if metric given
if metric is not None:
self.lr_scheduler.step(metric, epoch)

if self._is_apply_noise(epoch):
self._apply_noise(epoch)

def step_update(self, num_updates: int, metric: float = None):
def step_update(self, num_updates: int, metric: Optional[float] = None):
return None

def _apply_noise(self, epoch):
Expand Down
4 changes: 2 additions & 2 deletions timm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
return None
return self._get_lr(t)

def step(self, epoch: int, metric: float = None) -> None:
def step(self, epoch: int, metric: Optional[float] = None) -> None:
self.metric = metric
values = self._get_values(epoch, on_epoch=True)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)

def step_update(self, num_updates: int, metric: float = None):
def step_update(self, num_updates: int, metric: Optional[float] = None):
self.metric = metric
values = self._get_values(num_updates, on_epoch=False)
if values is not None:
Expand Down
18 changes: 16 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
group.add_argument('--val-interval', type=int, default=1, metavar='N',
help='how many epochs between validation and checkpointing')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
Expand Down Expand Up @@ -1013,6 +1015,18 @@ def main():
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

epoch_p_1 = epoch + 1
if epoch_p_1 % args.val_interval != 0 and epoch_p_1 != num_epochs:
if utils.is_primary(args):
_logger.info("Skipping eval and checkpointing ")
if lr_scheduler is not None:
# step LR for next epoch, take care when using metric dependent lr_scheduler
lr_scheduler.step(epoch_p_1, metric=None)
# Skip validation and metric logic
# FIXME we could make the logic below able to handle no eval metrics more gracefully,
# but for simplicity opting to just skip for now.
continue

if loader_eval is not None:
eval_metrics = validate(
model,
Expand Down Expand Up @@ -1064,7 +1078,7 @@ def main():

if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, latest_metric)
lr_scheduler.step(epoch_p_1, latest_metric)

latest_results = {
'epoch': epoch,
Expand Down Expand Up @@ -1252,7 +1266,7 @@ def _backward(_loss):
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now

if update_idx % args.log_interval == 0:
if update_idx % args.log_interval == 0 or last_batch:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)

Expand Down