Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoothing in tqdm progress bar has no effect #20003

Open
heth27 opened this issue Jun 21, 2024 · 1 comment
Open

Smoothing in tqdm progress bar has no effect #20003

heth27 opened this issue Jun 21, 2024 · 1 comment
Labels
docs Documentation related good first issue Good for newcomers help wanted Open to be worked on

Comments

@heth27
Copy link

heth27 commented Jun 21, 2024

Bug description

The option smoothing when creating progress bars in TQDMProgressBar has no effect in the default implementation, as
_update_n only calls bar.refresh() and not the update method of the progress bar. Thus only the global average is taken, as the update method of the tqdm class is responsible for calculating moving averages.
Either the update method of the progress bar could be used or it should be added to the documentation if smoothing having no effect is the desired behavior (overriding a default that has no effect is a bit misleading)

What version are you seeing the problem on?

master

How to reproduce the bug

import time

import lightning.pytorch as pl
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler

from src.main.ml.data.data_augmentation.helpers.random_numbers import create_rng_from_string
import sys
from typing import Any

import lightning.pytorch as pl
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.utilities.types import STEP_OUTPUT
from typing_extensions import override


class LitProgressBar(TQDMProgressBar):
    """
    different smoothing factor than default lightning TQDMProgressBar, where smoothing=0 (average),
     instead of smoothing=1 (current speed) is taken

     See also:
     https://tqdm.github.io/docs/tqdm/
    """

    def init_train_tqdm(self) -> Tqdm:
        """Override this to customize the tqdm bar for training."""
        return Tqdm(
            desc=self.train_description,
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=True,
            dynamic_ncols=True,
            file=sys.stdout,
            smoothing=1.0,
            bar_format=self.BAR_FORMAT,
        )

    # default method

    # @override
    # def on_train_batch_end(
    #     self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
    # ) -> None:
    #     n = batch_idx + 1
    #     if self._should_update(n, self.train_progress_bar.total):
    #         _update_n(self.train_progress_bar, n)
    #         self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))


   # my own method that uses smoothing by using the update method of progress bar
    @override
    def on_train_batch_end(
            self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any,
            batch_idx: int
    ) -> None:
        n = batch_idx + 1
        if self._should_update(n, self.train_progress_bar.total):
            self.train_progress_bar.update(self.refresh_rate)
            self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))


class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.simple_layer(input)


class TestBatchSampler(Sampler):
    def __init__(self, step=0):
        super().__init__()
        self.step = step

    def __len__(self) -> int:
        return 1e100
        # return len(self.train_allfiles)

    def __iter__(self):  # -> Iterator[int]:
        return self

    def __next__(self):  # -> Iterator[int]:
        return_value = self.step
        self.step += 1
        return [return_value]


class TestDataset(Dataset):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.total_len = 512

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        rng = create_rng_from_string(
            str(idx) + "_"
            + "random_choice_sampler")
        return torch.tensor(rng.random(self.in_dim), dtype=torch.float32)


class TestDataModule(pl.LightningDataModule):
    def __init__(self, start_step=0):
        super().__init__()
        self.in_dim = 512
        self.val_batch_size = 1
        self.start_step = start_step

    def train_dataloader(self):
        train_ds = TestDataset(self.in_dim)
        train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(step=self.start_step), num_workers=4,
                              shuffle=False)
        return train_dl


class TestLitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.test_module_obj = TestModule(in_dim=512, out_dim=16)
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        if batch_idx == 0:
            time.sleep(5)

        time.sleep(0.5)

        optimizer = self.optimizers()

        output = self.test_module_obj(batch)

        loss = output.sum()

        self.manual_backward(loss)

        optimizer.step()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.test_module_obj.parameters()
        )
        return optimizer


if __name__ == '__main__':
    test_data_loader = TestDataModule()
    test_lit_model = TestLitModel()

    bar = LitProgressBar(refresh_rate=5)
    trainer = pl.Trainer(
        log_every_n_steps=1,
        callbacks=[bar],
        max_epochs=-1,
        max_steps=400000,
    )

    trainer.fit(test_lit_model,
                datamodule=test_data_loader)


### Error messages and logs

Error messages and logs here please



### Environment

<details>
  <summary>Current environment</summary>

#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
#- Running environment of LightningApp (e.g. local, cloud):


</details>


### More info

_No response_

cc @borda
@heth27 heth27 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 21, 2024
@heth27 heth27 changed the title Smoothing in tqdm progress bar Smoothing in tqdm progress bar has no effect Jun 21, 2024
@awaelchli
Copy link
Member

as _update_n only calls bar.refresh() and not the update method of the progress bar

This change is needed to give us exact control in all use cases to enable exact updates. The progress bar is deeply tied to the loops so we need that precise control. So we can't remove the _update_n change. But you can definitely override everything in TQDMProgressBar you wish and make it your own if that smoothing option is important.

or it should be added to the documentation if smoothing having no effect is the desired behavior (overriding a default that has no effect is a bit misleading)

I'm fine with adding a note to the TQDMProgressBar.init_*_tqdm methods documentation 👍

@awaelchli awaelchli added help wanted Open to be worked on good first issue Good for newcomers docs Documentation related and removed bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x labels Jul 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation related good first issue Good for newcomers help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants