Utilities API¶
Utility functions for batch operations and visualization.
Batch Operations¶
vsax.utils.vmap_bind(opset, X, Y)
¶
Vectorized binding of two batches of hypervectors.
Applies the bind operation element-wise across two batches of hypervectors using JAX's vmap for efficient parallel execution on GPU/TPU.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opset
|
AbstractOpSet
|
The operation set defining the bind operation. |
required |
X
|
Union[ndarray, list[ndarray]]
|
First batch of hypervectors, shape (batch_size, dim). |
required |
Y
|
Union[ndarray, list[ndarray]]
|
Second batch of hypervectors, shape (batch_size, dim). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Batch of bound hypervectors, shape (batch_size, dim). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batch sizes don't match. |
Example
from vsax import create_fhrr_model, VSAMemory from vsax.utils import vmap_bind import jax.numpy as jnp model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add_many(["a", "b", "c", "x", "y", "z"])
Batch bind: [a, b, c] with [x, y, z]¶
X = jnp.stack([memory["a"].vec, memory["b"].vec, memory["c"].vec]) Y = jnp.stack([memory["x"].vec, memory["y"].vec, memory["z"].vec]) result = vmap_bind(model.opset, X, Y) print(result.shape) # (3, 512)
vsax.utils.vmap_bundle(opset, X)
¶
Vectorized bundling across batch dimension.
Bundles a batch of hypervectors into a single hypervector using JAX's efficient reduction operations. This is NOT element-wise - it combines all vectors in the batch into one result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opset
|
AbstractOpSet
|
The operation set defining the bundle operation. |
required |
X
|
Union[ndarray, list[ndarray]]
|
Batch of hypervectors, shape (batch_size, dim). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Single bundled hypervector, shape (dim,). |
Example
from vsax import create_map_model, VSAMemory from vsax.utils import vmap_bundle import jax.numpy as jnp model = create_map_model(dim=512) memory = VSAMemory(model) memory.add_many(["red", "green", "blue"])
Bundle colors together¶
colors = jnp.stack([ ... memory["red"].vec, ... memory["green"].vec, ... memory["blue"].vec ... ]) color_set = vmap_bundle(model.opset, colors) print(color_set.shape) # (512,)
vsax.utils.vmap_similarity(similarity_fn, query, candidates)
¶
Vectorized similarity computation between query and multiple candidates.
Computes similarity between a single query vector and a batch of candidate vectors using JAX's vmap for efficient parallel execution.
Note: This function computes raw similarity scores without converting to Python floats, allowing it to work within JAX transformations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
similarity_fn
|
None
|
Similarity function that operates on arrays. Must accept two jnp.ndarray arguments and return a scalar. |
required |
query
|
Union[ndarray, list[ndarray]]
|
Single query hypervector, shape (dim,). |
required |
candidates
|
Union[ndarray, list[ndarray]]
|
Batch of candidate hypervectors, shape (batch_size, dim). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of similarity scores, shape (batch_size,). |
Example
from vsax import create_fhrr_model, VSAMemory from vsax.utils import vmap_similarity import jax.numpy as jnp model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add_many(["dog", "cat", "bird", "animal"])
Define similarity function that works on arrays¶
def array_cosine(a, b): ... dot = jnp.vdot(a, b) if jnp.iscomplexobj(a) else jnp.dot(a, b) ... if jnp.iscomplexobj(a): dot = jnp.real(dot) ... return dot / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-10)
Find most similar to "animal"¶
query = memory["animal"].vec candidates = jnp.stack([ ... memory["dog"].vec, ... memory["cat"].vec, ... memory["bird"].vec ... ]) similarities = vmap_similarity(array_cosine, query, candidates) best_match = jnp.argmax(similarities)
Visualization¶
vsax.utils.pretty_repr(hv, max_elements=5)
¶
Generate a pretty string representation of a hypervector.
Creates a human-readable representation showing shape, dtype, and a sample of the vector values. Useful for debugging and interactive exploration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hv
|
Union[AbstractHypervector, ndarray]
|
Hypervector to represent (AbstractHypervector or jnp.ndarray). |
required |
max_elements
|
int
|
Maximum number of elements to display (default: 5). |
5
|
Returns:
| Type | Description |
|---|---|
str
|
Formatted string representation. |
Example
from vsax import create_fhrr_model, VSAMemory from vsax.utils import pretty_repr model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add("example") print(pretty_repr(memory["example"])) ComplexHypervector(dim=512, dtype=complex64) Sample: [0.123+0.456j, -0.789+0.234j, ..., 0.567-0.890j]
vsax.utils.format_similarity_results(query_name, candidate_names, similarities, top_k=5)
¶
Format similarity search results in a readable table.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_name
|
str
|
Name of the query item. |
required |
candidate_names
|
list[str]
|
Names of candidate items. |
required |
similarities
|
ndarray
|
Array of similarity scores, shape (n_candidates,). |
required |
top_k
|
int
|
Number of top results to display (default: 5). |
5
|
Returns:
| Type | Description |
|---|---|
str
|
Formatted table string. |
Example
from vsax.utils import format_similarity_results import jax.numpy as jnp results = format_similarity_results( ... "dog", ... ["cat", "wolf", "bird", "puppy"], ... jnp.array([0.85, 0.92, 0.23, 0.95]), ... top_k=3 ... ) print(results) Query: dog Top 3 matches: 1. puppy 0.950 2. wolf 0.920 3. cat 0.850