Skip to content

Utilities API

Utility functions for batch operations and visualization.

Batch Operations

vsax.utils.vmap_bind(opset, X, Y)

Vectorized binding of two batches of hypervectors.

Applies the bind operation element-wise across two batches of hypervectors using JAX's vmap for efficient parallel execution on GPU/TPU.

Parameters:

Name Type Description Default
opset AbstractOpSet

The operation set defining the bind operation.

required
X Union[ndarray, list[ndarray]]

First batch of hypervectors, shape (batch_size, dim).

required
Y Union[ndarray, list[ndarray]]

Second batch of hypervectors, shape (batch_size, dim).

required

Returns:

Type Description
ndarray

Batch of bound hypervectors, shape (batch_size, dim).

Raises:

Type Description
ValueError

If batch sizes don't match.

Example

from vsax import create_fhrr_model, VSAMemory from vsax.utils import vmap_bind import jax.numpy as jnp model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add_many(["a", "b", "c", "x", "y", "z"])

Batch bind: [a, b, c] with [x, y, z]

X = jnp.stack([memory["a"].vec, memory["b"].vec, memory["c"].vec]) Y = jnp.stack([memory["x"].vec, memory["y"].vec, memory["z"].vec]) result = vmap_bind(model.opset, X, Y) print(result.shape) # (3, 512)

vsax.utils.vmap_bundle(opset, X)

Vectorized bundling across batch dimension.

Bundles a batch of hypervectors into a single hypervector using JAX's efficient reduction operations. This is NOT element-wise - it combines all vectors in the batch into one result.

Parameters:

Name Type Description Default
opset AbstractOpSet

The operation set defining the bundle operation.

required
X Union[ndarray, list[ndarray]]

Batch of hypervectors, shape (batch_size, dim).

required

Returns:

Type Description
ndarray

Single bundled hypervector, shape (dim,).

Example

from vsax import create_map_model, VSAMemory from vsax.utils import vmap_bundle import jax.numpy as jnp model = create_map_model(dim=512) memory = VSAMemory(model) memory.add_many(["red", "green", "blue"])

Bundle colors together

colors = jnp.stack([ ... memory["red"].vec, ... memory["green"].vec, ... memory["blue"].vec ... ]) color_set = vmap_bundle(model.opset, colors) print(color_set.shape) # (512,)

vsax.utils.vmap_similarity(similarity_fn, query, candidates)

Vectorized similarity computation between query and multiple candidates.

Computes similarity between a single query vector and a batch of candidate vectors using JAX's vmap for efficient parallel execution.

Note: This function computes raw similarity scores without converting to Python floats, allowing it to work within JAX transformations.

Parameters:

Name Type Description Default
similarity_fn None

Similarity function that operates on arrays. Must accept two jnp.ndarray arguments and return a scalar.

required
query Union[ndarray, list[ndarray]]

Single query hypervector, shape (dim,).

required
candidates Union[ndarray, list[ndarray]]

Batch of candidate hypervectors, shape (batch_size, dim).

required

Returns:

Type Description
ndarray

Array of similarity scores, shape (batch_size,).

Example

from vsax import create_fhrr_model, VSAMemory from vsax.utils import vmap_similarity import jax.numpy as jnp model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add_many(["dog", "cat", "bird", "animal"])

Define similarity function that works on arrays

def array_cosine(a, b): ... dot = jnp.vdot(a, b) if jnp.iscomplexobj(a) else jnp.dot(a, b) ... if jnp.iscomplexobj(a): dot = jnp.real(dot) ... return dot / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-10)

Find most similar to "animal"

query = memory["animal"].vec candidates = jnp.stack([ ... memory["dog"].vec, ... memory["cat"].vec, ... memory["bird"].vec ... ]) similarities = vmap_similarity(array_cosine, query, candidates) best_match = jnp.argmax(similarities)

Visualization

vsax.utils.pretty_repr(hv, max_elements=5)

Generate a pretty string representation of a hypervector.

Creates a human-readable representation showing shape, dtype, and a sample of the vector values. Useful for debugging and interactive exploration.

Parameters:

Name Type Description Default
hv Union[AbstractHypervector, ndarray]

Hypervector to represent (AbstractHypervector or jnp.ndarray).

required
max_elements int

Maximum number of elements to display (default: 5).

5

Returns:

Type Description
str

Formatted string representation.

Example

from vsax import create_fhrr_model, VSAMemory from vsax.utils import pretty_repr model = create_fhrr_model(dim=512) memory = VSAMemory(model) memory.add("example") print(pretty_repr(memory["example"])) ComplexHypervector(dim=512, dtype=complex64) Sample: [0.123+0.456j, -0.789+0.234j, ..., 0.567-0.890j]

vsax.utils.format_similarity_results(query_name, candidate_names, similarities, top_k=5)

Format similarity search results in a readable table.

Parameters:

Name Type Description Default
query_name str

Name of the query item.

required
candidate_names list[str]

Names of candidate items.

required
similarities ndarray

Array of similarity scores, shape (n_candidates,).

required
top_k int

Number of top results to display (default: 5).

5

Returns:

Type Description
str

Formatted table string.

Example

from vsax.utils import format_similarity_results import jax.numpy as jnp results = format_similarity_results( ... "dog", ... ["cat", "wolf", "bird", "puppy"], ... jnp.array([0.85, 0.92, 0.23, 0.95]), ... top_k=3 ... ) print(results) Query: dog Top 3 matches: 1. puppy 0.950 2. wolf 0.920 3. cat 0.850