Skip to content

luke-a-thompson/Stochastax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stochastax

A Jax library for advanced stochastic analysis.

PyPI version Python 3.12+ License: MIT

Features

  • Fast: Built on JAX with JIT compilation for maximum performance
  • Flexible: Supports both path signatures and log signatures
  • GPU Support: Leverages JAX's GPU acceleration when available

Installation

pip install quicksig

For GPU support (CUDA 12):

pip install quicksig[cuda]
uv sync --extra cuda

For development:

pip install quicksig[dev]
uv sync --extra dev

For development:

pip install quicksig[all]
uv sync --all-extras

Run Benchmarks

  uv run pytest --benchmark-only --benchmark-autosave

Quick Start

import jax.numpy as jnp
from quicksig import get_signature, get_log_signature

# Create a simple 2D path
path = jnp.array([[0.0, 0.0], [1.0, 1.0], [2.0, 0.0]])

# Compute path signature up to depth 3
signature = get_signature(path, depth=3)
print(f"Signature shape: {signature.shape}")

# Compute log signature
log_sig = get_log_signature(path, depth=3, log_signature_type="lyndon")
print(f"Log signature shape: {log_sig.shape}")

Batch Processing

# Process multiple paths at once
batch_paths = jnp.array([
    [[0.0, 0.0], [1.0, 1.0], [2.0, 0.0]],
    [[0.0, 0.0], [1.0, -1.0], [2.0, 0.0]]
])

# Compute signatures for all paths
batch_signatures = jax.vmap(get_signature, in_axes=(0, None, None))(batch_paths, 2, False)

API Reference

get_signature(path, depth, stream=False)

Compute the signature of a path or batch of paths.

Parameters:

  • path (jax.Array): Input path(s) of shape (length, dim) for single path or (batch, length, dim) for batch
  • depth (int): Maximum signature depth to compute
  • stream (bool): Whether to compute streaming signatures

Returns:

  • jax.Array: Flattened signature tensor

get_log_signature(path, depth, log_signature_type)

Compute the log signature of a path or batch of paths.

Parameters:

  • path (jax.Array): Input path(s)
  • depth (int): Maximum signature depth
  • log_signature_type (Literal["expanded", "lyndon"]): Type of log signature computation

Returns:

  • jax.Array: Flattened log signature tensor

Development

# Clone the repository
git clone https://github.com/yourusername/stochastax.git
cd quicksig

# Install development dependencies
pip install -e ".[dev]"

# Run tests
pytest

Requirements

  • Python 3.12+
  • JAX >= 0.6.0

License

MIT License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

About

A Jax library for advanced stochastic analysis. Docs WIP.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages