Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't save models via the ModelCheckpoint() when using custom optimizer #20033

Open
youli-jlu opened this issue Jul 1, 2024 · 0 comments
Open
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@youli-jlu
Copy link

Bug description

Dear all,

I want to use a Hessian-Free LM optimizer replace the pytorch L-BFGS optimizer. However, the model can't be saved normally if I use the ModelCheckpoint(), while the torch.save() and Trainer.save_checkpoint() are still working. You can find my test python file in the following. Could you give me some suggestions to handle this problem?

Thanks!

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt

import lightning as L
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from hessianfree.optimizer import HessianFree


class LitModel(LightningModule):
    def __init__(self,loss):
        super().__init__()
        self.tanh_linear= nn.Sequential(
                nn.Linear(1,20),
                nn.Tanh(),
                nn.Linear(20,20),
                nn.Tanh(),
                nn.Linear(20,1),
                )
        self.loss_fn = nn.MSELoss()
        self.automatic_optimization = False
        return

    def forward(self, x):
        out = self.tanh_linear(x)
        return out

    def configure_optimizers(self):
        optimizer = HessianFree(
                self.parameters(),
                cg_tol=1e-6,
                cg_max_iter=1000,
                lr=1e0,
                LS_max_iter=1000,
                LS_c=1e-3
                )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        opt = self.optimizers()
        def forward_fn():
            y_pred = self(x)
            loss=self.loss_fn(y_pred,y)
            return loss,y_pred
        opt.optimizer.step( forward=forward_fn)
        loss,y_pred=forward_fn()
        self.log("train_loss", loss, on_epoch=True, on_step=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = self.loss_fn(y_hat, y)
        # passing to early_stoping
        self.log("val_loss", val_loss, on_epoch=True, on_step=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        return loss

def main():
    input_size = 20000
    train_size = int(input_size*0.9)
    test_size  = input_size-train_size
    batch_size = 1000

    x_total = np.linspace(-1.0, 1.0, input_size, dtype=np.float32)
    x_total = np.random.choice(x_total,size=input_size,replace=False) #random sampling
    x_train = x_total[0:train_size]
    x_train= x_train.reshape((train_size,1))
    x_test  = x_total[train_size:input_size]
    x_test= x_test.reshape((test_size,1))

    x_train=torch.from_numpy(x_train)
    x_test=torch.from_numpy(x_test)

    y_train = torch.from_numpy(np.sinc(10.0 * x_train))
    y_test  = torch.from_numpy(np.sinc(10.0 * x_test))

    training_data = TensorDataset(x_train,y_train)
    test_data = TensorDataset(x_test,y_test)


    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size
            #,num_workers=2
            )
    test_dataloader = DataLoader(test_data, batch_size=batch_size
            #,num_workers=2
            )

    for X, y in test_dataloader:
        print("Shape of X: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break
    for X, y in train_dataloader:
        print("Shape of X: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break




    loss_fn = nn.MSELoss()

    model=LitModel(loss_fn)

    # prepare trainer
    opt_label=f'lm_HF_t20'

    logger = CSVLogger(f"./{opt_label}", name=f"test-{opt_label}",flush_logs_every_n_steps=1)
    epochs = 1e1
    print(f"test for {opt_label}")
    early_stop_callback = EarlyStopping(
            monitor="val_loss"
            , min_delta=1e-9
            , patience=10
            , verbose=False, mode="min"
            , stopping_threshold = 1e-8 #stop if reaching accuracy
            )
    modelck=ModelCheckpoint(
            dirpath = f"./{opt_label}"
            , monitor="val_loss"
            ,save_last = True
            #, save_top_k = 2
            #, mode ='min'
            #, every_n_epochs = 1
            #, save_on_train_epoch_end=True
            #,save_weights_only=True,
            )

    Train_model=Trainer(
            accelerator="cpu"
            , max_epochs = int(epochs)
            , enable_progress_bar = True #using progress bar
            #, callbacks=[modelck,early_stop_callback] # using earlystopping
            , callbacks=[modelck] #do not using earlystopping
            , logger=logger
            #, num_processes = 16
            )

    t1=time.time()
    Train_model.fit(model,train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    t2=time.time()

    print('total time')
    print(t2-t1)





    # torch.save() and Trainer.save_checkpoint() can save the model, but ModelCheckpoint() can't.
    #torch.save(model.state_dict(), f"model{opt_label}.pth")
    #print(f"Saved PyTorch Model State to model{opt_label}.pth")
    #Train_model.save_checkpoint(f"model{opt_label}.ckpt")
    #print(f"Saved PL Model State to model{opt_label}.ckpt")
    exit()
    return

if __name__=='__main__':
    main()


### Error messages and logs

Error messages and logs here please

The program do not report error, but the ModelCheckpoint() can't save models when I use a custom optimizer.

### Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:               None
        - available:         False
        - version:           12.1
* Lightning:
        - backpack-for-pytorch: 1.6.0
        - lightning:         2.2.0
        - lightning-utilities: 0.11.3.post0
        - pytorch-lightning: 2.2.3
        - torch:             2.2.0
        - torchaudio:        2.0.1
        - torchmetrics:      0.11.4
        - torchvision:       0.15.1
* Packages:
        - aiohttp:           3.9.1
        - aiosignal:         1.3.1
        - async-timeout:     4.0.3
        - attrs:             23.2.0
        - backpack-for-pytorch: 1.6.0
        - bottleneck:        1.3.5
        - certifi:           2022.12.7
        - charset-normalizer: 3.1.0
        - cmake:             3.26.0
        - colorama:          0.4.6
        - contourpy:         1.2.1
        - cycler:            0.12.1
        - einops:            0.8.0
        - filelock:          3.10.0
        - fonttools:         4.51.0
        - frozenlist:        1.4.1
        - fsspec:            2023.3.0
        - hessianfree:       0.1
        - idna:              3.4
        - jinja2:            3.1.2
        - kiwisolver:        1.4.5
        - lightning:         2.2.0
        - lightning-utilities: 0.11.3.post0
        - lit:               15.0.7
        - markupsafe:        2.1.2
        - matplotlib:        3.8.4
        - mpmath:            1.3.0
        - multidict:         6.0.4
        - networkx:          3.0
        - numexpr:           2.8.4
        - numpy:             1.24.2
        - nvidia-cublas-cu11: 11.10.3.66
        - nvidia-cublas-cu12: 12.1.3.1
        - nvidia-cuda-cupti-cu11: 11.7.101
        - nvidia-cuda-cupti-cu12: 12.1.105
        - nvidia-cuda-nvrtc-cu11: 11.7.99
        - nvidia-cuda-nvrtc-cu12: 12.1.105
        - nvidia-cuda-runtime-cu11: 11.7.99
        - nvidia-cuda-runtime-cu12: 12.1.105
        - nvidia-cudnn-cu11: 8.5.0.96
        - nvidia-cudnn-cu12: 8.9.2.26
        - nvidia-cufft-cu11: 10.9.0.58
        - nvidia-cufft-cu12: 11.0.2.54
        - nvidia-curand-cu11: 10.2.10.91
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu11: 11.4.0.1
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu11: 11.7.4.91
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu11:  2.14.3
        - nvidia-nccl-cu12:  2.19.3
        - nvidia-nvjitlink-cu12: 12.3.101
        - nvidia-nvtx-cu11:  11.7.91
        - nvidia-nvtx-cu12:  12.1.105
        - packaging:         23.0
        - pandas:            1.5.3
        - pillow:            9.4.0
        - pip:               24.1.1
        - pyparsing:         3.1.2
        - python-dateutil:   2.8.2
        - pytorch-lightning: 2.2.3
        - pytz:              2022.7
        - pyyaml:            6.0
        - requests:          2.28.2
        - setuptools:        67.6.0
        - six:               1.16.0
        - sympy:             1.11.1
        - torch:             2.2.0
        - torchaudio:        2.0.1
        - torchmetrics:      0.11.4
        - torchvision:       0.15.1
        - tqdm:              4.65.0
        - triton:            2.2.0
        - typing-extensions: 4.11.0
        - unfoldnd:          0.2.1
        - urllib3:           1.26.15
        - wheel:             0.40.0
        - yarl:              1.9.4
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.10.9
        - release:           3.10.0-862.el7.x86_64
        - version:           #1 SMP Fri Apr 20 16:44:24 UTC 2018

</details>



### More info

_No response_
@youli-jlu youli-jlu added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant