QuaternionHypervector¶
Quaternion-valued hypervector representation for non-commutative VSA.
vsax.representations.QuaternionHypervector
¶
Bases: AbstractHypervector
Quaternion-valued hypervector for non-commutative VSA.
QuaternionHypervector uses quaternions to represent hypervectors, where each quaternion coordinate is a 4-tuple (a, b, c, d) representing q = a + bi + cj + dk. This enables non-commutative binding via the Hamilton product.
The storage format is (D, 4) where: - D is the number of quaternion coordinates (the VSA dimensionality) - 4 is the quaternion components (a, b, c, d)
The normalization operation projects all quaternions to unit length on S³.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vec
|
ndarray
|
Quaternion-valued JAX array of shape (..., D, 4). |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the last dimension is not 4. |
Example
import jax.numpy as jnp
Create a hypervector with 512 quaternion coordinates¶
vec = jnp.ones((512, 4)) / 2 # All unit quaternions (0.5, 0.5, 0.5, 0.5) hv = QuaternionHypervector(vec) normalized = hv.normalize()
Source code in vsax/representations/quaternion_hv.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | |
Attributes¶
dim
property
¶
Return the number of quaternion coordinates.
This is the VSA dimensionality D, not the total number of floats.
Returns:
| Type | Description |
|---|---|
int
|
Number of quaternion coordinates (D). |
quaternion_norms
property
¶
Return the norm of each quaternion coordinate.
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of shape (..., D) containing quaternion magnitudes. |
scalar_part
property
¶
Extract scalar (real) component of each quaternion.
The scalar part is the 'a' in q = a + bi + cj + dk.
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of shape (..., D) containing scalar parts. |
vector_part
property
¶
Extract vector (imaginary) components of each quaternion.
The vector part is (b, c, d) in q = a + bi + cj + dk.
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of shape (..., D, 3) containing vector parts. |
Functions¶
__init__(vec)
¶
Initialize quaternion hypervector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vec
|
ndarray
|
Quaternion-valued JAX array of shape (..., D, 4). |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the last dimension is not 4. |
Source code in vsax/representations/quaternion_hv.py
normalize()
¶
Normalize all quaternions to unit length.
Projects each quaternion onto the unit 3-sphere S³.
Returns:
| Type | Description |
|---|---|
QuaternionHypervector
|
New QuaternionHypervector with all quaternions having magnitude 1.0. |
Example
import jax.numpy as jnp vec = jnp.array([[1, 2, 3, 4], [5, 6, 7, 8]]) # Not unit hv = QuaternionHypervector(vec.astype(float)) normalized = hv.normalize()
All quaternions now have unit norm¶
Source code in vsax/representations/quaternion_hv.py
is_unit(atol=1e-06)
¶
Check if all quaternions are unit quaternions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
atol
|
float
|
Absolute tolerance for norm comparison. |
1e-06
|
Returns:
| Type | Description |
|---|---|
bool
|
True if all quaternions have magnitude approximately 1.0. |