Skip to content

Refactor act#68

Merged
alexander-soare merged 28 commits into
huggingface:user/rcadene/2024_03_31_remove_torchrlfrom
alexander-soare:refactor_act
Apr 9, 2024
Merged

Refactor act#68
alexander-soare merged 28 commits into
huggingface:user/rcadene/2024_03_31_remove_torchrlfrom
alexander-soare:refactor_act

Conversation

@alexander-soare

@alexander-soare alexander-soare commented Apr 8, 2024

Copy link
Copy Markdown
Contributor

Notes for reviewers:

  • I suggest primarily reading lerobot/common/policies/act.py from scratch (not in diff mode), as it's almost a complete rewrite.
  • There are some TODO(now) which need to be resolved before merging remove torchrl into the main branch.
  • Keep in mind while reviewing, there are still things to be done for the refactor but they are being deferred to near future PRs:
  • You should be able to use DATA_DIR=data python lerobot/scripts/eval.py --hub-id lerobot/act_aloha_transfer_cube_human-original_repo eval_episodes=1 rollout_batch_size=1 to evaluate the ported weights from the original repo.
  • So far, I've verified that training sim_transfer_cube_human can be trained to match/surpass the original weights.

Still left to do

  • Reproduce original repo's eval score with original repo's weights (with torchrl)
  • Reproduce original repo's eval score with original repo's weights (using new Aloha Env)
    • Make sure this is reproducible with scripts/configs. Upload converted weights, conversion script, converted stats, stats conversion script, to hub.
  • Train models for one human and one sim dataset, reproducing original results.
    • sim_transfer_cube_human
    • sim_insertion_scripted

Train on LeRobot with:

export DATA_DIR=data

python lerobot/scripts/train.py \
    hydra.job.name=act_aloha_sim_insertion_scripted \
    env=aloha \
    env.task=sim_insertion \
    dataset_id=aloha_sim_insertion_scripted \
    policy=act \
    log_freq=50 \
    eval_freq=2500 \
    rollout_batch_size=20 \
    eval_episodes=20 \
    policy.grad_clip_norm=100 \
    policy.use_vae=true \
    horizon=100 \
    wandb.enable=true \
    hydra.run.dir=outputs/train/act_aloha_sim_insertion_scripted \
    device=cuda \
    offline_steps=80000 \
    prefetch=4 \
    save_model=true \
    save_freq=5000 \

image

@alexander-soare alexander-soare marked this pull request as draft April 8, 2024 08:02
@alexander-soare alexander-soare marked this pull request as ready for review April 8, 2024 12:16
@alexander-soare alexander-soare changed the title [WIP] Refactor act Refactor act Apr 8, 2024
@alexander-soare alexander-soare changed the base branch from main to user/rcadene/2024_03_31_remove_torchrl April 8, 2024 12:18
@alexander-soare alexander-soare requested a review from Cadene April 8, 2024 12:29
Comment thread lerobot/common/policies/act/policy.py
Comment thread lerobot/configs/policy/act.yaml
Comment thread lerobot/common/policies/act/policy.py Outdated
Comment on lines +574 to +578
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
value=encoder_out,
)[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why [0] ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

Comment thread lerobot/common/policies/act/policy.py Outdated
if self.normalize_before:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why [0]?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

Comment thread lerobot/common/policies/act/policy.py Outdated
Returns:
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
"""
not_mask = torch.ones_like(x[0, [0]]) # (1, H, W)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why x[0, [0]]? Could we do something more readable?

@alexander-soare alexander-soare Apr 8, 2024

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good shout. This is what I normally do (see revision). Is that more readable for you? Otherwise, I need to do ones and get the dtype and device.

Comment thread lerobot/configs/policy/act.yaml Outdated
Comment thread lerobot/scripts/train.py Outdated
Comment thread lerobot/scripts/train.py
Comment thread lerobot/scripts/train.py Outdated
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
num_workers=0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove before merging no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep sorry. Btw IMO this goes in config.

Comment thread lerobot/scripts/eval.py
Comment thread lerobot/common/policies/act/policy.py
@alexander-soare

Copy link
Copy Markdown
Contributor Author

@Cadene many thanks for the review. Bty

@Cadene Cadene left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR :)
I pushed some changes to user/rcadene/2024_03_31_remove_torchrl.
In particular, I passed the test_policies.py with Aloha/Act.
You will need to solve some non-trivial merge issues.
Dont hesitate to call me so that we solve them together.
Thanks!

@alexander-soare alexander-soare merged commit 07c28a2 into huggingface:user/rcadene/2024_03_31_remove_torchrl Apr 9, 2024
@alexander-soare alexander-soare deleted the refactor_act branch April 9, 2024 17:55
menhguin pushed a commit to menhguin/lerobot that referenced this pull request Feb 9, 2025
Kalcy-U referenced this pull request in Kalcy-U/lerobot May 13, 2025
ZoreAnuj pushed a commit to luckyrobots/lerobot that referenced this pull request Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants