You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to use a Hessian-Free LM optimizer replace the pytorch L-BFGS optimizer. However, the model can't be saved normally if I use the ModelCheckpoint(), while the torch.save() and Trainer.save_checkpoint() are still working. You can find my test python file in the following. Could you give me some suggestions to handle this problem?
Thanks!
What version are you seeing the problem on?
v2.2
How to reproduce the bug
importnumpyasnpimportpandasaspdimporttimeimporttorchfromtorchimportnnfromtorch.utils.dataimportDataLoader,TensorDatasetimportmatplotlib.pyplotaspltimportlightningasLfromlightning.pytorchimportLightningModulefromlightning.pytorch.loggersimportCSVLoggerfromlightning.pytorch.callbacks.model_checkpointimportModelCheckpointfromlightning.pytorchimportTrainerfromlightning.pytorch.callbacks.early_stoppingimportEarlyStoppingfromhessianfree.optimizerimportHessianFreeclassLitModel(LightningModule):
def__init__(self,loss):
super().__init__()
self.tanh_linear=nn.Sequential(
nn.Linear(1,20),
nn.Tanh(),
nn.Linear(20,20),
nn.Tanh(),
nn.Linear(20,1),
)
self.loss_fn=nn.MSELoss()
self.automatic_optimization=Falsereturndefforward(self, x):
out=self.tanh_linear(x)
returnoutdefconfigure_optimizers(self):
optimizer=HessianFree(
self.parameters(),
cg_tol=1e-6,
cg_max_iter=1000,
lr=1e0,
LS_max_iter=1000,
LS_c=1e-3
)
returnoptimizerdeftraining_step(self, batch, batch_idx):
x, y=batchopt=self.optimizers()
defforward_fn():
y_pred=self(x)
loss=self.loss_fn(y_pred,y)
returnloss,y_predopt.optimizer.step( forward=forward_fn)
loss,y_pred=forward_fn()
self.log("train_loss", loss, on_epoch=True, on_step=False)
returnlossdefvalidation_step(self, batch, batch_idx):
x, y=batchy_hat=self(x)
val_loss=self.loss_fn(y_hat, y)
# passing to early_stopingself.log("val_loss", val_loss, on_epoch=True, on_step=False)
returnval_lossdeftest_step(self, batch, batch_idx):
x, y=batchy_hat=self(x)
loss=self.loss_fn(y_hat, y)
returnlossdefmain():
input_size=20000train_size=int(input_size*0.9)
test_size=input_size-train_sizebatch_size=1000x_total=np.linspace(-1.0, 1.0, input_size, dtype=np.float32)
x_total=np.random.choice(x_total,size=input_size,replace=False) #random samplingx_train=x_total[0:train_size]
x_train=x_train.reshape((train_size,1))
x_test=x_total[train_size:input_size]
x_test=x_test.reshape((test_size,1))
x_train=torch.from_numpy(x_train)
x_test=torch.from_numpy(x_test)
y_train=torch.from_numpy(np.sinc(10.0*x_train))
y_test=torch.from_numpy(np.sinc(10.0*x_test))
training_data=TensorDataset(x_train,y_train)
test_data=TensorDataset(x_test,y_test)
# Create data loaders.train_dataloader=DataLoader(training_data, batch_size=batch_size#,num_workers=2
)
test_dataloader=DataLoader(test_data, batch_size=batch_size#,num_workers=2
)
forX, yintest_dataloader:
print("Shape of X: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
breakforX, yintrain_dataloader:
print("Shape of X: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
breakloss_fn=nn.MSELoss()
model=LitModel(loss_fn)
# prepare traineropt_label=f'lm_HF_t20'logger=CSVLogger(f"./{opt_label}", name=f"test-{opt_label}",flush_logs_every_n_steps=1)
epochs=1e1print(f"test for {opt_label}")
early_stop_callback=EarlyStopping(
monitor="val_loss"
, min_delta=1e-9
, patience=10
, verbose=False, mode="min"
, stopping_threshold=1e-8#stop if reaching accuracy
)
modelck=ModelCheckpoint(
dirpath=f"./{opt_label}"
, monitor="val_loss"
,save_last=True#, save_top_k = 2#, mode ='min'#, every_n_epochs = 1#, save_on_train_epoch_end=True#,save_weights_only=True,
)
Train_model=Trainer(
accelerator="cpu"
, max_epochs=int(epochs)
, enable_progress_bar=True#using progress bar#, callbacks=[modelck,early_stop_callback] # using earlystopping
, callbacks=[modelck] #do not using earlystopping
, logger=logger#, num_processes = 16
)
t1=time.time()
Train_model.fit(model,train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
t2=time.time()
print('total time')
print(t2-t1)
# torch.save() and Trainer.save_checkpoint() can save the model, but ModelCheckpoint() can't.#torch.save(model.state_dict(), f"model{opt_label}.pth")#print(f"Saved PyTorch Model State to model{opt_label}.pth")#Train_model.save_checkpoint(f"model{opt_label}.ckpt")#print(f"Saved PL Model State to model{opt_label}.ckpt")exit()
returnif__name__=='__main__':
main()
Bug description
Dear all,
I want to use a Hessian-Free LM optimizer replace the pytorch L-BFGS optimizer. However, the model can't be saved normally if I use the ModelCheckpoint(), while the torch.save() and Trainer.save_checkpoint() are still working. You can find my test python file in the following. Could you give me some suggestions to handle this problem?
Thanks!
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs here please
The text was updated successfully, but these errors were encountered: