Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save save_hyperparameters no longer respects linked arguments. #20311

Open
Erotemic opened this issue Sep 30, 2024 · 4 comments
Open

Save save_hyperparameters no longer respects linked arguments. #20311

Erotemic opened this issue Sep 30, 2024 · 4 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x ver: 2.4.x

Comments

@Erotemic
Copy link
Contributor

Erotemic commented Sep 30, 2024

Bug description

As of lightning 2.3.0 save_hyperparameters no longer seems to respect linked arguments.

Based on my investigation this seems to be due to #18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff

What version are you seeing the problem on?

v2.3, v2.4, master

How to reproduce the bug

Save the following script as: lightning_cli_save_hyperaparams_error_on_link_args.py

import torch
import torch.nn
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from typing import List, Dict


class MWE_Model(pl.LightningModule):
    """
    Example:
        >>> dataset = MWE_Dataset()
        >>> self = MWE_Model(dataset_stats=dataset.dataset_stats)
        >>> batch = [dataset[i] for i in range(2)]
        >>> self.forward(batch)
    """
    def __init__(self, sorting=False, dataset_stats=None, d_model=16):
        super().__init__()
        self.save_hyperparameters()

        if dataset_stats is None:
            raise ValueError('must be given dataset stats')

        self.d_model = d_model
        self.dataset_stats = dataset_stats

        self.known_sensorchan = {
            (mode['sensor'], mode['channels'], mode['num_bands'])
            for mode in self.dataset_stats['known_modalities']
        }
        self.known_tasks = self.dataset_stats['known_tasks']
        if sorting:
            self.known_sensorchan = sorted(self.known_sensorchan)
            self.known_tasks = sorted(self.known_tasks, key=lambda t: t['name'])

        # Construct stems based on the dataset
        self.stems = torch.nn.ModuleDict()
        for sensor, channels, num_bands in self.known_sensorchan:
            if sensor not in self.stems:
                self.stems[sensor] = torch.nn.ModuleDict()
            self.stems[sensor][channels] = torch.nn.Conv2d(num_bands, self.d_model, kernel_size=1)

        # Backbone is small generic transformer
        self.backbone = torch.nn.Transformer(
            d_model=self.d_model,
            nhead=4,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            batch_first=True
        )

        # Construct heads based on the dataset
        self.heads = torch.nn.ModuleDict()
        for head_info in self.known_tasks:
            head_name = head_info['name']
            head_classes = head_info['classes']
            num_classes = len(head_classes)
            self.heads[head_name] = torch.nn.Conv2d(
                self.d_model, num_classes, kernel_size=1)

    @property
    def main_device(self):
        """ Helper to get a device for the model. """
        for key, item in self.state_dict().items():
            return item.device

    def tokenize_inputs(self, item: Dict):
        """
        Process a single batch item's heterogeneous sequence into a flat list
        if tokens for the encoder and decoder.
        """
        device = self.device

        input_sequence = []
        for input_item in item['inputs']:
            stem = self.stems[input_item['sensor_code']][input_item['channel_code']]
            out = stem(input_item['data'])
            tokens = out.view(self.d_model, -1).T
            input_sequence.append(tokens)

        output_sequence = []
        for output_item in item['outputs']:
            shape = tuple(output_item['dims']) + (self.d_model,)
            tokens = torch.rand(shape, device=device).view(-1, self.d_model)
            output_sequence.append(tokens)
        if len(input_sequence) == 0 or len(output_sequence) == 0:
            return None, None
        in_tokens = torch.concat(input_sequence, dim=0)
        out_tokens = torch.concat(output_sequence, dim=0)
        return in_tokens, out_tokens

    def forward(self, batch: List[Dict]) -> List[Dict]:
        """
        Runs prediction on multiple batch items. The input is assumed to an
        uncollated list of dictionaries, each containing information about some
        heterogeneous sequence. The output is a corresponding list of
        dictionaries containing the logits for each head.
        """
        batch_in_tokens = []
        batch_out_tokens = []

        given_batch_size = len(batch)
        valid_batch_indexes = []

        # Prepopulate an output for each input
        batch_logits = [{} for _ in range(given_batch_size)]

        # Handle heterogeneous style inputs on a per-item level
        for batch_idx, item in enumerate(batch):
            in_tokens, out_tokens = self.tokenize_inputs(item)
            if in_tokens is not None:
                valid_batch_indexes.append(batch_idx)
                batch_in_tokens.append(in_tokens)
                batch_out_tokens.append(out_tokens)

        # Some batch items might not be valid
        valid_batch_size = len(valid_batch_indexes)
        if not valid_batch_size:
            # No inputs were valid
            return batch_logits

        # Pad everything into a batch to be more efficient
        padding_value = -9999.0
        input_seqs = nn.utils.rnn.pad_sequence(
            batch_in_tokens,
            batch_first=True,
            padding_value=padding_value,
        )
        output_seqs = nn.utils.rnn.pad_sequence(
            batch_out_tokens,
            batch_first=True,
            padding_value=padding_value,
        )

        input_masks = input_seqs[..., 0] > padding_value
        output_masks = output_seqs[..., 0] > padding_value
        input_seqs[~input_masks] = 0.
        output_seqs[~output_masks] = 0.

        decoded = self.backbone(
            src=input_seqs,
            tgt=output_seqs,
            src_key_padding_mask=~input_masks,
            tgt_key_padding_mask=~output_masks,
        )
        B = valid_batch_size
        # Note output h/w is hardcoded here and uses the fact that the mwe only
        # has one task; could be generalized.
        oh, ow = 3, 3
        decoded_features = decoded.view(B, -1, oh, ow, self.d_model)
        decoded_masks = output_masks.view(B, -1, oh, ow)

        # Reconstruct outputs corresponding to the inputs
        for batch_idx, feat, mask in zip(valid_batch_indexes, decoded_features, decoded_masks):
            item_feat = feat[mask].view(-1, oh, ow, self.d_model).permute(0, 3, 1, 2)
            item_logits = batch_logits[batch_idx]
            for head_name, head_layer in self.heads.items():
                head_logits = head_layer(item_feat)
                item_logits[head_name] = head_logits
        return batch_logits

    def forward_step(self, batch: List[Dict], with_loss=False, stage='unspecified'):
        """
        Generic forward step used for test / train / validation
        """
        batch_logits : List[Dict] = self.forward(batch)
        outputs = {}
        outputs['logits'] = batch_logits

        if with_loss:
            losses = []
            valid_batch_size = 0
            for item, item_logits in zip(batch, batch_logits):
                if len(item_logits):
                    valid_batch_size += 1
                for head_name, head_logits in item_logits.items():
                    head_target = torch.stack([label['data'] for label in item['labels'] if label['head'] == head_name], dim=0)
                    # dummy loss function
                    head_loss = torch.nn.functional.mse_loss(head_logits, head_target)
                    losses.append(head_loss)
            total_loss = sum(losses) if len(losses) > 0 else None
            if total_loss is not None:
                self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=valid_batch_size, sync_dist=True)
            outputs['loss'] = total_loss

        return outputs

    def training_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='train')
        if outputs['loss'] is None:
            return None
        return outputs

    def validation_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='val')
        return outputs

    def test_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='test')
        return outputs


class MWE_Dataset(Dataset):
    """
    A dataset that produces heterogeneous outputs

    Example:
        >>> self = MWE_Dataset()
        >>> self[0]
    """
    def __init__(self, max_items_per_epoch=100):
        super().__init__()
        self.max_items_per_epoch = max_items_per_epoch
        self.rng = np.random
        self.dataset_stats =  {
            'known_modalities': [
                {'sensor': 'sensor1', 'channels': 'rgb', 'num_bands': 3, 'dims': (23, 23)},
            ],
            'known_tasks': [
                {'name': 'class', 'classes': ['a', 'b', 'c', 'd', 'e'], 'dims': (3, 3)},
            ]
        }

    def __len__(self):
        return self.max_items_per_epoch

    def __getitem__(self, index) -> Dict:
        """
        Returns:
            Dict: containing
                * inputs - a list of observations
                * outputs - a list of what we want to predict
                * labels - ground truth if we have it
        """
        inputs = []
        outputs = []
        labels = []
        max_timesteps_per_item = 5
        num_frames = max_timesteps_per_item
        p_drop_input = 0

        for frame_index in range(num_frames):
            had_input = 0
            # In general we may have any number of observations per frame
            for modality in self.dataset_stats['known_modalities']:
                sensor = modality['sensor']
                channels = modality['channels']
                c = modality['num_bands']
                h, w = modality['dims']

                # Randomly include each sensorchan on each frame
                if self.rng.rand() >= p_drop_input:
                    had_input = 1
                    inputs.append({
                        'type': 'input',
                        'channel_code': channels,
                        'sensor_code': sensor,
                        'frame_index': frame_index,
                        'data': torch.rand(c, h, w),
                    })
            if had_input:
                for task_info in self.dataset_stats['known_tasks']:
                    task = task_info['name']
                    oh, ow = task_info['dims']
                    oc = len(task_info['classes'])
                    outputs.append({
                        'type': 'output',
                        'head': task,
                        'frame_index': frame_index,
                        'dims': (oh, ow),
                    })
                    labels.append({
                        'type': 'label',
                        'head': task,
                        'frame_index': frame_index,
                        'data': torch.rand(oc, oh, ow),
                    })
        item = {
            'inputs': inputs,
            'outputs': outputs,
            'labels': labels,
        }
        return item

    def make_loader(self, batch_size=1, num_workers=0, shuffle=False,
                    pin_memory=False):
        """
        Create a dataloader option with sensible defaults for the problem
        """
        loader = torch.utils.data.DataLoader(
            self, batch_size=batch_size, num_workers=num_workers,
            shuffle=shuffle, pin_memory=pin_memory,
            collate_fn=lambda x: x
        )
        return loader


class MWE_Datamodule(pl.LightningDataModule):
    def __init__(self, batch_size=1, num_workers=0, max_items_per_epoch=100):
        super().__init__()
        self.save_hyperparameters()
        self.torch_datasets = {}
        self.dataset_stats = None
        self.dataset_kwargs = {
            'max_items_per_epoch': max_items_per_epoch,
        }
        self._did_setup = False

    def setup(self, stage):
        if self._did_setup:
            return
        self.torch_datasets['train'] = MWE_Dataset(**self.dataset_kwargs)
        self.torch_datasets['test'] = MWE_Dataset(**self.dataset_kwargs)
        self.torch_datasets['vali'] = MWE_Dataset(**self.dataset_kwargs)
        self.dataset_stats = self.torch_datasets['train'].dataset_stats
        self._did_setup = True
        print('Setup MWE_Datamodule')
        print(self.__dict__)

    def train_dataloader(self):
        return self._make_dataloader('train', shuffle=True)

    def val_dataloader(self):
        return self._make_dataloader('vali', shuffle=False)

    def test_dataloader(self):
        return self._make_dataloader('test', shuffle=False)

    @property
    def train_dataset(self):
        return self.torch_datasets.get('train', None)

    @property
    def test_dataset(self):
        return self.torch_datasets.get('test', None)

    @property
    def vali_dataset(self):
        return self.torch_datasets.get('vali', None)

    def _make_dataloader(self, stage, shuffle=False):
        loader = self.torch_datasets[stage].make_loader(
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )
        return loader


class MWE_LightningCLI(LightningCLI):
    """
    Customized LightningCLI to ensure the expected model inputs / outputs are
    coupled with the what the dataset is able to provide.
    """

    def add_arguments_to_parser(self, parser):
        def data_value_getter(key):
            # Hack to call setup on the datamodule before linking args
            def get_value(data):
                if not data._did_setup:
                    data.setup('fit')
                return getattr(data, key)
            return get_value
        # pass dataset stats to model after datamodule initialization
        parser.link_arguments(
            "data",
            "model.dataset_stats",
            compute_fn=data_value_getter('dataset_stats'),
            apply_on="instantiate")
        super().add_arguments_to_parser(parser)


def main():
    MWE_LightningCLI(
        model_class=MWE_Model,
        datamodule_class=MWE_Datamodule,
    )


if __name__ == '__main__':
    """
    CommandLine:
        cd ~/code/geowatch/dev/mwe/

    """
    main()

Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:

class MWE_LightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        def data_value_getter(key):
            # Hack to call setup on the datamodule before linking args
            def get_value(data):
                if not data._did_setup:
                    data.setup('fit')
                return getattr(data, key)
            return get_value
        # pass dataset stats to model after datamodule initialization
        parser.link_arguments(
            "data",
            "model.dataset_stats",
            compute_fn=data_value_getter('dataset_stats'),
            apply_on="instantiate")
        super().add_arguments_to_parser(parser)
class MWE_Model(pl.LightningModule):
    def __init__(self, sorting=False, dataset_stats=None, d_model=16):
        super().__init__()
        self.save_hyperparameters()
    ...

Given the above script saved as lightning_cli_save_hyperaparams_error_on_link_args.py, I invoke it as:

DEFAULT_ROOT_DIR=./mwe_train_dir

python lightning_cli_save_hyperaparams_error_on_link_args.py fit --config "
    model:
        sorting: True
    data:
        num_workers: 8
        batch_size: 2
        max_items_per_epoch: 200
    optimizer:
      class_path: torch.optim.Adam
      init_args:
        lr: 1e-7
    trainer:
      default_root_dir     : $DEFAULT_ROOT_DIR
      accelerator          : gpu
      devices              : 1
      max_epochs: 100
"

CKPT_FPATH=$(python -c "import pathlib; print(list(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/checkpoints/*.ckpt'))[0])")
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"

Error messages and logs

When using pytorch_lightning 2.2.5, running:

        HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
        cat "$HPARAM_FPATH"

Correctly prints hyparams that include the dataset_stats linked arguments.

sorting: true
dataset_stats:
  known_modalities:
  - sensor: sensor1
    channels: rgb
    num_bands: 3
    dims:
    - 23
    - 23
  known_tasks:
  - name: class
    classes:
    - a
    - b
    - c
    - d
    - e
    dims:
    - 3
    - 3
d_model: 16
batch_size: 2
num_workers: 8
max_items_per_epoch: 200

But on the latest master branch and 2.4.0 it incorrectly prints:

