GPU-parallel batched re-implementation of mink, with the same QP semantics and an API that mirrors mink almost 1:1. Built on MJX (MuJoCo-JAX) for vectorized FK/Jacobian and JAX for the rest. Inspired by pyroki for the batched-JAX patterns.
Status: pre-alpha, under construction (see Status / v1 scope).
Mink solves a differential IK QP:
mink_batch solves B such QPs in parallel on a single GPU. Every public array gains a leading batch axis; the math is identical. You can take an existing mink script and roughly translate by:
| mink (NumPy + MuJoCo C) | mink_batch (JAX + MJX) |
|---|---|
cfg = mink.Configuration(model) |
cfg = mink_batch.Configuration(model, q_init) where q_init.shape == (B, nq) |
cfg.update(q) (shape (nq,)) |
cfg = cfg.update(q) (shape (B, nq), returns a new pytree) |
mink.FrameTask(...) |
mink_batch.FrameTask(...) — target pose broadcastable to (B, 7) wxyz_xyz |
mink.solve_ik(cfg, tasks, dt, "daqp") |
mink_batch.solve_ik(cfg, tasks, dt, solver="qpax") returns (B, nv) |
Result v: (nv,) |
Result v: (B, nv) |
Mink's NumPy + single-instance design is great for one robot in real time, but bottlenecks at batch workloads (data generation, RL rollouts, batched trajectory optimization, multi-start IK seeding). pyroki showed that a JAX-based pipeline can solve 1000 Franka IK problems in ~14 ms on an RTX 4090 — but pyroki uses a soft-constraint LSQ formulation and URDF, dropping mink's QP hard constraints and MJCF compatibility.
mink_batch keeps mink's QP hard-constraint semantics and its MJCF model files, and adds pyroki-style batching and GPU acceleration.
cd mink_parrallel
pip install -e .
# JAX + CUDA:
pip install -U "jax[cuda12]"Hard runtime deps:
jax>=0.4,jaxlibmujoco>=3.3.6andmujoco-mjxjaxlie(Lie groups)jax_dataclasses(pytree dataclasses +@jdc.jit)qpax(default batched QP solver)
Soft deps for testing/benchmarking: mink, pytest, pyroki.
Identical to the runnable test.py at the project root — copy-paste and python test.py:
import jax
import jax.numpy as jnp
import jaxlie
import mujoco
import numpy as np
import mink_batch as mb
mj_model = mujoco.MjModel.from_xml_path(
"examples/assets/universal_robots_ur5e/ur5e.xml"
)
B = 64
# Batched home configuration (slightly off-zero to avoid the singular zero pose).
q0_single = jnp.asarray(mj_model.qpos0, dtype=jnp.float32) + 0.1
q0 = jnp.broadcast_to(q0_single, (B, mj_model.nq))
cfg = mb.Configuration.create(mj_model, q=q0)
# Sample B random target poses around the home end-effector pose.
home_pose = cfg._get_transform_frame_to_world_wxyz_xyz("attachment_site", "site")[0]
key = jax.random.PRNGKey(0)
twists = jax.random.normal(key, (B, 6)) * jnp.asarray(
[0.05, 0.05, 0.05, 0.1, 0.1, 0.1], dtype=jnp.float32
)
targets = (jaxlie.SE3(home_pose) @ jaxlie.SE3.exp(twists)).wxyz_xyz # (B, 7)
tasks = (
mb.FrameTask.create(
"attachment_site", "site",
position_cost=1.0, orientation_cost=1.0, lm_damping=1e-6,
transform_target_to_world=targets,
),
mb.PostureTask.create(mj_model, cost=1e-3, target_q=q0_single),
)
velocities = {
mj_model.joint(i).name: float(jnp.pi)
for i in range(mj_model.njnt) if mj_model.joint(i).name
}
limits = (
mb.ConfigurationLimit.create(mj_model),
mb.VelocityLimit.create(mj_model, velocities=velocities),
)
dt = jnp.float32(0.01)
for step in range(100):
vel = mb.solve_ik(cfg, tasks, dt, limits=limits, solver="qpax")
cfg = cfg.integrate_inplace(vel, dt)
final_pose = np.array(cfg._get_transform_frame_to_world_wxyz_xyz("attachment_site", "site"))
pos_err = np.linalg.norm(final_pose[:, 4:] - np.array(targets)[:, 4:], axis=-1)
print(f"Final pos err (median/max mm): "
f"{np.median(pos_err)*1e3:.2f} / {pos_err.max()*1e3:.2f}")Notes for users translating from mink:
- Use the
.create(...)factory methods (Configuration.create,FrameTask.create, ...) — direct constructor calls on the underlying@jdc.pytree_dataclasses require all named fields. tasksandlimitsmust be tuples (not lists) so JIT structure is hashable / stable.cfg.integrate(vel, dt)returns an(B, nq)array;cfg.integrate_inplace(vel, dt)returns a newConfiguration.VelocityLimit.createtakes avelocities: {joint_name: max_vel}mapping, not a flat array.
Implemented (v1 target):
-
Configuration(MJX-backed, batched FK/Jacobian, manifold-awareintegrate/differentiate_pos) -
FrameTask -
PostureTask -
DampingTask -
ConfigurationLimit -
VelocityLimit -
solve_ikwithqpaxbackend - Parity tests against mink at
B=1andB=8(15/15 passing) - UR5e batched example (
examples/arm_ur5e_batch.py) - GPU benchmark on RTX 5090 (see Benchmark section)
Not in v1 (planned for v2+):
CollisionAvoidanceLimit— needs MJX contact/distance plumbingComTask,RelativeFrameTask,EqualityConstraintTask,DofFreezingTask,KineticEnergyRegularizationTask- Closed-chain / loop-closure equality constraints
- Alternative QP backends (
jaxopt.OSQP, hand-rolled ADMM)
These are intentional, documented behavioral deltas:
limitsis required. Mink defaults to[ConfigurationLimit(model)]whenlimits=None; we require an explicit tuple (even if empty) so JIT tracing is predictable. Passlimits=()to disable.tasksandlimitsmust be tuples, not lists. JAX uses the container type as part of the pytree structure; lists cause unnecessary retracing.- No
safety_break. Pre-flight limit checking is not in v1; users can do it host-side. - No implicit equality constraints. v1 does not support
mj_makeConstraint/ equality tasks; closed-chain models won't be respected. Tracked for v2. Configuration.updatereturns a new pytree, it does not mutate. JAX values are immutable.
Same-as-mink:
- MJCF model files: identical (we use mink's examples directly).
- Cost vectors: same length and semantics (e.g.,
FrameTask.costis length 6, position then orientation). - Manifold-aware joint differentiation for free joints (we mirror
mj_differentiatePos/mj_integratePos).
UR5e, 200 IK steps fused into a single jax.lax.scan and @jdc.jit'd (so JIT compile is amortized over the loop). qpax with max_iter=40, solver_tol=1e-4. Hardware: RTX 5090 (32 GB), CUDA 13.0. Run via python benchmarks/bench_batch.py.
| B | JIT (ms) | wall, 200 steps (s) | per call (ms) | per problem (μs) |
|---|---|---|---|---|
| 1 | 5198 | 0.136 | 0.68 | 681 |
| 16 | 5658 | 0.146 | 0.73 | 45.8 |
| 64 | 5927 | 0.145 | 0.73 | 11.4 |
| 256 | 5959 | 0.148 | 0.74 | 2.9 |
| 1024 | 6356 | 0.154 | 0.77 | 0.75 |
Per-call wall is nearly flat across B ∈ [1, 1024] — kernel launch and qpax PDIP cost dominate. Each problem at B=1024 costs 0.75 μs.
For reference: the pyroki paper reports ~14 ms per call (≈14 μs per problem) for Franka @ B=1000 on an RTX 4090. mink_batch is in the same regime; differences come from (1) UR5e being slightly smaller than Franka, (2) RTX 5090 ≫ 4090, and (3) qpax's PDIP vs jaxls's LM having different inner work.
CPU fallback (jax with no CUDA) on the same machine: B=32 takes ~240 ms per call — illustrating the 300x+ gap that motivated this project.
- v1.0: above table green, parity tests passing.
- v1.1:
ComTask,RelativeFrameTask. - v2.0:
CollisionAvoidanceLimitusing MJX contact arrays; equality constraints; closed-chain support. - v2.x: alternative solvers (
jaxopt.OSQP, ADMM, GPU-LCP) selectable per call.
Findings from the API spike (benchmarks/spike_mjx_api.py):
mjx.put_modelrejects mink'sscene.xmlfiles because they includemjSENS_GEOMFROMTOsensors (used by the mocap viewer-style examples). Use the bare robot XML (e.g.ur5e.xml) or strip the sensor block. The bare UR5e files (ur5e.xml+ meshes +scene_plain.xml) are vendored underexamples/assets/universal_robots_ur5e/(sourced from mink, BSD-3 license).mjx.jacreturns(nv, 3), not(3, nv)likemj_jacSite/mj_jacBody. Transpose before concatenating into the 6×nv mink-shaped Jacobian.mjx.differentiate_posandmjx.integrate_posdo not exist. We handwrite both, dispatching onmodel.jnt_type(free / ball / slide / hinge). Free + ball joints usejaxlie.SO3.exp/.logfor the quaternion components.- FK pipeline =
mjx.kinematics→mjx.com_pos(matches mink'smj_kinematics+mj_comPos);mjx.fwd_positionis the full pipeline but does more work than IK needs. - Numerical agreement vs
mj_jacSite: max abs delta ≈ 3e-7 on UR5e, well below our 1e-6 target.
- mink: https://github.com/kevinzakka/mink
- pyroki: https://github.com/chungmin99/pyroki
- mujoco-mjx: https://mujoco.readthedocs.io/en/stable/mjx.html
- qpax: https://github.com/kevin-tracy/qpax
Apache 2.0 (matches mink). See LICENSE once added.