Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I nest LightningModules inside child modules? #20053

Open
jackdent opened this issue Jul 5, 2024 · 4 comments
Open

Can I nest LightningModules inside child modules? #20053

jackdent opened this issue Jul 5, 2024 · 4 comments
Labels
lightningmodule pl.LightningModule question Further information is requested ver: 2.2.x

Comments

@jackdent
Copy link

jackdent commented Jul 5, 2024

Bug description

Suppose I have a LightningModule (parent) that contains a nn.Module (child), which in turn contains another LightningModule (grandchild). Calling .log inside the LightningModule (the grandchild) results in the following warning:

You are trying to self.log() but the self.trainer reference is not registered on the model yet. This is most likely because the model hasn't been passed to the Trainer

The trainer is only set on the direct children of the parent LightningModule, not all the descendants, since the trainer.setter uses self.children() rather than self.modules():

@trainer.setter
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer # type: ignore[assignment]
self._trainer = trainer

What version are you seeing the problem on?

master

How to reproduce the bug

# %%

import lightning as L
import torch
from torch import nn


class GrandChild(L.LightningModule):
    def dummy_log(self):
        self.log("foo", 1)


class Child(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = nn.Linear(1, 1)
        self.grandchild = GrandChild()

    def forward(self):
        self.grandchild.dummy_log()
        return 1


class Parent(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.child = Child()

    def training_step(self, batch, batch_idx):
        return self.child()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.randn(10, 1)), batch_size=1
        )


# model
parent = Parent()

# train model
trainer = L.Trainer()
trainer.fit(model=parent)
optimizer = parent.configure_optimizers()

loss = parent.training_step(batch=None, batch_idx=0)

Error messages and logs

You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-80GB
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.2.1
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.1
    - torch: 2.3.1
    - torchmetrics: 1.3.2
    - torchvision: 0.18.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-113-generic
    - version: update Win CI req. #122 #123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024

More info

No response

cc @carmocca @justusschock @awaelchli @Borda

@jackdent jackdent added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 5, 2024
@awaelchli
Copy link
Member

awaelchli commented Jul 6, 2024

Hey @jackdent
This is not supported, and we don't document any examples of nesting LightningModules like that as far as I know. Since this goes against the design principles behind LightningModules, we can't really "fix" this. Even if the trainer reference was set, you still run into issues from a design perspective where hook calling suddenly is no longer well defined.

The LightningModule is meant to be the top-level module that organizes your code. Nesting it conceptually does not make sense. I'm making a guess, but the reason your are doing this could be that you'd like to reuse some code that you wrote and want to inherit from. This by itself is not a bad idea, but it can be achieved without having your children modules to be LightningModules. I recommend that you refactor your code so that your LightningModule is the top-level module.

@awaelchli awaelchli added question Further information is requested lightningmodule pl.LightningModule and removed bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x labels Jul 6, 2024
@awaelchli awaelchli changed the title LightningModule does not have trainer set when nested inside child module Can I nest LightningModules inside child modules? Jul 6, 2024
@jackdent
Copy link
Author

jackdent commented Jul 8, 2024

@awaelchli thank you for the response--it's reasonable that this behaviour is unsupported. However, it is worth pointing out that making a LightningModule a direct children of another LightningModules does seem to be well supported (e.g. the code snippet I shared above handles this case explicitly, and sets the trainer on all children).

Perhaps it's better to take a step back and explain what I'm trying to accomplish. I want to be able to log intermediate values (e.g. the mean + stddev of some activations in a nn.Module). Having the ability to Introspect operations inside nn.Modules is extremely important, since we regularly need to link into the internal state of the model to find bugs (e.g. we want to track the flow of grad norms through our network).

Since the Lightning logger is only defined on the LightningModule, not on nn.Modules, we can't call log inside an nn.Module. Right now, our solution is to pass through the logger as an argument to every forward method in all our child submodules, but that's fairly inelegant. Ideally, we'd just be able to caller lightning_logger.log from anywhere inside our child modules--creating a Logger class that inherits from LightningModule and setting that as a child on the nn.Modules was my attempt to achieve the desired behaviour, but I'm open to better solutions if you can think of any (e.g. should Lightning expose a singleton logger?).

@awaelchli
Copy link
Member

awaelchli commented Jul 8, 2024

I don't know why that code for assigning trainer to children is there. But I have never seen a use case where this can be exploited to a great benefit.

Having read about your use case (thanks for the context), it only makes me more confident that it is the wrong approach. Suppose what you proposed was supported in Lightning. You'd have a PyTorch model that has nested layers, but some of them are going to be LightningModules. While you might be able to implement your logging strategy this way, what happens when you're done training? Very likely you want to use the model, by loading a checkpoint. But now your model contains tons of code that is unrelated to your inference/deployment pipeline. In fact, at inference there won't be a trainer object defined! So you'll have to anyway change/update your model code after the fact and guard all your logging calls. All this will unnecessarily complicate your model code. But this is the reason why model code should not be mixed with orchestration code! It's a trap. The better way, the Lightning way, is to separate training orchestration code from model code (the definition of your forward). That's why the LightningModule as a top-level system is there for

# The pure nn Module. Contains PURE modelling code, no training, logging or testing
class MyModel(nn.Module):
    def forward(self, x):
 		....

# The LightningModule, contains code for all interactions with your model 
# For example training, evaluation, or inference
class MyTask(LightningModule):
	def __init__(self, ...):
	    super().__init__()
		self.model = MyModel()

	# special hooks


# Later on when we're done training, we can just use MyModel directly and throw away the LightningModule (no need anymore)

model = MyModel()
model.load_state_dict(...)
model(input)

This is the high-level approach to Lightning's design principles.

To achieve this

Having the ability to Introspect operations inside nn.Modules is extremely important, since we regularly need to link into the internal state of the model to find bugs (e.g. we want to track the flow of grad norms through our network).

there are other ways. I see at least two:

a) Return the debugging information of your intermediate outputs as meta-data from your forward call, perhaps as a dict. Collect that output in your training_step(), then process it there and log it. Pro: Maximum flexibility, no orchestration dependencies. Con: Still some debugging related code tied to your model

b) Use forward or backward hooks (a feature in PyTorch) to collect your intermediate module outputs and do something with them. For example, I've done that in the past for plotting histograms of intermediate outputs. Pro: your model code remains completely clean of any debugging code! Con: Less flexible, more code to write.

c) If you can't avoid using a logger in your PyTorch module directly, then I suggest passing it to __init__ or saving it as a reference rather than passing it through forward. You can access the logger in your LightningModule as self.logger. Pro: Closes to what you've been doing so far, no new Lightning features required. Con: Logger is tied to your model

I hope one of these fits your needs and you can give it some thoughts.

@jackdent
Copy link
Author

jackdent commented Jul 9, 2024

Thank you for the comprehensive answer @awaelchli -- using forward/backward hooks, and passing through the trainer/logger as a closure when defining those hooks on the level of the task is a great solution (far better than the direction I was going in). Your data monitor snippet is an extremely helpful reference implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lightningmodule pl.LightningModule question Further information is requested ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

2 participants