sorting: true
d_model: 16
_instantiator: pytorch_lightning.cli.instantiate_module
batch_size: 2
num_workers: 8
max_items_per_epoch: 200

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3090
      • NVIDIA GeForce RTX 3090
    • available: True
    • version: 12.4
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.2
    • perceiver-pytorch: 0.8.3
    • performer-pytorch: 1.0.11
    • pytorch-lightning: 2.4.0
    • pytorch-msssim: 0.1.5
    • pytorch-ranger: 0.1.1
    • reformer-pytorch: 1.4.3
    • torch: 2.4.0+cu124
    • torch-liberator: 0.2.2
    • torch-optimizer: 0.1.0
    • torchaudio: 2.4.0+cu124
    • torchmetrics: 0.11.0
    • torchvision: 0.19.0
  • Packages:
    • absl-py: 1.4.0
    • accelerate: 0.30.1
    • addict: 2.4.0
    • affine: 2.3.0
    • aiobotocore: 2.5.4
    • aiohttp: 3.9.5
    • aiohttp-retry: 2.8.3
    • aioitertools: 0.11.0
    • aiosignal: 1.3.1
    • alabaster: 0.7.16
    • albumentations: 1.0.0
    • amqp: 5.2.0
    • annotated-types: 0.7.0
    • antlr4-python3-runtime: 4.9.3
    • anyio: 4.6.0
    • anytree: 2.12.1
    • appdirs: 1.4.4
    • argcomplete: 3.5.0
    • argo-workflows: 6.5.6
    • arrow: 1.3.0
    • asciitree: 0.3.3
    • astor: 0.8.1
    • astroid: 3.2.2
    • asttokens: 2.4.1
    • astunparse: 1.6.3
    • asyncssh: 2.14.2
    • atomicwrites: 1.4.0
    • atpublic: 4.1.0
    • attrs: 23.2.0
    • auditwheel: 6.1.0
    • autobahn: 24.4.2
    • autodocsumm: 0.2.13
    • automat: 22.10.0
    • autopep8: 2.0.0
    • axial-positional-embedding: 0.2.1
    • babel: 2.15.0
    • backports.tarfile: 1.2.0
    • baron: 0.10.1
    • bashlex: 0.18
    • bcrypt: 4.1.3
    • beautifulsoup4: 4.12.3
    • bidict: 0.23.1
    • billiard: 4.2.0
    • black: 24.4.2
    • blake3: 0.3.1
    • bleach: 6.1.0
    • blinker: 1.8.2
    • boto: 2.49.0
    • boto3: 1.28.17
    • botocore: 1.31.17
    • bpytop: 1.0.68
    • bracex: 2.4
    • brotli: 1.1.0
    • build: 1.2.2
    • cachecontrol: 0.14.0
    • cachetools: 5.4.0
    • celery: 5.4.0
    • certifi: 2024.2.2
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • chardet: 5.2.0
    • charset-normalizer: 2.0.12
    • chromecontroller: 0.3.26
    • cibuildwheel: 2.21.0
    • cleo: 2.1.0
    • click: 8.1.7
    • click-didyoumean: 0.3.1
    • click-plugins: 1.1.1
    • click-repl: 0.3.0
    • cligj: 0.7.2
    • cloudpickle: 3.0.0
    • cmake: 3.29.3
    • cmd-queue: 0.1.21
    • codecarbon: 2.2.4
    • colorama: 0.4.6
    • colormath: 3.0.0
    • colt5-attention: 0.10.20
    • comm: 0.2.2
    • commonmark: 0.9.1
    • configargparse: 1.7
    • configobj: 5.0.8
    • constantly: 23.10.4
    • contourpy: 1.2.1
    • coverage: 7.4.3
    • crashtest: 0.4.1
    • cryptography: 42.0.7
    • cssutils: 2.10.2
    • cycler: 0.12.1
    • cython: 0.29.34
    • dask: 2023.8.1
    • dataframe-image: 0.1.13
    • dataproperty: 1.0.1
    • dbus-python: 1.3.2
    • debugpy: 1.8.2
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • delayed-image: 0.3.2
    • delorean: 1.0.0
    • detectron2: 0.6
    • diceware: 0.10
    • dictdiffer: 0.9.0
    • diskcache: 5.6.3
    • distinctipy: 1.2.1
    • distlib: 0.3.8
    • distro: 1.9.0
    • docopt: 0.6.2
    • docstring-parser: 0.16
    • docutils: 0.20.1
    • dominate: 2.9.1
    • dpath: 2.1.6
    • dtool-ibeis: 1.1.2
    • dulwich: 0.22.1
    • dvc: 3.51.2
    • dvc-data: 3.15.1
    • dvc-http: 2.32.0
    • dvc-objects: 5.1.0
    • dvc-render: 1.0.2
    • dvc-s3: 3.2.0
    • dvc-ssh: 4.1.1
    • dvc-studio-client: 0.20.0
    • dvc-task: 0.4.0
    • einops: 0.6.0
    • entrypoints: 0.4
    • et-xmlfile: 1.1.0
    • executing: 2.0.1
    • faiss-cpu: 1.8.0
    • fasteners: 0.17.3
    • fastjsonschema: 2.19.1
    • filelock: 3.15.4
    • filterpy: 1.4.5
    • fiona: 1.8.22
    • fire: 0.4.0
    • flake8: 7.0.0
    • flask: 3.0.3
    • flask-basicauth: 0.2.0
    • flask-cors: 3.0.10
    • flask-socketio: 5.3.6
    • flatten-dict: 0.4.2
    • flexcache: 0.3
    • flexparser: 0.3.1
    • flufl.lock: 7.1.1
    • fonttools: 4.51.0
    • frozenlist: 1.4.1
    • fsspec: 2024.6.0
    • funcy: 2.0
    • futures-actors: 0.0.5
    • fuzzywuzzy: 0.18.0
    • fvcore: 0.1.5.post20221221
    • gdal: 3.5.2
    • geodatasets: 2023.12.0
    • geographiclib: 2.0
    • geojson: 3.0.1
    • geomet: 1.1.0
    • geopandas: 0.14.4
    • geopy: 2.4.1
    • geowatch: 0.18.4
    • gevent: 24.2.1
    • girder-client: 3.2.4.dev30+gcacd0e706
    • git-of-theseus: 0.3.4
    • git-python: 1.0.3
    • git-well: 0.2.1
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • google-api-core: 2.19.0
    • google-api-python-client: 2.130.0
    • google-auth: 2.29.0
    • google-auth-httplib2: 0.2.0
    • google-auth-oauthlib: 1.0.0
    • googleapis-common-protos: 1.63.0
    • grandalf: 0.8
    • graphid: 0.1.0
    • greenlet: 3.0.3
    • grpcio: 1.63.0
    • gto: 1.7.1
    • guitool-ibeis: 2.2.0
    • h11: 0.14.0
    • h3: 3.7.7
    • hardware: 0.31.0
    • hkdf: 0.0.3
    • html2image: 2.0.4.3
    • httpcore: 0.16.3
    • httplib2: 0.22.0
    • httpx: 0.23.3
    • huggingface-hub: 0.23.0
    • humanize: 4.8.0
    • hydra-core: 1.3.2
    • hyperlink: 21.0.0
    • ibeis: 2.3.2
    • identify: 2.6.0
    • idna: 3.7
    • ijson: 3.2.1
    • imageio: 2.34.1
    • imagesize: 1.4.1
    • importlib-metadata: 7.2.1
    • importlib-resources: 6.4.0
    • incremental: 24.7.2
    • iniconfig: 2.0.0
    • installer: 0.7.0
    • instant-rst: 0.9.9.1
    • iopath: 0.1.9
    • ipykernel: 6.29.5
    • ipython: 8.18.1
    • ipython-genutils: 0.2.0
    • isort: 5.13.2
    • iterable-io: 1.0.0
    • iterative-telemetry: 0.0.8
    • itk: 5.4.0
    • itk-core: 5.4.0
    • itk-filtering: 5.4.0
    • itk-io: 5.4.0
    • itk-numerics: 5.4.0
    • itk-registration: 5.4.0
    • itk-segmentation: 5.4.0
    • itsdangerous: 2.2.0
    • jaraco.classes: 3.4.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.2
    • jedi: 0.19.1
    • jeepney: 0.8.0
    • jellyfin-apiclient-python: 1.9.2
    • jellyfin-migrator: 0.0.0
    • jinja2: 3.1.4
    • jmespath: 1.0.1
    • joblib: 1.4.2
    • johnnydep: 1.20.4
    • jq: 1.7.0
    • jsonargparse: 4.32.1
    • jsonnet: 0.20.0
    • jsonpath: 0.82.2
    • jsonschema: 4.19.2
    • jsonschema-specifications: 2023.12.1
    • jupyter-client: 8.6.1
    • jupyter-core: 5.7.2
    • jupyterlab-pygments: 0.3.0
    • kafka-python: 2.0.2
    • keyring: 24.3.1
    • kiwisolver: 1.4.5
    • kombu: 5.3.7
    • kornia: 0.6.8
    • kornia-rs: 0.1.3
    • kubernetes: 29.0.0
    • kwalop: 0.1.0
    • kwarray: 0.6.19
    • kwcoco: 0.8.5
    • kwcoco-explorer: 0.0.1
    • kwgis: 0.1.1
    • kwimage: 0.10.1
    • kwimage-ext: 0.2.1
    • kwplot: 0.5.2
    • kwutil: 0.3.3
    • lark: 1.1.7
    • lark-cython: 0.0.15
    • lazy-loader: 0.3
    • levenshtein: 0.25.1
    • liberator: 0.1.0
    • lightning: 2.4.0
    • lightning-utilities: 0.11.2
    • line-profiler: 4.1.3
    • linkify-it-py: 2.0.3
    • lit: 18.1.4
    • livereload: 2.7.0
    • llvmlite: 0.42.0
    • local-attention: 1.9.1
    • locket: 1.0.0
    • lockfile: 0.12.2
    • logmatic-python: 0.1.7
    • lxml: 4.9.2
    • magic-wormhole: 0.14.0
    • markdown: 3.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • mathutf: 0.1.0
    • matplotlib: 3.8.2
    • matplotlib-inline: 0.1.7
    • maturin: 1.7.4
    • mbstrdecoder: 1.1.3
    • mccabe: 0.7.0
    • mdit-py-plugins: 0.4.1
    • mdurl: 0.1.2
    • mgrs: 1.4.6
    • mistune: 3.0.2
    • mkinit: 1.1.0
    • mmcv: 2.0.0
    • mmengine: 0.10.4
    • monai: 0.8.0
    • more-itertools: 8.12.0
    • mpmath: 1.3.0
    • msgpack: 1.0.8
    • multidict: 6.0.5
    • munch: 4.0.0
    • mutagen: 1.47.0
    • mypy: 1.10.0
    • mypy-extensions: 1.0.0
    • myst-parser: 3.0.1
    • nbclient: 0.10.0
    • nbconvert: 7.16.4
    • nbformat: 5.10.4
    • ndsampler: 0.7.9
    • nest-asyncio: 1.6.0
    • netharn: 0.6.2
    • networkx: 3.3
    • networkx-algo-common-subtree: 0.2.1
    • nh3: 0.2.18
    • nodeenv: 1.9.1
    • nrtk: 0.11.0
    • nrtk-explorer: 0.3.0
    • numba: 0.59.1
    • numcodecs: 0.13.0
    • numexpr: 2.8.4
    • numpy: 1.25.2
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cublas-cu12: 12.4.2.65
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-cupti-cu12: 12.4.99
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-nvrtc-cu12: 12.4.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cuda-runtime-cu12: 12.4.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cudnn-cu12: 9.1.0.70
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-cufft-cu12: 11.2.0.44
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-curand-cu12: 10.3.5.119
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusolver-cu12: 11.6.0.99
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-cusparse-cu12: 12.3.0.142
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nccl-cu12: 2.20.5
    • nvidia-nvjitlink-cu12: 12.4.99
    • nvidia-nvtx-cu11: 11.7.91
    • nvidia-nvtx-cu12: 12.4.99
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • openapi-python-client: 0.20.0
    • openapi-python-generator: 0.5.0
    • openapi-schema-pydantic: 1.2.4
    • opencv-python-headless: 4.10.0.84
    • openpyxl: 3.0.9
    • opentimestamps: 0.4.5
    • opentimestamps-client: 0.7.1
    • ordered-set: 4.1.0
    • orjson: 3.10.3
    • osmnx: 1.9.4
    • oyaml: 1.0
    • packaging: 24.1
    • pandas: 1.5.3
    • pandocfilters: 1.5.1
    • parse: 1.19.0
    • parso: 0.8.4
    • partd: 1.4.2
    • pathspec: 0.12.1
    • pathvalidate: 3.2.1
    • patsy: 0.5.6
    • pbr: 6.0.0
    • pendulum: 3.0.0
    • perceiver-pytorch: 0.8.3
    • performer-pytorch: 1.0.11
    • pexpect: 4.9.0
    • pillow: 10.3.0
    • pint: 0.24.3
    • pip: 24.2
    • pkginfo: 1.10.0
    • platformdirs: 3.11.0
    • plotly: 5.24.0
    • plottool-ibeis: 2.3.0
    • pls-dont-shadow-me: 1.0.0
    • pluggy: 1.5.0
    • pockets: 0.9.1
    • poetry: 1.8.3
    • poetry-core: 1.9.0
    • poetry-plugin-export: 1.8.0
    • pooch: 1.8.2
    • portalocker: 2.10.1
    • portion: 2.4.1
    • pre-commit: 3.8.0
    • prettytable: 3.11.0
    • product-key-memory: 0.2.2
    • progiter: 2.0.0
    • prometheus-client: 0.20.0
    • prompt-toolkit: 3.0.43
    • proto-plus: 1.23.0
    • protobuf: 4.25.3
    • psutil: 5.9.6
    • psycopg2-binary: 2.9.5
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • purepy-root-demo-pkg-lzcsutvo: 1.0.0
    • purepy-src-demo-pkg: 1.0.0
    • purepy-src-demo-pkg-dbrmcjpb: 1.0.0
    • purepy-src-demo-pkg-lzcsutvo: 1.0.0
    • py-cpuinfo: 9.0.0
    • pyasn1: 0.6.0
    • pyasn1-modules: 0.4.0
    • pybsm: 0.5.1
    • pycocotools: 2.0.7
    • pycodestyle: 2.11.1
    • pycparser: 2.22
    • pycryptodomex: 3.20.0
    • pydantic: 2.7.1
    • pydantic-core: 2.18.2
    • pydot: 2.0.0
    • pyelftools: 0.31
    • pyfiglet: 1.0.2
    • pyflakes: 3.2.0
    • pyflann-ibeis: 2.4.0
    • pygame: 2.6.0
    • pygit2: 1.15.0
    • pygments: 2.18.0
    • pygraphviz: 1.13
    • pygtrie: 2.5.0
    • pyhesaff: 2.1.1
    • pylatex: 0+untagged.769.gb48e8ec
    • pylatexenc: 3.0a29
    • pymongo: 3.13.0
    • pynacl: 1.5.0
    • pynmea2: 1.19.0
    • pynndescent: 0.5.12
    • pynvim: 0.5.0
    • pynvml: 11.5.0
    • pyo3-example: 0.1.0
    • pyopenssl: 24.1.0
    • pyparsing: 3.1.2
    • pyperclip: 1.8.2
    • pypistats: 1.6.0
    • pypng: 0.20220715.0
    • pypogo: 0.1.0
    • pyproj: 3.4.1
    • pyproject-api: 1.7.1
    • pyproject-hooks: 1.1.0
    • pyqrcode: 1.2.1
    • pyqt5: 5.15.10
    • pyqt5-qt5: 5.15.2
    • pyqt5-sip: 12.13.0
    • pyqtree: 1.0.0
    • pysocks: 1.7.1
    • pystac: 1.10.1
    • pystac-client: 0.8.1
    • pytablewriter: 1.2.0
    • pytest: 8.0.2
    • pytest-cov: 5.0.0
    • pytest-subtests: 0.13.1
    • python-bitcoinlib: 0.12.2
    • python-dateutil: 2.9.0.post1.dev3+g9eaa5de
    • python-engineio: 4.9.1
    • python-gitlab: 4.6.0
    • python-json-logger: 2.0.7
    • python-levenshtein: 0.25.1
    • python-slugify: 8.0.4
    • python-socketio: 5.11.3
    • pytimeparse: 1.1.8
    • pytorch-lightning: 2.4.0
    • pytorch-msssim: 0.1.5
    • pytorch-ranger: 0.1.1
    • pytz: 2024.1
    • pywavelets: 1.6.0
    • pyyaml: 6.0.1
    • pyzmq: 26.0.3
    • quantities: 0.15.0
    • rapidfuzz: 3.9.1
    • rasterio: 1.3.10
    • readme-renderer: 44.0
    • reconplogger: 4.16.1
    • redbaron: 0.9.2
    • referencing: 0.35.1
    • reformer-pytorch: 1.4.3
    • regex: 2024.5.10
    • requests: 2.32.2
    • requests-oauthlib: 2.0.0
    • requests-toolbelt: 1.0.0
    • responses: 0.25.3
    • rfc3986: 1.5.0
    • rgd-client: 0.2.7
    • rgd-imagery-client: 0.2.7
    • rich: 12.5.1
    • rich-argparse: 1.1.0
    • rpds-py: 0.18.1
    • rply: 0.7.8
    • rsa: 4.9
    • rtree: 1.0.1
    • ruamel.yaml: 0.17.32
    • ruamel.yaml.clib: 0.2.8
    • ruff: 0.4.5
    • ruyaml: 0.91.0
    • s3fs: 2024.6.0
    • s3transfer: 0.6.2
    • s5cmd: 0.2.0
    • safer: 4.12.3
    • safetensors: 0.4.3
    • scikit-build: 0.17.6
    • scikit-image: 0.21.0
    • scikit-learn: 1.5.1
    • scipy: 1.14.0
    • scmrepo: 3.3.5
    • scriptconfig: 0.7.16
    • seaborn: 0.13.2
    • secretstorage: 3.3.3
    • semver: 3.0.2
    • service-identity: 24.1.0
    • setuptools: 67.7.2
    • shapely: 2.0.1
    • shellingham: 1.5.4
    • shitspotter: 0.0.1
    • shortuuid: 1.0.13
    • shtab: 1.7.1
    • simple-dvc: 0.2.2
    • simple-websocket: 1.0.0
    • simpleitk: 2.3.1
    • simplejson: 3.19.2
    • simplekml: 1.3.3
    • six: 1.16.0
    • smartflow: 3.1.3
    • smmap: 5.0.1
    • smqtk-classifier: 0.19.0
    • smqtk-core: 0.19.0
    • smqtk-dataprovider: 0.18.0
    • smqtk-descriptors: 0.19.0
    • smqtk-detection: 0.20.1
    • smqtk-image-io: 0.17.1
    • smqtk-indexing: 0.18.0
    • smqtk-iqr: 0.15.1
    • smqtk-relevancy: 0.17.0
    • sniffio: 1.3.1
    • snowballstemmer: 2.2.0
    • snuggs: 1.4.7
    • sortedcontainers: 2.4.0
    • soupsieve: 2.5
    • spake2: 0.8
    • sphinx: 7.3.7
    • sphinx-autoapi: 3.1.1
    • sphinx-autobuild: 2024.4.16
    • sphinx-autodoc-typehints: 2.3.0
    • sphinx-reredirects: 0.1.3
    • sphinx-rtd-theme: 2.0.0
    • sphinxcontrib-applehelp: 1.0.8
    • sphinxcontrib-devhelp: 1.0.6
    • sphinxcontrib-htmlhelp: 2.0.5
    • sphinxcontrib-jquery: 4.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-napoleon: 0.7
    • sphinxcontrib-qthelp: 1.0.7
    • sphinxcontrib-serializinghtml: 1.1.10
    • sqlalchemy: 1.4.50
    • sqlalchemy-utils: 0.41.2
    • sqltrie: 0.11.0
    • sshfs: 2024.4.1
    • stack-data: 0.6.3
    • starlette: 0.37.2
    • statsmodels: 0.14.2
    • structlog: 24.2.0
    • sympy: 1.12
    • tabledata: 1.3.3
    • tabulate: 0.9.0
    • tcolorpy: 0.1.6
    • tempenv: 0.2.0
    • tenacity: 9.0.0
    • tensorboard: 2.14.0
    • tensorboard-data-server: 0.7.2
    • tensorrt-bindings: 8.6.1
    • tensorrt-cu12: 10.0.1
    • tensorrt-cu12-bindings: 10.0.1
    • tensorrt-cu12-libs: 10.0.1
    • tensorrt-libs: 8.6.1
    • termcolor: 2.4.0
    • text-unidecode: 1.3
    • textual: 0.1.18
    • threadpoolctl: 3.5.0
    • tifffile: 2024.5.22
    • timerit: 1.1.0
    • timezonefinder: 6.5.2
    • timm: 0.6.13
    • tinycss2: 1.3.0
    • tokenizers: 0.15.2
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.12.5
    • toolz: 0.12.1
    • torch: 2.4.0+cu124
    • torch-liberator: 0.2.2
    • torch-optimizer: 0.1.0
    • torchaudio: 2.4.0+cu124
    • torchmetrics: 0.11.0
    • torchvision: 0.19.0
    • tornado: 6.4
    • tox: 4.17.1
    • tqdm: 4.64.1
    • traitlets: 5.14.3
    • trame: 3.6.1
    • trame-client: 3.0.3
    • trame-plotly: 3.0.2
    • trame-quasar: 0.2.1
    • trame-server: 3.0.1
    • trame-vuetify: 2.7.0
    • transformers: 4.37.2
    • triton: 3.0.0
    • trove-classifiers: 2024.9.12
    • twine: 5.1.1
    • twisted: 24.3.0
    • txaio: 23.1.1
    • txtorcon: 23.11.0
    • typepy: 1.3.2
    • typer: 0.12.3
    • types-python-dateutil: 2.9.0.20240316
    • types-pyyaml: 6.0.12.20240808
    • types-requests: 2.32.0.20240907
    • types-setuptools: 70.0.0.20240524
    • typeshed-client: 2.5.1
    • typing-extensions: 4.11.0
    • tzdata: 2024.1
    • tzlocal: 5.2
    • ubelt: 1.3.6
    • uc-micro-py: 1.0.3
    • ujson: 5.6.0
    • umap-learn: 0.5.6
    • uncertainties: 3.2.2
    • uritemplate: 4.1.1
    • uritools: 4.0.2
    • urllib3: 1.26.20
    • utm: 0.7.0
    • utool: 2.2.0
    • uv: 0.3.4
    • uvicorn: 0.29.0
    • validators: 0.28.1
    • vimtk: 0.5.0
    • vine: 5.1.0
    • virtualenv: 20.26.3
    • voluptuous: 0.14.2
    • vtool-ibeis: 2.3.0
    • vtool-ibeis-ext: 0.1.1
    • watchfiles: 0.21.0
    • wcmatch: 8.5.2
    • wcwidth: 0.2.13
    • webencodings: 0.5.1
    • websocket-client: 1.8.0
    • websockets: 12.0
    • werkzeug: 3.0.4
    • wheel: 0.40.0
    • wimpy: 0.6
    • wrapt: 1.14.1
    • wslink: 2.0.4
    • wsproto: 1.2.0
    • xarray: 0.17.0
    • xcookie: 0.2.2
    • xdev: 1.5.2
    • xdoctest: 1.1.5
    • xinspect: 0.2.0
    • xmltodict: 0.12.0
    • xxhash: 3.4.1
    • yacs: 0.1.8
    • yapf: 0.40.2
    • yarl: 1.9.4
    • yt-dlp: 2024.8.6
    • zarr: 2.18.2
    • zc.lockfile: 3.0.post1
    • zipp: 3.18.1
    • zipstream-ng: 1.7.1
    • zope.event: 5.0
    • zope.interface: 6.4.post2
  • System:

More info

No response

@Erotemic Erotemic added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 30, 2024
@mauvilsa
Copy link
Contributor

mauvilsa commented Oct 5, 2024

Was the reproduction script supposed to be attached, but it isn't?

@Erotemic
Copy link
Contributor Author

Erotemic commented Oct 5, 2024

The MWE is in the details. You can click the arrow to expand it. For convenience here it is in a gist as well: https://gist.github.com/Erotemic/dfdbf192004486e9f108b0334dd7fdcd

@noamsgl
Copy link

noamsgl commented Oct 27, 2024

I am also affected by this issue... :(

@mauvilsa
Copy link
Contributor

Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your LightningCLI subclass:

    def _add_instantiators(self):
        pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

3 participants