You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As of lightning 2.3.0 save_hyperparameters no longer seems to respect linked arguments.
Based on my investigation this seems to be due to #18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff
What version are you seeing the problem on?
v2.3, v2.4, master
How to reproduce the bug
Save the following script as: lightning_cli_save_hyperaparams_error_on_link_args.py
importtorchimporttorch.nnfromtorchimportnnfromtorch.utils.dataimportDatasetimportnumpyasnpimportpytorch_lightningasplfrompytorch_lightning.cliimportLightningCLIfromtypingimportList, DictclassMWE_Model(pl.LightningModule):
""" Example: >>> dataset = MWE_Dataset() >>> self = MWE_Model(dataset_stats=dataset.dataset_stats) >>> batch = [dataset[i] for i in range(2)] >>> self.forward(batch) """def__init__(self, sorting=False, dataset_stats=None, d_model=16):
super().__init__()
self.save_hyperparameters()
ifdataset_statsisNone:
raiseValueError('must be given dataset stats')
self.d_model=d_modelself.dataset_stats=dataset_statsself.known_sensorchan= {
(mode['sensor'], mode['channels'], mode['num_bands'])
formodeinself.dataset_stats['known_modalities']
}
self.known_tasks=self.dataset_stats['known_tasks']
ifsorting:
self.known_sensorchan=sorted(self.known_sensorchan)
self.known_tasks=sorted(self.known_tasks, key=lambdat: t['name'])
# Construct stems based on the datasetself.stems=torch.nn.ModuleDict()
forsensor, channels, num_bandsinself.known_sensorchan:
ifsensornotinself.stems:
self.stems[sensor] =torch.nn.ModuleDict()
self.stems[sensor][channels] =torch.nn.Conv2d(num_bands, self.d_model, kernel_size=1)
# Backbone is small generic transformerself.backbone=torch.nn.Transformer(
d_model=self.d_model,
nhead=4,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
batch_first=True
)
# Construct heads based on the datasetself.heads=torch.nn.ModuleDict()
forhead_infoinself.known_tasks:
head_name=head_info['name']
head_classes=head_info['classes']
num_classes=len(head_classes)
self.heads[head_name] =torch.nn.Conv2d(
self.d_model, num_classes, kernel_size=1)
@propertydefmain_device(self):
""" Helper to get a device for the model. """forkey, iteminself.state_dict().items():
returnitem.devicedeftokenize_inputs(self, item: Dict):
""" Process a single batch item's heterogeneous sequence into a flat list if tokens for the encoder and decoder. """device=self.deviceinput_sequence= []
forinput_iteminitem['inputs']:
stem=self.stems[input_item['sensor_code']][input_item['channel_code']]
out=stem(input_item['data'])
tokens=out.view(self.d_model, -1).Tinput_sequence.append(tokens)
output_sequence= []
foroutput_iteminitem['outputs']:
shape=tuple(output_item['dims']) + (self.d_model,)
tokens=torch.rand(shape, device=device).view(-1, self.d_model)
output_sequence.append(tokens)
iflen(input_sequence) ==0orlen(output_sequence) ==0:
returnNone, Nonein_tokens=torch.concat(input_sequence, dim=0)
out_tokens=torch.concat(output_sequence, dim=0)
returnin_tokens, out_tokensdefforward(self, batch: List[Dict]) ->List[Dict]:
""" Runs prediction on multiple batch items. The input is assumed to an uncollated list of dictionaries, each containing information about some heterogeneous sequence. The output is a corresponding list of dictionaries containing the logits for each head. """batch_in_tokens= []
batch_out_tokens= []
given_batch_size=len(batch)
valid_batch_indexes= []
# Prepopulate an output for each inputbatch_logits= [{} for_inrange(given_batch_size)]
# Handle heterogeneous style inputs on a per-item levelforbatch_idx, iteminenumerate(batch):
in_tokens, out_tokens=self.tokenize_inputs(item)
ifin_tokensisnotNone:
valid_batch_indexes.append(batch_idx)
batch_in_tokens.append(in_tokens)
batch_out_tokens.append(out_tokens)
# Some batch items might not be validvalid_batch_size=len(valid_batch_indexes)
ifnotvalid_batch_size:
# No inputs were validreturnbatch_logits# Pad everything into a batch to be more efficientpadding_value=-9999.0input_seqs=nn.utils.rnn.pad_sequence(
batch_in_tokens,
batch_first=True,
padding_value=padding_value,
)
output_seqs=nn.utils.rnn.pad_sequence(
batch_out_tokens,
batch_first=True,
padding_value=padding_value,
)
input_masks=input_seqs[..., 0] >padding_valueoutput_masks=output_seqs[..., 0] >padding_valueinput_seqs[~input_masks] =0.output_seqs[~output_masks] =0.decoded=self.backbone(
src=input_seqs,
tgt=output_seqs,
src_key_padding_mask=~input_masks,
tgt_key_padding_mask=~output_masks,
)
B=valid_batch_size# Note output h/w is hardcoded here and uses the fact that the mwe only# has one task; could be generalized.oh, ow=3, 3decoded_features=decoded.view(B, -1, oh, ow, self.d_model)
decoded_masks=output_masks.view(B, -1, oh, ow)
# Reconstruct outputs corresponding to the inputsforbatch_idx, feat, maskinzip(valid_batch_indexes, decoded_features, decoded_masks):
item_feat=feat[mask].view(-1, oh, ow, self.d_model).permute(0, 3, 1, 2)
item_logits=batch_logits[batch_idx]
forhead_name, head_layerinself.heads.items():
head_logits=head_layer(item_feat)
item_logits[head_name] =head_logitsreturnbatch_logitsdefforward_step(self, batch: List[Dict], with_loss=False, stage='unspecified'):
""" Generic forward step used for test / train / validation """batch_logits : List[Dict] =self.forward(batch)
outputs= {}
outputs['logits'] =batch_logitsifwith_loss:
losses= []
valid_batch_size=0foritem, item_logitsinzip(batch, batch_logits):
iflen(item_logits):
valid_batch_size+=1forhead_name, head_logitsinitem_logits.items():
head_target=torch.stack([label['data'] forlabelinitem['labels'] iflabel['head'] ==head_name], dim=0)
# dummy loss functionhead_loss=torch.nn.functional.mse_loss(head_logits, head_target)
losses.append(head_loss)
total_loss=sum(losses) iflen(losses) >0elseNoneiftotal_lossisnotNone:
self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=valid_batch_size, sync_dist=True)
outputs['loss'] =total_lossreturnoutputsdeftraining_step(self, batch, batch_idx=None):
outputs=self.forward_step(batch, with_loss=True, stage='train')
ifoutputs['loss'] isNone:
returnNonereturnoutputsdefvalidation_step(self, batch, batch_idx=None):
outputs=self.forward_step(batch, with_loss=True, stage='val')
returnoutputsdeftest_step(self, batch, batch_idx=None):
outputs=self.forward_step(batch, with_loss=True, stage='test')
returnoutputsclassMWE_Dataset(Dataset):
""" A dataset that produces heterogeneous outputs Example: >>> self = MWE_Dataset() >>> self[0] """def__init__(self, max_items_per_epoch=100):
super().__init__()
self.max_items_per_epoch=max_items_per_epochself.rng=np.randomself.dataset_stats= {
'known_modalities': [
{'sensor': 'sensor1', 'channels': 'rgb', 'num_bands': 3, 'dims': (23, 23)},
],
'known_tasks': [
{'name': 'class', 'classes': ['a', 'b', 'c', 'd', 'e'], 'dims': (3, 3)},
]
}
def__len__(self):
returnself.max_items_per_epochdef__getitem__(self, index) ->Dict:
""" Returns: Dict: containing * inputs - a list of observations * outputs - a list of what we want to predict * labels - ground truth if we have it """inputs= []
outputs= []
labels= []
max_timesteps_per_item=5num_frames=max_timesteps_per_itemp_drop_input=0forframe_indexinrange(num_frames):
had_input=0# In general we may have any number of observations per frameformodalityinself.dataset_stats['known_modalities']:
sensor=modality['sensor']
channels=modality['channels']
c=modality['num_bands']
h, w=modality['dims']
# Randomly include each sensorchan on each frameifself.rng.rand() >=p_drop_input:
had_input=1inputs.append({
'type': 'input',
'channel_code': channels,
'sensor_code': sensor,
'frame_index': frame_index,
'data': torch.rand(c, h, w),
})
ifhad_input:
fortask_infoinself.dataset_stats['known_tasks']:
task=task_info['name']
oh, ow=task_info['dims']
oc=len(task_info['classes'])
outputs.append({
'type': 'output',
'head': task,
'frame_index': frame_index,
'dims': (oh, ow),
})
labels.append({
'type': 'label',
'head': task,
'frame_index': frame_index,
'data': torch.rand(oc, oh, ow),
})
item= {
'inputs': inputs,
'outputs': outputs,
'labels': labels,
}
returnitemdefmake_loader(self, batch_size=1, num_workers=0, shuffle=False,
pin_memory=False):
""" Create a dataloader option with sensible defaults for the problem """loader=torch.utils.data.DataLoader(
self, batch_size=batch_size, num_workers=num_workers,
shuffle=shuffle, pin_memory=pin_memory,
collate_fn=lambdax: x
)
returnloaderclassMWE_Datamodule(pl.LightningDataModule):
def__init__(self, batch_size=1, num_workers=0, max_items_per_epoch=100):
super().__init__()
self.save_hyperparameters()
self.torch_datasets= {}
self.dataset_stats=Noneself.dataset_kwargs= {
'max_items_per_epoch': max_items_per_epoch,
}
self._did_setup=Falsedefsetup(self, stage):
ifself._did_setup:
returnself.torch_datasets['train'] =MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['test'] =MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['vali'] =MWE_Dataset(**self.dataset_kwargs)
self.dataset_stats=self.torch_datasets['train'].dataset_statsself._did_setup=Trueprint('Setup MWE_Datamodule')
print(self.__dict__)
deftrain_dataloader(self):
returnself._make_dataloader('train', shuffle=True)
defval_dataloader(self):
returnself._make_dataloader('vali', shuffle=False)
deftest_dataloader(self):
returnself._make_dataloader('test', shuffle=False)
@propertydeftrain_dataset(self):
returnself.torch_datasets.get('train', None)
@propertydeftest_dataset(self):
returnself.torch_datasets.get('test', None)
@propertydefvali_dataset(self):
returnself.torch_datasets.get('vali', None)
def_make_dataloader(self, stage, shuffle=False):
loader=self.torch_datasets[stage].make_loader(
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
shuffle=shuffle,
pin_memory=True,
)
returnloaderclassMWE_LightningCLI(LightningCLI):
""" Customized LightningCLI to ensure the expected model inputs / outputs are coupled with the what the dataset is able to provide. """defadd_arguments_to_parser(self, parser):
defdata_value_getter(key):
# Hack to call setup on the datamodule before linking argsdefget_value(data):
ifnotdata._did_setup:
data.setup('fit')
returngetattr(data, key)
returnget_value# pass dataset stats to model after datamodule initializationparser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
defmain():
MWE_LightningCLI(
model_class=MWE_Model,
datamodule_class=MWE_Datamodule,
)
if__name__=='__main__':
""" CommandLine: cd ~/code/geowatch/dev/mwe/ """main()
Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:
classMWE_LightningCLI(LightningCLI):
defadd_arguments_to_parser(self, parser):
defdata_value_getter(key):
# Hack to call setup on the datamodule before linking argsdefget_value(data):
ifnotdata._did_setup:
data.setup('fit')
returngetattr(data, key)
returnget_value# pass dataset stats to model after datamodule initializationparser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your LightningCLI subclass:
Bug description
As of lightning 2.3.0
save_hyperparameters
no longer seems to respect linked arguments.Based on my investigation this seems to be due to #18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff
What version are you seeing the problem on?
v2.3, v2.4, master
How to reproduce the bug
Save the following script as:
lightning_cli_save_hyperaparams_error_on_link_args.py
Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:
Given the above script saved as
lightning_cli_save_hyperaparams_error_on_link_args.py
, I invoke it as:Error messages and logs
When using pytorch_lightning 2.2.5, running:
Correctly prints hyparams that include the
dataset_stats
linked arguments.But on the latest master branch and 2.4.0 it incorrectly prints:
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: