Skip to content

WeisonWEileen/mink_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mink_batch

GPU-parallel batched re-implementation of mink, with the same QP semantics and an API that mirrors mink almost 1:1. Built on MJX (MuJoCo-JAX) for vectorized FK/Jacobian and JAX for the rest. Inspired by pyroki for the batched-JAX patterns.

Status: pre-alpha, under construction (see Status / v1 scope).


What this is

Mink solves a differential IK QP:

$$\min_{\Delta q} \tfrac{1}{2}\Delta q^{\top} H \Delta q + c^{\top} \Delta q \quad \text{s.t.}\quad G \Delta q \le h$$

mink_batch solves B such QPs in parallel on a single GPU. Every public array gains a leading batch axis; the math is identical. You can take an existing mink script and roughly translate by:

mink (NumPy + MuJoCo C) mink_batch (JAX + MJX)
cfg = mink.Configuration(model) cfg = mink_batch.Configuration(model, q_init) where q_init.shape == (B, nq)
cfg.update(q) (shape (nq,)) cfg = cfg.update(q) (shape (B, nq), returns a new pytree)
mink.FrameTask(...) mink_batch.FrameTask(...) — target pose broadcastable to (B, 7) wxyz_xyz
mink.solve_ik(cfg, tasks, dt, "daqp") mink_batch.solve_ik(cfg, tasks, dt, solver="qpax") returns (B, nv)
Result v: (nv,) Result v: (B, nv)

Why

Mink's NumPy + single-instance design is great for one robot in real time, but bottlenecks at batch workloads (data generation, RL rollouts, batched trajectory optimization, multi-start IK seeding). pyroki showed that a JAX-based pipeline can solve 1000 Franka IK problems in ~14 ms on an RTX 4090 — but pyroki uses a soft-constraint LSQ formulation and URDF, dropping mink's QP hard constraints and MJCF compatibility.

mink_batch keeps mink's QP hard-constraint semantics and its MJCF model files, and adds pyroki-style batching and GPU acceleration.

Install (planned)

cd mink_parrallel
pip install -e .
# JAX + CUDA:
pip install -U "jax[cuda12]"

Hard runtime deps:

  • jax>=0.4, jaxlib
  • mujoco>=3.3.6 and mujoco-mjx
  • jaxlie (Lie groups)
  • jax_dataclasses (pytree dataclasses + @jdc.jit)
  • qpax (default batched QP solver)

Soft deps for testing/benchmarking: mink, pytest, pyroki.

Quickstart (B=64 UR5e)

Identical to the runnable test.py at the project root — copy-paste and python test.py:

import jax
import jax.numpy as jnp
import jaxlie
import mujoco
import numpy as np

import mink_batch as mb

mj_model = mujoco.MjModel.from_xml_path(
    "examples/assets/universal_robots_ur5e/ur5e.xml"
)
B = 64

# Batched home configuration (slightly off-zero to avoid the singular zero pose).
q0_single = jnp.asarray(mj_model.qpos0, dtype=jnp.float32) + 0.1
q0 = jnp.broadcast_to(q0_single, (B, mj_model.nq))
cfg = mb.Configuration.create(mj_model, q=q0)

# Sample B random target poses around the home end-effector pose.
home_pose = cfg._get_transform_frame_to_world_wxyz_xyz("attachment_site", "site")[0]
key = jax.random.PRNGKey(0)
twists = jax.random.normal(key, (B, 6)) * jnp.asarray(
    [0.05, 0.05, 0.05, 0.1, 0.1, 0.1], dtype=jnp.float32
)
targets = (jaxlie.SE3(home_pose) @ jaxlie.SE3.exp(twists)).wxyz_xyz  # (B, 7)

tasks = (
    mb.FrameTask.create(
        "attachment_site", "site",
        position_cost=1.0, orientation_cost=1.0, lm_damping=1e-6,
        transform_target_to_world=targets,
    ),
    mb.PostureTask.create(mj_model, cost=1e-3, target_q=q0_single),
)
velocities = {
    mj_model.joint(i).name: float(jnp.pi)
    for i in range(mj_model.njnt) if mj_model.joint(i).name
}
limits = (
    mb.ConfigurationLimit.create(mj_model),
    mb.VelocityLimit.create(mj_model, velocities=velocities),
)

dt = jnp.float32(0.01)
for step in range(100):
    vel = mb.solve_ik(cfg, tasks, dt, limits=limits, solver="qpax")
    cfg = cfg.integrate_inplace(vel, dt)

final_pose = np.array(cfg._get_transform_frame_to_world_wxyz_xyz("attachment_site", "site"))
pos_err = np.linalg.norm(final_pose[:, 4:] - np.array(targets)[:, 4:], axis=-1)
print(f"Final pos err (median/max mm): "
      f"{np.median(pos_err)*1e3:.2f} / {pos_err.max()*1e3:.2f}")

