Skip to content

Releases: Lightning-AI/pytorch-lightning

Patch release v2.3.3

08 Jul 20:42
Compare
Choose a tag to compare

This release removes the code from the main lightning package that was reported in CVE-2024-5980.

Patch release v2.3.2

04 Jul 09:06
056bb08
Compare
Choose a tag to compare

Includes a minor bugfix that avoids a conflict with the entrypoint command with another package #20041.

Patch release v2.3.1

27 Jun 17:47
8b69285
Compare
Choose a tag to compare

Includes minor bugfixes and stability improvements.

Full Changelog: 2.3.0...2.3.1

Lightning v2.3: Tensor Parallelism and 2D Parallelism

13 Jun 21:30
a42484c
Compare
Choose a tag to compare

Lightning AI is excited to announce the release of Lightning 2.3 ⚡

Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.

This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.

Highlights

Tensor Parallelism (beta)

Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel.

PyTorch Lightning

Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model() method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        ...


# 2. Create the strategy
strategy = ModelParallelStrategy()

# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        if self.device_mesh is None:
            return

        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        output = self.model(batch)
        loss = output.sum()
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=3e-3)

    def train_dataloader(self):
        # Trainer configures the sampler automatically for you such that
        # all batches in a tensor-parallel group are identical
        dataset = RandomDataset(8192, 64)
        return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)


strategy = ModelParallelStrategy()
trainer = L.Trainer(
    accelerator="cuda",
    devices=2,
    strategy=strategy,
    max_epochs=1,
)

model = LitModel()
trainer.fit(model)

trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

Lightning Fabric

Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)

# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)

# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_si...
Read more

Patch release v2.2.5

22 May 17:28
ac3f1ee
Compare
Choose a tag to compare

PyTorch Lightning + Fabric

Fixed

  • Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)

Full Changelog: 2.2.4...2.2.5

Patch release v2.2.4

01 May 22:50
Compare
Choose a tag to compare

App

Fixed

  • Fixed HTTPClient retry for flow/work queue (#19837)

PyTorch

No Changes.

Fabric

No Changes.

Full Changelog: 2.2.3...2.2.4

Patch release v2.2.3

23 Apr 18:44
f4cd9df
Compare
Choose a tag to compare

PyTorch

Fixed

  • Fixed WandbLogger.log_hyperparameters() raising an error if hyperparameters are not JSON serializable (#19769)

Fabric

No Changes.

Full Changelog: 2.2.2...2.2.3

Patch release v2.2.2

11 Apr 13:27
Compare
Choose a tag to compare

PyTorch

Fixed

  • Fixed an issue causing a TypeError when using torch.compile as a decorator (#19627)
  • Fixed a KeyError when saving a FSDP sharded checkpoint and setting save_weights_only=True (#19524)

Fabric

Fixed

  • Fixed an issue causing a TypeError when using torch.compile as a decorator (#19627)
  • Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped (#19705)
  • Fixed an issue causing weights to be reset in Fabric.setup() when using FSDP (#19755)

Full Changelog: 2.2.1...2.2.2

Contributors

@ankitgola005 @awaelchli @Borda @carmocca @dmitsf @dvoytan-spark @fnhirwa

Patch release v2.2.1

04 Mar 17:21
Compare
Choose a tag to compare

PyTorch

Fixed

  • Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)
  • Fixed the divisibility check for Trainer.accumulate_grad_batches and Trainer.log_every_n_steps in ThroughputMonitor (#19470)
  • Fixed support for Remote Stop and Remote Abort with NeptuneLogger (#19130)
  • Fixed infinite recursion error in precision plugin graveyard (#19542)

Fabric

Fixed

  • Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)

Full Changelog: 2.2.0post...2.2.1

Contributors

@Raalsky @awaelchli @carmocca @Borda

If we forgot someone due to not matching commit email with GitHub account, let us know :]

Minor release correction

12 Feb 19:43
Compare
Choose a tag to compare