From f20cbc154a005f7f4a763a30f971dec9d4717aae Mon Sep 17 00:00:00 2001 From: Tony Date: Wed, 5 Nov 2025 09:25:51 +0100 Subject: [PATCH 1/2] initial commit: - added val-interval argument. Eval and checkpointing is only applied every val-interval epochs. - Changed `float` to `Optional[float]` in typing Scheduler step function parameter `metric` - Skipping step of base scheduler in plateau scheduler to avoid `TypeError` when converting `None` to `float` - added `or last_batch` to logging logic during training to be consistent with validation --- timm/scheduler/plateau_lr.py | 8 +++++--- timm/scheduler/scheduler.py | 4 ++-- train.py | 14 +++++++++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index e868bd5e58..e8beead1c0 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -5,7 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch -from typing import List +from typing import List, Optional from .scheduler import Scheduler @@ -86,12 +86,14 @@ def step(self, epoch, metric=None): param_group['lr'] = self.restore_lr[i] self.restore_lr = None - self.lr_scheduler.step(metric, epoch) # step the base scheduler + # step the base scheduler if metric given + if metric is not None: + self.lr_scheduler.step(metric, epoch) if self._is_apply_noise(epoch): self._apply_noise(epoch) - def step_update(self, num_updates: int, metric: float = None): + def step_update(self, num_updates: int, metric: Optional[float] = None): return None def _apply_noise(self, epoch): diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 583357f7c5..f4eb8ab0fa 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -74,14 +74,14 @@ def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]: return None return self._get_lr(t) - def step(self, epoch: int, metric: float = None) -> None: + def step(self, epoch: int, metric: Optional[float] = None) -> None: self.metric = metric values = self._get_values(epoch, on_epoch=True) if values is not None: values = self._add_noise(values, epoch) self.update_groups(values) - def step_update(self, num_updates: int, metric: float = None): + def step_update(self, num_updates: int, metric: Optional[float] = None): self.metric = metric values = self._get_values(num_updates, on_epoch=False) if values is not None: diff --git a/train.py b/train.py index 131260dca4..eca09fb7f5 100755 --- a/train.py +++ b/train.py @@ -370,6 +370,8 @@ help='worker seed mode (default: all)') group.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') +group.add_argument('--val-interval', type=int, default=1, metavar='N', + help='how many epochs between validation and checkpointing') group.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', @@ -1013,6 +1015,16 @@ def main(): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + if (epoch + 1) % args.val_interval != 0: + if utils.is_primary(args): + _logger.info("Skipping eval and checkpointing ") + if lr_scheduler is not None: + # step LR for next epoch + # careful when using metric dependent lr_scheduler + lr_scheduler.step(epoch + 1, metric=None) + # skip validation and metric logic + continue + if loader_eval is not None: eval_metrics = validate( model, @@ -1252,7 +1264,7 @@ def _backward(_loss): update_time_m.update(time.time() - update_start_time) update_start_time = time_now - if update_idx % args.log_interval == 0: + if update_idx % args.log_interval == 0 or last_batch: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) From 800054e341aff3b6396814561af9e16a5ba8b9bb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 Nov 2025 16:03:06 -0800 Subject: [PATCH 2/2] Ensure final epoch always gets validated even if it doesn't line up with val interval. Add a few comments. --- train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index eca09fb7f5..0dacdc1dc2 100755 --- a/train.py +++ b/train.py @@ -1015,14 +1015,16 @@ def main(): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - if (epoch + 1) % args.val_interval != 0: + epoch_p_1 = epoch + 1 + if epoch_p_1 % args.val_interval != 0 and epoch_p_1 != num_epochs: if utils.is_primary(args): _logger.info("Skipping eval and checkpointing ") if lr_scheduler is not None: - # step LR for next epoch - # careful when using metric dependent lr_scheduler - lr_scheduler.step(epoch + 1, metric=None) - # skip validation and metric logic + # step LR for next epoch, take care when using metric dependent lr_scheduler + lr_scheduler.step(epoch_p_1, metric=None) + # Skip validation and metric logic + # FIXME we could make the logic below able to handle no eval metrics more gracefully, + # but for simplicity opting to just skip for now. continue if loader_eval is not None: @@ -1076,7 +1078,7 @@ def main(): if lr_scheduler is not None: # step LR for next epoch - lr_scheduler.step(epoch + 1, latest_metric) + lr_scheduler.step(epoch_p_1, latest_metric) latest_results = { 'epoch': epoch,