Notes for users translating from mink:

  • Use the .create(...) factory methods (Configuration.create, FrameTask.create, ...) — direct constructor calls on the underlying @jdc.pytree_dataclasses require all named fields.
  • tasks and limits must be tuples (not lists) so JIT structure is hashable / stable.
  • cfg.integrate(vel, dt) returns an (B, nq) array; cfg.integrate_inplace(vel, dt) returns a new Configuration.
  • VelocityLimit.create takes a velocities: {joint_name: max_vel} mapping, not a flat array.

Status / v1 scope

Implemented (v1 target):

  • Configuration (MJX-backed, batched FK/Jacobian, manifold-aware integrate/differentiate_pos)
  • FrameTask
  • PostureTask
  • DampingTask
  • ConfigurationLimit
  • VelocityLimit
  • solve_ik with qpax backend
  • Parity tests against mink at B=1 and B=8 (15/15 passing)
  • UR5e batched example (examples/arm_ur5e_batch.py)
  • GPU benchmark on RTX 5090 (see Benchmark section)

Not in v1 (planned for v2+):

  • CollisionAvoidanceLimit — needs MJX contact/distance plumbing
  • ComTask, RelativeFrameTask, EqualityConstraintTask, DofFreezingTask, KineticEnergyRegularizationTask
  • Closed-chain / loop-closure equality constraints
  • Alternative QP backends (jaxopt.OSQP, hand-rolled ADMM)

Differences from mink

These are intentional, documented behavioral deltas:

  1. limits is required. Mink defaults to [ConfigurationLimit(model)] when limits=None; we require an explicit tuple (even if empty) so JIT tracing is predictable. Pass limits=() to disable.
  2. tasks and limits must be tuples, not lists. JAX uses the container type as part of the pytree structure; lists cause unnecessary retracing.
  3. No safety_break. Pre-flight limit checking is not in v1; users can do it host-side.
  4. No implicit equality constraints. v1 does not support mj_makeConstraint / equality tasks; closed-chain models won't be respected. Tracked for v2.
  5. Configuration.update returns a new pytree, it does not mutate. JAX values are immutable.

Same-as-mink:

  • MJCF model files: identical (we use mink's examples directly).
  • Cost vectors: same length and semantics (e.g., FrameTask.cost is length 6, position then orientation).
  • Manifold-aware joint differentiation for free joints (we mirror mj_differentiatePos / mj_integratePos).

Benchmark

UR5e, 200 IK steps fused into a single jax.lax.scan and @jdc.jit'd (so JIT compile is amortized over the loop). qpax with max_iter=40, solver_tol=1e-4. Hardware: RTX 5090 (32 GB), CUDA 13.0. Run via python benchmarks/bench_batch.py.

B JIT (ms) wall, 200 steps (s) per call (ms) per problem (μs)
1 5198 0.136 0.68 681
16 5658 0.146 0.73 45.8
64 5927 0.145 0.73 11.4
256 5959 0.148 0.74 2.9
1024 6356 0.154 0.77 0.75

Per-call wall is nearly flat across B ∈ [1, 1024] — kernel launch and qpax PDIP cost dominate. Each problem at B=1024 costs 0.75 μs.

For reference: the pyroki paper reports ~14 ms per call (≈14 μs per problem) for Franka @ B=1000 on an RTX 4090. mink_batch is in the same regime; differences come from (1) UR5e being slightly smaller than Franka, (2) RTX 5090 ≫ 4090, and (3) qpax's PDIP vs jaxls's LM having different inner work.

CPU fallback (jax with no CUDA) on the same machine: B=32 takes ~240 ms per call — illustrating the 300x+ gap that motivated this project.

Roadmap

  • v1.0: above table green, parity tests passing.
  • v1.1: ComTask, RelativeFrameTask.
  • v2.0: CollisionAvoidanceLimit using MJX contact arrays; equality constraints; closed-chain support.
  • v2.x: alternative solvers (jaxopt.OSQP, ADMM, GPU-LCP) selectable per call.

MJX gotchas (validated Phase 0)

Findings from the API spike (benchmarks/spike_mjx_api.py):

  1. mjx.put_model rejects mink's scene.xml files because they include mjSENS_GEOMFROMTO sensors (used by the mocap viewer-style examples). Use the bare robot XML (e.g. ur5e.xml) or strip the sensor block. The bare UR5e files (ur5e.xml + meshes + scene_plain.xml) are vendored under examples/assets/universal_robots_ur5e/ (sourced from mink, BSD-3 license).
  2. mjx.jac returns (nv, 3), not (3, nv) like mj_jacSite/mj_jacBody. Transpose before concatenating into the 6×nv mink-shaped Jacobian.
  3. mjx.differentiate_pos and mjx.integrate_pos do not exist. We handwrite both, dispatching on model.jnt_type (free / ball / slide / hinge). Free + ball joints use jaxlie.SO3.exp/.log for the quaternion components.
  4. FK pipeline = mjx.kinematicsmjx.com_pos (matches mink's mj_kinematics + mj_comPos); mjx.fwd_position is the full pipeline but does more work than IK needs.
  5. Numerical agreement vs mj_jacSite: max abs delta ≈ 3e-7 on UR5e, well below our 1e-6 target.

References

License

Apache 2.0 (matches mink). See LICENSE once added.

About

Batch python inverse kinematics based on Mujoco(Batch version of mink)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages