Skip to content

QuaternionOperations

Hamilton product-based operations for quaternion hypervectors. Provides non-commutative binding.

vsax.ops.QuaternionOperations

Bases: AbstractOpSet

Quaternion VSA operations using Hamilton product for binding.

Quaternion Hypervectors (QHV) use the Hamilton product for binding, which is NON-COMMUTATIVE. This makes them suitable for order-sensitive role/filler bindings where bind(role, filler) != bind(filler, role).

Key properties: - Binding: Hamilton product (non-commutative, associative) - Bundling: Sum + normalize to unit quaternions - Inverse: Quaternion inverse (conjugate / norm²) - Unbind: Right-unbind recovers x from bind(x, y) given y - Unbind-left: Left-unbind recovers y from bind(x, y) given x

Vector shape: (D, 4) where D is the number of quaternion coordinates.

Example

import jax import jax.numpy as jnp from vsax.ops import QuaternionOperations from vsax.sampling import sample_quaternion_random

ops = QuaternionOperations() key = jax.random.PRNGKey(0) vecs = sample_quaternion_random(dim=512, n=2, key=key) x, y = vecs[0], vecs[1]

Non-commutative binding

xy = ops.bind(x, y) yx = ops.bind(y, x)

xy != yx (different results)

Right-unbind: recover x from xy using y

recovered_x = ops.unbind(xy, y)

Left-unbind: recover y from xy using x

recovered_y = ops.unbind_left(x, xy)

Source code in vsax/ops/quaternion.py
class QuaternionOperations(AbstractOpSet):
    """Quaternion VSA operations using Hamilton product for binding.

    Quaternion Hypervectors (QHV) use the Hamilton product for binding,
    which is NON-COMMUTATIVE. This makes them suitable for order-sensitive
    role/filler bindings where bind(role, filler) != bind(filler, role).

    Key properties:
    - Binding: Hamilton product (non-commutative, associative)
    - Bundling: Sum + normalize to unit quaternions
    - Inverse: Quaternion inverse (conjugate / norm²)
    - Unbind: Right-unbind recovers x from bind(x, y) given y
    - Unbind-left: Left-unbind recovers y from bind(x, y) given x

    Vector shape: (D, 4) where D is the number of quaternion coordinates.

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from vsax.ops import QuaternionOperations
        >>> from vsax.sampling import sample_quaternion_random
        >>>
        >>> ops = QuaternionOperations()
        >>> key = jax.random.PRNGKey(0)
        >>> vecs = sample_quaternion_random(dim=512, n=2, key=key)
        >>> x, y = vecs[0], vecs[1]
        >>>
        >>> # Non-commutative binding
        >>> xy = ops.bind(x, y)
        >>> yx = ops.bind(y, x)
        >>> # xy != yx (different results)
        >>>
        >>> # Right-unbind: recover x from xy using y
        >>> recovered_x = ops.unbind(xy, y)
        >>>
        >>> # Left-unbind: recover y from xy using x
        >>> recovered_y = ops.unbind_left(x, xy)
    """

    def bind(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
        """Bind two quaternion hypervectors using Hamilton product.

        This operation is NON-COMMUTATIVE: bind(a, b) != bind(b, a).

        Args:
            a: First quaternion hypervector of shape (..., D, 4).
            b: Second quaternion hypervector of shape (..., D, 4).

        Returns:
            Bound quaternion hypervector of shape (..., D, 4).
        """
        return qmul(a, b)

    def bundle(self, *vecs: jnp.ndarray) -> jnp.ndarray:
        """Bundle multiple quaternion hypervectors.

        Sums the vectors and normalizes to unit quaternions.

        Args:
            *vecs: Variable number of quaternion hypervectors.

        Returns:
            Bundled quaternion hypervector, normalized to unit length.

        Raises:
            ValueError: If no vectors are provided.
        """
        if len(vecs) == 0:
            raise ValueError("bundle() requires at least one vector")

        # Sum all vectors
        result = jnp.sum(jnp.stack(vecs), axis=0)

        # Normalize to unit quaternions
        return qnormalize(result)

    def inverse(self, a: jnp.ndarray) -> jnp.ndarray:
        """Compute the quaternion inverse.

        For unit quaternions, this equals the conjugate.

        Args:
            a: Quaternion hypervector of shape (..., D, 4).

        Returns:
            Inverse quaternion hypervector of shape (..., D, 4).
        """
        return qinverse(a)

    def unbind(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
        """Right-unbind: recover x from z = bind(x, y) given y.

        Computes: z * y⁻¹ = x

        For quaternion binding z = x * y, right multiplication by y⁻¹
        recovers x: z * y⁻¹ = x * y * y⁻¹ = x

        Args:
            a: Bound quaternion hypervector (z = bind(x, y)).
            b: Quaternion hypervector to unbind (y).

        Returns:
            Recovered quaternion hypervector (x).
        """
        return qmul(a, qinverse(b))

    def unbind_left(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
        """Left-unbind: recover y from z = bind(x, y) given x.

        Computes: x⁻¹ * z = y

        For quaternion binding z = x * y, left multiplication by x⁻¹
        recovers y: x⁻¹ * z = x⁻¹ * x * y = y

        Args:
            a: Quaternion hypervector used in binding (x).
            b: Bound quaternion hypervector (z = bind(x, y)).

        Returns:
            Recovered quaternion hypervector (y).
        """
        return qmul(qinverse(a), b)

    def permute(self, a: jnp.ndarray, shift: int) -> jnp.ndarray:
        """Permute a quaternion hypervector by circular shift.

        Shifts along the first axis (quaternion coordinate dimension),
        not the last axis (quaternion components).

        Args:
            a: Quaternion hypervector of shape (..., D, 4).
            shift: Number of positions to shift.

        Returns:
            Permuted quaternion hypervector.
        """
        # Roll along the second-to-last axis (D dimension)
        return jnp.roll(a, shift, axis=-2)

    def sandwich(self, rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
        """Apply sandwich product: rotor * v * rotor^-1.

        This is the core operation for quaternion-based rotations/transformations.
        Useful for learning state transformations where s' = sandwich(U, s).

        Args:
            rotor: Quaternion rotor of shape (..., D, 4).
            v: Quaternion hypervector to transform of shape (..., D, 4).

        Returns:
            Transformed quaternion hypervector.
        """
        return sandwich(rotor, v)

    def sandwich_unit(self, rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
        """Apply unit rotor sandwich product: rotor * v * rotor*.

        More efficient version for unit quaternions where inverse = conjugate.

        Args:
            rotor: Unit quaternion rotor of shape (..., D, 4).
            v: Quaternion hypervector to transform of shape (..., D, 4).

        Returns:
            Transformed quaternion hypervector.
        """
        return sandwich_unit(rotor, v)

Functions

bind(a, b)

Bind two quaternion hypervectors using Hamilton product.

This operation is NON-COMMUTATIVE: bind(a, b) != bind(b, a).

Parameters:

Name Type Description Default
a ndarray

First quaternion hypervector of shape (..., D, 4).

required
b ndarray

Second quaternion hypervector of shape (..., D, 4).

required

Returns:

Type Description
ndarray

Bound quaternion hypervector of shape (..., D, 4).

Source code in vsax/ops/quaternion.py
def bind(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """Bind two quaternion hypervectors using Hamilton product.

    This operation is NON-COMMUTATIVE: bind(a, b) != bind(b, a).

    Args:
        a: First quaternion hypervector of shape (..., D, 4).
        b: Second quaternion hypervector of shape (..., D, 4).

    Returns:
        Bound quaternion hypervector of shape (..., D, 4).
    """
    return qmul(a, b)

bundle(*vecs)

Bundle multiple quaternion hypervectors.

Sums the vectors and normalizes to unit quaternions.

Parameters:

Name Type Description Default
*vecs ndarray

Variable number of quaternion hypervectors.

()

Returns:

Type Description
ndarray

Bundled quaternion hypervector, normalized to unit length.

Raises:

Type Description
ValueError

If no vectors are provided.

Source code in vsax/ops/quaternion.py
def bundle(self, *vecs: jnp.ndarray) -> jnp.ndarray:
    """Bundle multiple quaternion hypervectors.

    Sums the vectors and normalizes to unit quaternions.

    Args:
        *vecs: Variable number of quaternion hypervectors.

    Returns:
        Bundled quaternion hypervector, normalized to unit length.

    Raises:
        ValueError: If no vectors are provided.
    """
    if len(vecs) == 0:
        raise ValueError("bundle() requires at least one vector")

    # Sum all vectors
    result = jnp.sum(jnp.stack(vecs), axis=0)

    # Normalize to unit quaternions
    return qnormalize(result)

inverse(a)

Compute the quaternion inverse.

For unit quaternions, this equals the conjugate.

Parameters:

Name Type Description Default
a ndarray

Quaternion hypervector of shape (..., D, 4).

required

Returns:

Type Description
ndarray

Inverse quaternion hypervector of shape (..., D, 4).

Source code in vsax/ops/quaternion.py
def inverse(self, a: jnp.ndarray) -> jnp.ndarray:
    """Compute the quaternion inverse.

    For unit quaternions, this equals the conjugate.

    Args:
        a: Quaternion hypervector of shape (..., D, 4).

    Returns:
        Inverse quaternion hypervector of shape (..., D, 4).
    """
    return qinverse(a)

unbind(a, b)

Right-unbind: recover x from z = bind(x, y) given y.

Computes: z * y⁻¹ = x

For quaternion binding z = x * y, right multiplication by y⁻¹ recovers x: z * y⁻¹ = x * y * y⁻¹ = x

Parameters:

Name Type Description Default
a ndarray

Bound quaternion hypervector (z = bind(x, y)).

required
b ndarray

Quaternion hypervector to unbind (y).

required

Returns:

Type Description
ndarray

Recovered quaternion hypervector (x).

Source code in vsax/ops/quaternion.py
def unbind(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """Right-unbind: recover x from z = bind(x, y) given y.

    Computes: z * y⁻¹ = x

    For quaternion binding z = x * y, right multiplication by y⁻¹
    recovers x: z * y⁻¹ = x * y * y⁻¹ = x

    Args:
        a: Bound quaternion hypervector (z = bind(x, y)).
        b: Quaternion hypervector to unbind (y).

    Returns:
        Recovered quaternion hypervector (x).
    """
    return qmul(a, qinverse(b))

unbind_left(a, b)

Left-unbind: recover y from z = bind(x, y) given x.

Computes: x⁻¹ * z = y

For quaternion binding z = x * y, left multiplication by x⁻¹ recovers y: x⁻¹ * z = x⁻¹ * x * y = y

Parameters:

Name Type Description Default
a ndarray

Quaternion hypervector used in binding (x).

required
b ndarray

Bound quaternion hypervector (z = bind(x, y)).

required

Returns:

Type Description
ndarray

Recovered quaternion hypervector (y).

Source code in vsax/ops/quaternion.py
def unbind_left(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """Left-unbind: recover y from z = bind(x, y) given x.

    Computes: x⁻¹ * z = y

    For quaternion binding z = x * y, left multiplication by x⁻¹
    recovers y: x⁻¹ * z = x⁻¹ * x * y = y

    Args:
        a: Quaternion hypervector used in binding (x).
        b: Bound quaternion hypervector (z = bind(x, y)).

    Returns:
        Recovered quaternion hypervector (y).
    """
    return qmul(qinverse(a), b)

permute(a, shift)

Permute a quaternion hypervector by circular shift.

Shifts along the first axis (quaternion coordinate dimension), not the last axis (quaternion components).

Parameters:

Name Type Description Default
a ndarray

Quaternion hypervector of shape (..., D, 4).

required
shift int

Number of positions to shift.

required

Returns:

Type Description
ndarray

Permuted quaternion hypervector.

Source code in vsax/ops/quaternion.py
def permute(self, a: jnp.ndarray, shift: int) -> jnp.ndarray:
    """Permute a quaternion hypervector by circular shift.

    Shifts along the first axis (quaternion coordinate dimension),
    not the last axis (quaternion components).

    Args:
        a: Quaternion hypervector of shape (..., D, 4).
        shift: Number of positions to shift.

    Returns:
        Permuted quaternion hypervector.
    """
    # Roll along the second-to-last axis (D dimension)
    return jnp.roll(a, shift, axis=-2)

sandwich(rotor, v)

Apply sandwich product: rotor * v * rotor^-1.

This is the core operation for quaternion-based rotations/transformations. Useful for learning state transformations where s' = sandwich(U, s).

Parameters:

Name Type Description Default
rotor ndarray

Quaternion rotor of shape (..., D, 4).

required
v ndarray

Quaternion hypervector to transform of shape (..., D, 4).

required

Returns:

Type Description
ndarray

Transformed quaternion hypervector.

Source code in vsax/ops/quaternion.py
def sandwich(self, rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
    """Apply sandwich product: rotor * v * rotor^-1.

    This is the core operation for quaternion-based rotations/transformations.
    Useful for learning state transformations where s' = sandwich(U, s).

    Args:
        rotor: Quaternion rotor of shape (..., D, 4).
        v: Quaternion hypervector to transform of shape (..., D, 4).

    Returns:
        Transformed quaternion hypervector.
    """
    return sandwich(rotor, v)

sandwich_unit(rotor, v)

Apply unit rotor sandwich product: rotor * v * rotor*.

More efficient version for unit quaternions where inverse = conjugate.

Parameters:

Name Type Description Default
rotor ndarray

Unit quaternion rotor of shape (..., D, 4).

required
v ndarray

Quaternion hypervector to transform of shape (..., D, 4).

required

Returns:

Type Description
ndarray

Transformed quaternion hypervector.

Source code in vsax/ops/quaternion.py
def sandwich_unit(self, rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
    """Apply unit rotor sandwich product: rotor * v * rotor*.

    More efficient version for unit quaternions where inverse = conjugate.

    Args:
        rotor: Unit quaternion rotor of shape (..., D, 4).
        v: Quaternion hypervector to transform of shape (..., D, 4).

    Returns:
        Transformed quaternion hypervector.
    """
    return sandwich_unit(rotor, v)

Sandwich Product Functions

The sandwich product applies a rotor transformation via rotor * v * rotor⁻¹:

vsax.ops.sandwich(rotor, v)

Apply rotor transformation via sandwich product: rotor * v * rotor^-1.

This is the core operation for quaternion-based rotations/transformations. For unit quaternions representing rotations, this performs the rotation of v by the rotation encoded in rotor.

The sandwich product is useful for learning transformations between states. Given state traces (s, action, s'), you can learn a rotor U such that s' = sandwich(U, s).

Parameters:

Name Type Description Default
rotor ndarray

Quaternion rotor, shape (..., 4) or (..., D, 4).

required
v ndarray

Quaternion to transform, shape (..., 4) or (..., D, 4).

required

Returns:

Type Description
ndarray

Transformed quaternion rotor * v * rotor^-1.

Example

import jax.numpy as jnp

Apply rotor transformation

rotor = jnp.array([0.707, 0.707, 0, 0]) # 90-degree rotation v = jnp.array([0, 1, 0, 0]) # Pure quaternion transformed = sandwich(rotor, v)

Source code in vsax/ops/quaternion.py
def sandwich(rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
    """Apply rotor transformation via sandwich product: rotor * v * rotor^-1.

    This is the core operation for quaternion-based rotations/transformations.
    For unit quaternions representing rotations, this performs the rotation
    of v by the rotation encoded in rotor.

    The sandwich product is useful for learning transformations between states.
    Given state traces (s, action, s'), you can learn a rotor U such that
    s' = sandwich(U, s).

    Args:
        rotor: Quaternion rotor, shape (..., 4) or (..., D, 4).
        v: Quaternion to transform, shape (..., 4) or (..., D, 4).

    Returns:
        Transformed quaternion rotor * v * rotor^-1.

    Example:
        >>> import jax.numpy as jnp
        >>> # Apply rotor transformation
        >>> rotor = jnp.array([0.707, 0.707, 0, 0])  # 90-degree rotation
        >>> v = jnp.array([0, 1, 0, 0])  # Pure quaternion
        >>> transformed = sandwich(rotor, v)
    """
    return qmul(qmul(rotor, v), qinverse(rotor))

vsax.ops.sandwich_unit(rotor, v)

Apply unit rotor transformation: rotor * v * rotor*.

More efficient version for unit quaternions where inverse = conjugate. Use this when you know the rotor is already normalized to unit length.

Parameters:

Name Type Description Default
rotor ndarray

Unit quaternion rotor, shape (..., 4) or (..., D, 4).

required
v ndarray

Quaternion to transform, shape (..., 4) or (..., D, 4).

required

Returns:

Type Description
ndarray

Transformed quaternion rotor * v * conj(rotor).

Example

import jax.numpy as jnp

Apply unit rotor transformation (more efficient)

rotor = jnp.array([0.707, 0.707, 0, 0]) # Already unit length v = jnp.array([0, 1, 0, 0]) transformed = sandwich_unit(rotor, v)

Source code in vsax/ops/quaternion.py
def sandwich_unit(rotor: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
    """Apply unit rotor transformation: rotor * v * rotor*.

    More efficient version for unit quaternions where inverse = conjugate.
    Use this when you know the rotor is already normalized to unit length.

    Args:
        rotor: Unit quaternion rotor, shape (..., 4) or (..., D, 4).
        v: Quaternion to transform, shape (..., 4) or (..., D, 4).

    Returns:
        Transformed quaternion rotor * v * conj(rotor).

    Example:
        >>> import jax.numpy as jnp
        >>> # Apply unit rotor transformation (more efficient)
        >>> rotor = jnp.array([0.707, 0.707, 0, 0])  # Already unit length
        >>> v = jnp.array([0, 1, 0, 0])
        >>> transformed = sandwich_unit(rotor, v)
    """
    return qmul(qmul(rotor, v), qconj(rotor))