Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAN training crashes with unused parameters #20034

Open
samsara-ku opened this issue Jul 1, 2024 · 0 comments
Open

GAN training crashes with unused parameters #20034

samsara-ku opened this issue Jul 1, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@samsara-ku
Copy link

Bug description

I have some problem with training with my codes without strategy=ddp_find_unused_parameters_true. When I turned on the flag, it seems there is no parameters that didn't participate in the training. But when I turned off, it always crashed with some error logs like your model has unused parameters ~~~.

You can see attached codes, and I found that the actual error is caused by .detach() operation, because when I removed that there is no problem with training. How can I solve this problem?

What version are you seeing the problem on?

v2.2

How to reproduce the bug

def training_step(self, batch, batch_idx):
        g_opt, d_opt = self.optimizers()

        src_img, drv_img = batch["src"], batch["drv"]
        gen_img = self.generator(src_img, drv_img)

        errD = self.gan_loss(drv_img, gen_img.detach(), opt_d=True)["errD"]

        d_opt.zero_grad(set_to_none=True)
        self.manual_backward(errD, retain_graph=True)
        d_opt.step()

        gan_loss = self.gan_loss(drv_img, gen_img, opt_d=False)
        perceptual_loss = self.perceptual_loss(drv_img, gen_img)

        errG = gan_loss["errG_GAN"] + gan_loss["errG_FM"] + perceptual_loss["vgg_imagenet"] + perceptual_loss["vgg_face"]

        g_opt.zero_grad(set_to_none=True)
        self.manual_backward(errG)
        g_opt.step()

Error messages and logs

# Error messages and logs here please

Environment

Current environment
Package                      Version
---------------------------- -----------
absl-py                      2.1.0
aiohttp                      3.9.5
aiosignal                    1.3.1
antlr4-python3-runtime       4.9.3
asttokens                    2.4.1
astunparse                   1.6.3
async-timeout                4.0.3
attrs                        23.2.0
backcall                     0.2.0
cachetools                   5.3.3
certifi                      2024.2.2
charset-normalizer           3.3.2
click                        8.1.7
comm                         0.2.2
contourpy                    1.1.1
cycler                       0.12.1
debugpy                      1.6.7
decorator                    5.1.1
docker-pycreds               0.4.0
executing                    2.0.1
facenet-pytorch              2.6.0
filelock                     3.14.0
flatbuffers                  24.3.25
fonttools                    4.53.0
frozenlist                   1.4.1
fsspec                       2024.5.0
gast                         0.4.0
gitdb                        4.0.11
GitPython                    3.1.43
google-auth                  2.30.0
google-auth-oauthlib         1.0.0
google-pasta                 0.2.0
grpcio                       1.64.1
h5py                         3.11.0
idna                         3.7
importlib_metadata           7.1.0
importlib_resources          6.4.0
ipykernel                    6.29.3
ipython                      8.12.2
jedi                         0.19.1
Jinja2                       3.1.4
jupyter_client               8.6.2
jupyter_core                 5.7.2
keras                        2.13.1
kiwisolver                   1.4.5
libclang                     18.1.1
lightning-utilities          0.11.2
Markdown                     3.6
MarkupSafe                   2.1.5
matplotlib                   3.7.5
matplotlib-inline            0.1.7
mpmath                       1.3.0
mtcnn                        0.1.1
multidict                    6.0.5
nest_asyncio                 1.6.0
networkx                     3.1
numpy                        1.24.3
nvidia-cublas-cu12           12.1.3.1
nvidia-cuda-cupti-cu12       12.1.105
nvidia-cuda-nvrtc-cu12       12.1.105
nvidia-cuda-runtime-cu12     12.1.105
nvidia-cudnn-cu12            8.9.2.26
nvidia-cufft-cu12            11.0.2.54
nvidia-curand-cu12           10.3.2.106
nvidia-cusolver-cu12         11.4.5.107
nvidia-cusparse-cu12         12.1.0.106
nvidia-nccl-cu12             2.19.3
nvidia-nvjitlink-cu12        12.5.40
nvidia-nvtx-cu12             12.1.105
oauthlib                     3.2.2
omegaconf                    2.3.0
opencv-python                4.10.0.82
opt-einsum                   3.3.0
packaging                    24.0
parso                        0.8.4
pexpect                      4.9.0
pickleshare                  0.7.5
pillow                       10.2.0
pip                          24.0
platformdirs                 4.2.2
prompt-toolkit               3.0.42
protobuf                     4.25.3
psutil                       5.9.8
ptyprocess                   0.7.0
pure-eval                    0.2.2
pyasn1                       0.6.0
pyasn1_modules               0.4.0
Pygments                     2.18.0
pyparsing                    3.1.2
python-dateutil              2.9.0
pytorch-lightning            2.2.5
PyYAML                       6.0.1
pyzmq                        25.1.2
requests                     2.32.3
requests-oauthlib            2.0.0
rsa                          4.9
scipy                        1.10.1
sentry-sdk                   2.4.0
setproctitle                 1.3.3
setuptools                   69.5.1
six                          1.16.0
slack_sdk                    3.29.0
smmap                        5.0.1
stack-data                   0.6.2
sympy                        1.12.1
tensorboard                  2.13.0
tensorboard-data-server      0.7.2
tensorflow-estimator         2.13.0
tensorflow-io-gcs-filesystem 0.34.0
termcolor                    2.4.0
torch                        2.2.2
torchaudio                   2.3.0
torchmetrics                 1.4.0.post0
torchvision                  0.17.2
tornado                      6.4
tqdm                         4.66.4
traitlets                    5.14.3
triton                       2.2.0
typing_extensions            4.12.0
urllib3                      2.2.1
wandb                        0.17.0
wcwidth                      0.2.13
Werkzeug                     3.0.3
wheel                        0.43.0
wrapt                        1.16.0
yarl                         1.9.4
zipp                         3.17.0

More info

No response

@samsara-ku samsara-ku added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant