Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EarlyStopping override disrupts wandb logging frequency #19990

Open
RafiBrent opened this issue Jun 18, 2024 · 0 comments
Open

EarlyStopping override disrupts wandb logging frequency #19990

RafiBrent opened this issue Jun 18, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@RafiBrent
Copy link

Bug description

When an EarlyStopping callback would halt the training before min_epochs has elapsed, EarlyStopping is (correctly) overridden, and prints the warning message given below. However, at the exact step number when the warning was printed, WandbLogger suddenly begins logging the train metrics for every single batch. This results in slowed training and strange output graphs.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import torch
from lightning.pytorch import LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    seed_everything(42, workers=True)
    wandb_logger = WandbLogger(project="bug-report",
                               entity="example-user", name="debug_logging")

    early_stopping_callback = EarlyStopping(monitor="train_loss", patience=2)

    callbacks = [early_stopping_callback]

    kwargs = {

        "log_every_n_steps": 8,
        "logger": wandb_logger,
        "num_sanity_val_steps": 0,
        "callbacks": callbacks,
        "val_check_interval": 0.1,
        "max_epochs": 10,
        "min_epochs": 2,
        "deterministic": True

    }

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(**kwargs)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


if __name__ == "__main__":
    run()

Error messages and logs

Epoch 0:  38%|████████████████████████████████████████▉                                                                  | 12/32 [00:00<00:00, 53.14it/s, v_num=yqz5]

Trainer was signaled to stop but the required `min_epochs=2` or `min_steps=None` has not been met. Training will continue...          

Environment

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.2
    - torch: 2.3.0
    - torchmetrics: 1.4.0.post0
  • Packages:
    - appdirs: 1.4.4
    - appnope: 0.1.4
    - asttokens: 2.4.1
    - brotli: 1.1.0
    - certifi: 2024.6.2
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - colorama: 0.4.6
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - docker-pycreds: 0.4.0
    - exceptiongroup: 1.2.0
    - executing: 2.0.1
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - freetype-py: 2.3.0
    - fsspec: 2024.6.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - gmpy2: 2.1.5
    - greenlet: 3.0.3
    - idna: 3.7
    - importlib-metadata: 7.1.0
    - ipykernel: 6.29.3
    - ipython: 8.25.0
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - jupyter-client: 8.6.2
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.5
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - markupsafe: 2.1.5
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.7
    - mpmath: 1.3.0
    - munkres: 1.1.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - numexpr: 2.10.0
    - numpy: 1.26.4
    - packaging: 24.0
    - pandas: 2.2.2
    - parso: 0.8.4
    - pathtools: 0.1.2
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - prompt-toolkit: 3.0.46
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pycairo: 1.26.0
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - python-dateutil: 2.9.0
    - pytorch-lightning: 2.2.2
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - rdkit: 2024.3.3
    - reportlab: 4.1.0
    - requests: 2.32.3
    - rlpycairo: 0.2.0
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - sentry-sdk: 2.4.0
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - six: 1.16.0
    - smmap: 5.0.0
    - sqlalchemy: 2.0.30
    - stack-data: 0.6.2
    - sympy: 1.12
    - tables: 3.9.2
    - threadpoolctl: 3.5.0
    - torch: 2.3.0
    - torchmetrics: 1.4.0.post0
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - typing-extensions: 4.12.1
    - tzdata: 2024.1
    - urllib3: 2.2.1
    - wandb: 0.16.5
    - wcwidth: 0.2.13
    - wheel: 0.43.0
    - zipp: 3.17.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    -
    - processor: arm
    - python: 3.11.9
    - release: 23.5.0
    - version: Darwin Kernel Version 23.5.0: Wed May 1 20:19:05 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8112

More info

The symptoms of this bug are somewhat similar to those of #16821 and #13525, but based on those threads it seems like the causes may be different.

@RafiBrent RafiBrent added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant