Skip to content

Spatial Semantic Pointers

Continuous spatial representation using Fractional Power Encoding.

NEW in v1.2.0 - Based on Komer et al. (2019).

SpatialSemanticPointers

vsax.spatial.SpatialSemanticPointers

Spatial Semantic Pointers for continuous spatial representation.

Encodes continuous spatial locations using fractional power encoding

S(x, y) = X^x ⊗ Y^y

This enables
  • Encoding arbitrary spatial coordinates
  • Binding objects to locations
  • Querying "what is at location (x, y)?"
  • Querying "where is object O?"
  • Shifting entire scenes by a displacement vector
  • Decoding approximate locations from encoded vectors

Based on Komer et al. 2019 which demonstrates that SSPs can represent continuous spatial relationships in a compositional, distributed manner.

Attributes:

Name Type Description
model

VSAModel instance (must use ComplexHypervector/FHRR).

memory

VSAMemory for storing axis basis vectors and named objects.

config

SSPConfig with spatial configuration.

encoder

FractionalPowerEncoder for encoding spatial coordinates.

Example

import jax from vsax import create_fhrr_model, VSAMemory from vsax.spatial import SpatialSemanticPointers, SSPConfig

Create 2D spatial representation

model = create_fhrr_model(dim=512, key=jax.random.PRNGKey(0)) memory = VSAMemory(model) config = SSPConfig(dim=512, num_axes=2) # 2D space ssp = SpatialSemanticPointers(model, memory, config)

Encode a location

location = ssp.encode_location([3.5, 2.1]) # (x=3.5, y=2.1)

Bind object to location

memory.add("apple") scene = ssp.bind_object_location("apple", [3.5, 2.1])

Query: what is at location (3.5, 2.1)?

result = ssp.query_location(scene, [3.5, 2.1])

result should be similar to apple hypervector

See Also
  • :class:~vsax.encoders.FractionalPowerEncoder: Underlying encoder
  • Komer et al. 2019: "A neural representation of continuous space using fractional binding"
Source code in vsax/spatial/ssp.py
class SpatialSemanticPointers:
    """Spatial Semantic Pointers for continuous spatial representation.

    Encodes continuous spatial locations using fractional power encoding:
        S(x, y) = X^x ⊗ Y^y

    This enables:
        - Encoding arbitrary spatial coordinates
        - Binding objects to locations
        - Querying "what is at location (x, y)?"
        - Querying "where is object O?"
        - Shifting entire scenes by a displacement vector
        - Decoding approximate locations from encoded vectors

    Based on Komer et al. 2019 which demonstrates that SSPs can represent
    continuous spatial relationships in a compositional, distributed manner.

    Attributes:
        model: VSAModel instance (must use ComplexHypervector/FHRR).
        memory: VSAMemory for storing axis basis vectors and named objects.
        config: SSPConfig with spatial configuration.
        encoder: FractionalPowerEncoder for encoding spatial coordinates.

    Example:
        >>> import jax
        >>> from vsax import create_fhrr_model, VSAMemory
        >>> from vsax.spatial import SpatialSemanticPointers, SSPConfig
        >>>
        >>> # Create 2D spatial representation
        >>> model = create_fhrr_model(dim=512, key=jax.random.PRNGKey(0))
        >>> memory = VSAMemory(model)
        >>> config = SSPConfig(dim=512, num_axes=2)  # 2D space
        >>> ssp = SpatialSemanticPointers(model, memory, config)
        >>>
        >>> # Encode a location
        >>> location = ssp.encode_location([3.5, 2.1])  # (x=3.5, y=2.1)
        >>>
        >>> # Bind object to location
        >>> memory.add("apple")
        >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
        >>>
        >>> # Query: what is at location (3.5, 2.1)?
        >>> result = ssp.query_location(scene, [3.5, 2.1])
        >>> # result should be similar to apple hypervector

    See Also:
        - :class:`~vsax.encoders.FractionalPowerEncoder`: Underlying encoder
        - Komer et al. 2019: "A neural representation of continuous space using
          fractional binding"
    """

    def __init__(
        self,
        model: VSAModel,
        memory: VSAMemory,
        config: Optional[SSPConfig] = None,
    ) -> None:
        """Initialize Spatial Semantic Pointers.

        Args:
            model: VSAModel instance (must use ComplexHypervector).
            memory: VSAMemory for basis vectors and objects.
            config: SSPConfig, defaults to 2D (512-dim) if not provided.

        Raises:
            TypeError: If model doesn't use ComplexHypervector.
        """
        self.model = model
        self.memory = memory
        self.config = config if config is not None else SSPConfig()

        # axis_names is always set by SSPConfig.__post_init__
        assert self.config.axis_names is not None

        # Create FractionalPowerEncoder
        self.encoder = FractionalPowerEncoder(model, memory, scale=self.config.scale)

        # Initialize axis basis vectors in memory
        for axis_name in self.config.axis_names:
            if axis_name not in memory:
                memory.add(axis_name)

    def encode_location(self, coordinates: list[float]) -> ComplexHypervector:
        """Encode a spatial location as a hypervector.

        For 2D: S(x, y) = X^x ⊗ Y^y
        For 3D: S(x, y, z) = X^x ⊗ Y^y ⊗ Z^z

        Args:
            coordinates: List of coordinate values, one per axis.

        Returns:
            ComplexHypervector representing the spatial location.

        Raises:
            ValueError: If coordinates length doesn't match num_axes.

        Example:
            >>> location = ssp.encode_location([3.5, 2.1])  # 2D point
            >>> location3d = ssp.encode_location([1.0, 2.0, 3.0])  # 3D point
        """
        if len(coordinates) != self.config.num_axes:
            raise ValueError(f"Expected {self.config.num_axes} coordinates, got {len(coordinates)}")

        assert self.config.axis_names is not None  # Always set in __init__
        return self.encoder.encode_multi(self.config.axis_names, coordinates)

    def bind_object_location(
        self, object_name: str, coordinates: list[float]
    ) -> ComplexHypervector:
        """Bind an object to a spatial location.

        Creates: Object ⊗ S(x, y) = Object ⊗ X^x ⊗ Y^y

        Args:
            object_name: Name of object in memory.
            coordinates: Spatial coordinates for the object.

        Returns:
            ComplexHypervector representing object-at-location.

        Raises:
            KeyError: If object_name not in memory.
            ValueError: If coordinates length doesn't match num_axes.

        Example:
            >>> memory.add("apple")
            >>> apple_at_pos = ssp.bind_object_location("apple", [3.5, 2.1])
        """
        # Get object hypervector
        object_hv = self.memory[object_name]

        # Encode location
        location_hv = self.encode_location(coordinates)

        # Bind object to location
        result_vec = self.model.opset.bind(object_hv.vec, location_hv.vec)

        return ComplexHypervector(result_vec)

    def query_location(
        self, scene: ComplexHypervector, coordinates: list[float]
    ) -> ComplexHypervector:
        """Query what object is at a given location in the scene.

        For scene containing Object ⊗ S(x, y), querying at (x, y) returns
        a vector similar to Object.

        Args:
            scene: Scene hypervector (typically a bundle of object-location pairs).
            coordinates: Location to query.

        Returns:
            ComplexHypervector representing the object at that location.

        Raises:
            ValueError: If coordinates length doesn't match num_axes.

        Example:
            >>> # Create scene with apple at (3.5, 2.1)
            >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
            >>> # Query: what's at (3.5, 2.1)?
            >>> result = ssp.query_location(scene, [3.5, 2.1])
            >>> # result should be similar to memory["apple"]
        """
        # Encode query location
        location_hv = self.encode_location(coordinates)

        # Unbind: Scene ⊗ S(x,y)^(-1) ≈ Object
        inv_location = self.model.opset.inverse(location_hv.vec)
        result_vec = self.model.opset.bind(scene.vec, inv_location)

        return ComplexHypervector(result_vec)

    def query_object(self, scene: ComplexHypervector, object_name: str) -> ComplexHypervector:
        """Query where an object is located in the scene.

        For scene containing Object ⊗ S(x, y), querying for Object returns
        a vector similar to S(x, y).

        Args:
            scene: Scene hypervector.
            object_name: Name of object to locate.

        Returns:
            ComplexHypervector representing the location of the object.

        Raises:
            KeyError: If object_name not in memory.

        Example:
            >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
            >>> # Query: where is the apple?
            >>> location_hv = ssp.query_object(scene, "apple")
            >>> # location_hv should be similar to ssp.encode_location([3.5, 2.1])
        """
        # Get object hypervector
        object_hv = self.memory[object_name]

        # Unbind: Scene ⊗ Object^(-1) ≈ S(x,y)
        inv_object = self.model.opset.inverse(object_hv.vec)
        result_vec = self.model.opset.bind(scene.vec, inv_object)

        return ComplexHypervector(result_vec)

    def shift_scene(self, scene: ComplexHypervector, offset: list[float]) -> ComplexHypervector:
        """Shift all objects in a scene by a displacement vector.

        For scene S containing objects at various locations, shift by (dx, dy)
        moves all objects by that offset.

        This works because: (Object ⊗ X^x ⊗ Y^y) ⊗ (X^dx ⊗ Y^dy) =
                             Object ⊗ X^(x+dx) ⊗ Y^(y+dy)

        Args:
            scene: Scene hypervector to shift.
            offset: Displacement vector, one value per axis.

        Returns:
            Shifted scene hypervector.

        Raises:
            ValueError: If offset length doesn't match num_axes.

        Example:
            >>> # Apple at (3.5, 2.1)
            >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
            >>> # Shift scene by (1.0, -0.5)
            >>> shifted = ssp.shift_scene(scene, [1.0, -0.5])
            >>> # Now apple is at (4.5, 1.6)
        """
        if len(offset) != self.config.num_axes:
            raise ValueError(f"Expected {self.config.num_axes} offset values, got {len(offset)}")

        # Encode displacement as a location
        displacement_hv = self.encode_location(offset)

        # Bind displacement to scene
        result_vec = self.model.opset.bind(scene.vec, displacement_hv.vec)

        return ComplexHypervector(result_vec)

    def decode_location(
        self,
        location_hv: ComplexHypervector,
        search_range: list[tuple[float, float]],
        resolution: int = 20,
    ) -> list[float]:
        """Decode approximate coordinates from a location hypervector.

        Uses grid search to find coordinates that best match the encoded location.

        Args:
            location_hv: Encoded location hypervector to decode.
            search_range: List of (min, max) tuples, one per axis.
            resolution: Number of grid points to sample per axis (default: 20).

        Returns:
            List of decoded coordinates (approximate).

        Raises:
            ValueError: If search_range length doesn't match num_axes.

        Example:
            >>> location_hv = ssp.encode_location([3.5, 2.1])
            >>> decoded = ssp.decode_location(
            ...     location_hv,
            ...     search_range=[(0.0, 5.0), (0.0, 5.0)],
            ...     resolution=50
            ... )
            >>> # decoded should be close to [3.5, 2.1]
        """
        if len(search_range) != self.config.num_axes:
            raise ValueError(
                f"Expected {self.config.num_axes} search ranges, got {len(search_range)}"
            )

        # Create grid of candidate coordinates
        grids = [jnp.linspace(min_val, max_val, resolution) for min_val, max_val in search_range]

        # Generate all combinations using meshgrid
        mesh = jnp.meshgrid(*grids, indexing="ij")
        # Flatten and stack to get all candidate points
        candidates = jnp.stack([g.flatten() for g in mesh], axis=1)

        # Encode all candidate locations and compute similarities
        best_similarity = -jnp.inf
        best_coords = None

        for candidate in candidates:
            candidate_hv = self.encode_location(candidate.tolist())
            similarity = cosine_similarity(location_hv.vec, candidate_hv.vec)

            if similarity > best_similarity:
                best_similarity = similarity
                best_coords = candidate

        assert best_coords is not None  # At least one candidate should exist
        result: list[float] = best_coords.tolist()
        return result

Functions

__init__(model, memory, config=None)

Initialize Spatial Semantic Pointers.

Parameters:

Name Type Description Default
model VSAModel

VSAModel instance (must use ComplexHypervector).

required
memory VSAMemory

VSAMemory for basis vectors and objects.

required
config Optional[SSPConfig]

SSPConfig, defaults to 2D (512-dim) if not provided.

None

Raises:

Type Description
TypeError

If model doesn't use ComplexHypervector.

Source code in vsax/spatial/ssp.py
def __init__(
    self,
    model: VSAModel,
    memory: VSAMemory,
    config: Optional[SSPConfig] = None,
) -> None:
    """Initialize Spatial Semantic Pointers.

    Args:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for basis vectors and objects.
        config: SSPConfig, defaults to 2D (512-dim) if not provided.

    Raises:
        TypeError: If model doesn't use ComplexHypervector.
    """
    self.model = model
    self.memory = memory
    self.config = config if config is not None else SSPConfig()

    # axis_names is always set by SSPConfig.__post_init__
    assert self.config.axis_names is not None

    # Create FractionalPowerEncoder
    self.encoder = FractionalPowerEncoder(model, memory, scale=self.config.scale)

    # Initialize axis basis vectors in memory
    for axis_name in self.config.axis_names:
        if axis_name not in memory:
            memory.add(axis_name)

encode_location(coordinates)

Encode a spatial location as a hypervector.

For 2D: S(x, y) = X^x ⊗ Y^y For 3D: S(x, y, z) = X^x ⊗ Y^y ⊗ Z^z

Parameters:

Name Type Description Default
coordinates list[float]

List of coordinate values, one per axis.

required

Returns:

Type Description
ComplexHypervector

ComplexHypervector representing the spatial location.

Raises:

Type Description
ValueError

If coordinates length doesn't match num_axes.

Example

location = ssp.encode_location([3.5, 2.1]) # 2D point location3d = ssp.encode_location([1.0, 2.0, 3.0]) # 3D point

Source code in vsax/spatial/ssp.py
def encode_location(self, coordinates: list[float]) -> ComplexHypervector:
    """Encode a spatial location as a hypervector.

    For 2D: S(x, y) = X^x ⊗ Y^y
    For 3D: S(x, y, z) = X^x ⊗ Y^y ⊗ Z^z

    Args:
        coordinates: List of coordinate values, one per axis.

    Returns:
        ComplexHypervector representing the spatial location.

    Raises:
        ValueError: If coordinates length doesn't match num_axes.

    Example:
        >>> location = ssp.encode_location([3.5, 2.1])  # 2D point
        >>> location3d = ssp.encode_location([1.0, 2.0, 3.0])  # 3D point
    """
    if len(coordinates) != self.config.num_axes:
        raise ValueError(f"Expected {self.config.num_axes} coordinates, got {len(coordinates)}")

    assert self.config.axis_names is not None  # Always set in __init__
    return self.encoder.encode_multi(self.config.axis_names, coordinates)

bind_object_location(object_name, coordinates)

Bind an object to a spatial location.

Creates: Object ⊗ S(x, y) = Object ⊗ X^x ⊗ Y^y

Parameters:

Name Type Description Default
object_name str

Name of object in memory.

required
coordinates list[float]

Spatial coordinates for the object.

required

Returns:

Type Description
ComplexHypervector

ComplexHypervector representing object-at-location.

Raises:

Type Description
KeyError

If object_name not in memory.

ValueError

If coordinates length doesn't match num_axes.

Example

memory.add("apple") apple_at_pos = ssp.bind_object_location("apple", [3.5, 2.1])

Source code in vsax/spatial/ssp.py
def bind_object_location(
    self, object_name: str, coordinates: list[float]
) -> ComplexHypervector:
    """Bind an object to a spatial location.

    Creates: Object ⊗ S(x, y) = Object ⊗ X^x ⊗ Y^y

    Args:
        object_name: Name of object in memory.
        coordinates: Spatial coordinates for the object.

    Returns:
        ComplexHypervector representing object-at-location.

    Raises:
        KeyError: If object_name not in memory.
        ValueError: If coordinates length doesn't match num_axes.

    Example:
        >>> memory.add("apple")
        >>> apple_at_pos = ssp.bind_object_location("apple", [3.5, 2.1])
    """
    # Get object hypervector
    object_hv = self.memory[object_name]

    # Encode location
    location_hv = self.encode_location(coordinates)

    # Bind object to location
    result_vec = self.model.opset.bind(object_hv.vec, location_hv.vec)

    return ComplexHypervector(result_vec)

query_location(scene, coordinates)

Query what object is at a given location in the scene.

For scene containing Object ⊗ S(x, y), querying at (x, y) returns a vector similar to Object.

Parameters:

Name Type Description Default
scene ComplexHypervector

Scene hypervector (typically a bundle of object-location pairs).

required
coordinates list[float]

Location to query.

required

Returns:

Type Description
ComplexHypervector

ComplexHypervector representing the object at that location.

Raises:

Type Description
ValueError

If coordinates length doesn't match num_axes.

Example
Create scene with apple at (3.5, 2.1)

scene = ssp.bind_object_location("apple", [3.5, 2.1])

Query: what's at (3.5, 2.1)?

result = ssp.query_location(scene, [3.5, 2.1])

result should be similar to memory["apple"]
Source code in vsax/spatial/ssp.py
def query_location(
    self, scene: ComplexHypervector, coordinates: list[float]
) -> ComplexHypervector:
    """Query what object is at a given location in the scene.

    For scene containing Object ⊗ S(x, y), querying at (x, y) returns
    a vector similar to Object.

    Args:
        scene: Scene hypervector (typically a bundle of object-location pairs).
        coordinates: Location to query.

    Returns:
        ComplexHypervector representing the object at that location.

    Raises:
        ValueError: If coordinates length doesn't match num_axes.

    Example:
        >>> # Create scene with apple at (3.5, 2.1)
        >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
        >>> # Query: what's at (3.5, 2.1)?
        >>> result = ssp.query_location(scene, [3.5, 2.1])
        >>> # result should be similar to memory["apple"]
    """
    # Encode query location
    location_hv = self.encode_location(coordinates)

    # Unbind: Scene ⊗ S(x,y)^(-1) ≈ Object
    inv_location = self.model.opset.inverse(location_hv.vec)
    result_vec = self.model.opset.bind(scene.vec, inv_location)

    return ComplexHypervector(result_vec)

query_object(scene, object_name)

Query where an object is located in the scene.

For scene containing Object ⊗ S(x, y), querying for Object returns a vector similar to S(x, y).

Parameters:

Name Type Description Default
scene ComplexHypervector

Scene hypervector.

required
object_name str

Name of object to locate.

required

Returns:

Type Description
ComplexHypervector

ComplexHypervector representing the location of the object.

Raises:

Type Description
KeyError

If object_name not in memory.

Example

scene = ssp.bind_object_location("apple", [3.5, 2.1])

Query: where is the apple?

location_hv = ssp.query_object(scene, "apple")

location_hv should be similar to ssp.encode_location([3.5, 2.1])
Source code in vsax/spatial/ssp.py
def query_object(self, scene: ComplexHypervector, object_name: str) -> ComplexHypervector:
    """Query where an object is located in the scene.

    For scene containing Object ⊗ S(x, y), querying for Object returns
    a vector similar to S(x, y).

    Args:
        scene: Scene hypervector.
        object_name: Name of object to locate.

    Returns:
        ComplexHypervector representing the location of the object.

    Raises:
        KeyError: If object_name not in memory.

    Example:
        >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
        >>> # Query: where is the apple?
        >>> location_hv = ssp.query_object(scene, "apple")
        >>> # location_hv should be similar to ssp.encode_location([3.5, 2.1])
    """
    # Get object hypervector
    object_hv = self.memory[object_name]

    # Unbind: Scene ⊗ Object^(-1) ≈ S(x,y)
    inv_object = self.model.opset.inverse(object_hv.vec)
    result_vec = self.model.opset.bind(scene.vec, inv_object)

    return ComplexHypervector(result_vec)

shift_scene(scene, offset)

Shift all objects in a scene by a displacement vector.

For scene S containing objects at various locations, shift by (dx, dy) moves all objects by that offset.

(Object ⊗ X^x ⊗ Y^y) ⊗ (X^dx ⊗ Y^dy) =

Object ⊗ X^(x+dx) ⊗ Y^(y+dy)

Parameters:

Name Type Description Default
scene ComplexHypervector

Scene hypervector to shift.

required
offset list[float]

Displacement vector, one value per axis.

required

Returns:

Type Description
ComplexHypervector

Shifted scene hypervector.

Raises:

Type Description
ValueError

If offset length doesn't match num_axes.

Example
Apple at (3.5, 2.1)

scene = ssp.bind_object_location("apple", [3.5, 2.1])

Shift scene by (1.0, -0.5)

shifted = ssp.shift_scene(scene, [1.0, -0.5])

Now apple is at (4.5, 1.6)
Source code in vsax/spatial/ssp.py
def shift_scene(self, scene: ComplexHypervector, offset: list[float]) -> ComplexHypervector:
    """Shift all objects in a scene by a displacement vector.

    For scene S containing objects at various locations, shift by (dx, dy)
    moves all objects by that offset.

    This works because: (Object ⊗ X^x ⊗ Y^y) ⊗ (X^dx ⊗ Y^dy) =
                         Object ⊗ X^(x+dx) ⊗ Y^(y+dy)

    Args:
        scene: Scene hypervector to shift.
        offset: Displacement vector, one value per axis.

    Returns:
        Shifted scene hypervector.

    Raises:
        ValueError: If offset length doesn't match num_axes.

    Example:
        >>> # Apple at (3.5, 2.1)
        >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
        >>> # Shift scene by (1.0, -0.5)
        >>> shifted = ssp.shift_scene(scene, [1.0, -0.5])
        >>> # Now apple is at (4.5, 1.6)
    """
    if len(offset) != self.config.num_axes:
        raise ValueError(f"Expected {self.config.num_axes} offset values, got {len(offset)}")

    # Encode displacement as a location
    displacement_hv = self.encode_location(offset)

    # Bind displacement to scene
    result_vec = self.model.opset.bind(scene.vec, displacement_hv.vec)

    return ComplexHypervector(result_vec)

decode_location(location_hv, search_range, resolution=20)

Decode approximate coordinates from a location hypervector.

Uses grid search to find coordinates that best match the encoded location.

Parameters:

Name Type Description Default
location_hv ComplexHypervector

Encoded location hypervector to decode.

required
search_range list[tuple[float, float]]

List of (min, max) tuples, one per axis.

required
resolution int

Number of grid points to sample per axis (default: 20).

20

Returns:

Type Description
list[float]

List of decoded coordinates (approximate).

Raises:

Type Description
ValueError

If search_range length doesn't match num_axes.

Example

location_hv = ssp.encode_location([3.5, 2.1]) decoded = ssp.decode_location( ... location_hv, ... search_range=[(0.0, 5.0), (0.0, 5.0)], ... resolution=50 ... )

decoded should be close to [3.5, 2.1]
Source code in vsax/spatial/ssp.py
def decode_location(
    self,
    location_hv: ComplexHypervector,
    search_range: list[tuple[float, float]],
    resolution: int = 20,
) -> list[float]:
    """Decode approximate coordinates from a location hypervector.

    Uses grid search to find coordinates that best match the encoded location.

    Args:
        location_hv: Encoded location hypervector to decode.
        search_range: List of (min, max) tuples, one per axis.
        resolution: Number of grid points to sample per axis (default: 20).

    Returns:
        List of decoded coordinates (approximate).

    Raises:
        ValueError: If search_range length doesn't match num_axes.

    Example:
        >>> location_hv = ssp.encode_location([3.5, 2.1])
        >>> decoded = ssp.decode_location(
        ...     location_hv,
        ...     search_range=[(0.0, 5.0), (0.0, 5.0)],
        ...     resolution=50
        ... )
        >>> # decoded should be close to [3.5, 2.1]
    """
    if len(search_range) != self.config.num_axes:
        raise ValueError(
            f"Expected {self.config.num_axes} search ranges, got {len(search_range)}"
        )

    # Create grid of candidate coordinates
    grids = [jnp.linspace(min_val, max_val, resolution) for min_val, max_val in search_range]

    # Generate all combinations using meshgrid
    mesh = jnp.meshgrid(*grids, indexing="ij")
    # Flatten and stack to get all candidate points
    candidates = jnp.stack([g.flatten() for g in mesh], axis=1)

    # Encode all candidate locations and compute similarities
    best_similarity = -jnp.inf
    best_coords = None

    for candidate in candidates:
        candidate_hv = self.encode_location(candidate.tolist())
        similarity = cosine_similarity(location_hv.vec, candidate_hv.vec)

        if similarity > best_similarity:
            best_similarity = similarity
            best_coords = candidate

    assert best_coords is not None  # At least one candidate should exist
    result: list[float] = best_coords.tolist()
    return result

SSPConfig

vsax.spatial.SSPConfig dataclass

Configuration for Spatial Semantic Pointers.

Attributes:

Name Type Description
dim int

Dimensionality of hypervectors (e.g., 512, 1024).

num_axes int

Number of spatial dimensions (1D, 2D, 3D, etc.).

scale Optional[float]

Optional scaling factor for spatial coordinates.

axis_names Optional[list[str]]

Optional custom names for axes (defaults to ["x", "y", "z", ...]).

Source code in vsax/spatial/ssp.py
@dataclass
class SSPConfig:
    """Configuration for Spatial Semantic Pointers.

    Attributes:
        dim: Dimensionality of hypervectors (e.g., 512, 1024).
        num_axes: Number of spatial dimensions (1D, 2D, 3D, etc.).
        scale: Optional scaling factor for spatial coordinates.
        axis_names: Optional custom names for axes (defaults to ["x", "y", "z", ...]).
    """

    dim: int = 512
    num_axes: int = 2
    scale: Optional[float] = None
    axis_names: Optional[list[str]] = None

    def __post_init__(self) -> None:
        """Set default axis names if not provided."""
        if self.axis_names is None:
            # Default axis names: x, y, z, w, v, u, ...
            default_names = ["x", "y", "z", "w", "v", "u"]
            if self.num_axes <= len(default_names):
                self.axis_names = default_names[: self.num_axes]
            else:
                # For higher dimensions, use axis_0, axis_1, ...
                self.axis_names = [f"axis_{i}" for i in range(self.num_axes)]
        elif len(self.axis_names) != self.num_axes:
            raise ValueError(
                f"axis_names length ({len(self.axis_names)}) must match num_axes ({self.num_axes})"
            )

Functions

__post_init__()

Set default axis names if not provided.

Source code in vsax/spatial/ssp.py
def __post_init__(self) -> None:
    """Set default axis names if not provided."""
    if self.axis_names is None:
        # Default axis names: x, y, z, w, v, u, ...
        default_names = ["x", "y", "z", "w", "v", "u"]
        if self.num_axes <= len(default_names):
            self.axis_names = default_names[: self.num_axes]
        else:
            # For higher dimensions, use axis_0, axis_1, ...
            self.axis_names = [f"axis_{i}" for i in range(self.num_axes)]
    elif len(self.axis_names) != self.num_axes:
        raise ValueError(
            f"axis_names length ({len(self.axis_names)}) must match num_axes ({self.num_axes})"
        )

Utilities

create_spatial_scene

vsax.spatial.utils.create_spatial_scene(ssp, objects_and_locations)

Create a scene by bundling multiple object-location bindings.

Convenience function that: 1. Binds each object to its location: Object_i ⊗ S(x_i, y_i) 2. Bundles all bindings into a single scene hypervector

Parameters:

Name Type Description Default
ssp SpatialSemanticPointers

SpatialSemanticPointers instance.

required
objects_and_locations dict[str, list[float]]

Dictionary mapping object names to coordinates. Example: {"apple": [1.0, 2.0], "banana": [3.0, 4.0]}

required

Returns:

Type Description
ComplexHypervector

Scene hypervector containing all object-location pairs.

Raises:

Type Description
ValueError

If objects_and_locations is empty.

KeyError

If any object name is not in ssp.memory.

Example

scene = create_spatial_scene(ssp, { ... "apple": [1.0, 2.0], ... "banana": [3.0, 4.0], ... "cherry": [5.0, 1.0] ... })

Query: what's at (1.0, 2.0)?

result = ssp.query_location(scene, [1.0, 2.0])

Source code in vsax/spatial/utils.py
def create_spatial_scene(
    ssp: SpatialSemanticPointers,
    objects_and_locations: dict[str, list[float]],
) -> ComplexHypervector:
    """Create a scene by bundling multiple object-location bindings.

    Convenience function that:
    1. Binds each object to its location: Object_i ⊗ S(x_i, y_i)
    2. Bundles all bindings into a single scene hypervector

    Args:
        ssp: SpatialSemanticPointers instance.
        objects_and_locations: Dictionary mapping object names to coordinates.
            Example: {"apple": [1.0, 2.0], "banana": [3.0, 4.0]}

    Returns:
        Scene hypervector containing all object-location pairs.

    Raises:
        ValueError: If objects_and_locations is empty.
        KeyError: If any object name is not in ssp.memory.

    Example:
        >>> scene = create_spatial_scene(ssp, {
        ...     "apple": [1.0, 2.0],
        ...     "banana": [3.0, 4.0],
        ...     "cherry": [5.0, 1.0]
        ... })
        >>> # Query: what's at (1.0, 2.0)?
        >>> result = ssp.query_location(scene, [1.0, 2.0])
    """
    if not objects_and_locations:
        raise ValueError("objects_and_locations cannot be empty")

    # Bind each object to its location
    bindings = []
    for obj_name, coords in objects_and_locations.items():
        binding = ssp.bind_object_location(obj_name, coords)
        bindings.append(binding.vec)

    # Bundle all bindings
    scene_vec = ssp.model.opset.bundle(*bindings)
    return ComplexHypervector(scene_vec)

similarity_map_2d

vsax.spatial.utils.similarity_map_2d(ssp, query_hv, x_range, y_range, resolution=50)

Generate 2D similarity heatmap for visualization.

Computes similarity between query_hv and encoded locations across a 2D grid. Useful for visualizing "where" an object is located, or what regions match a given hypervector.

Parameters:

Name Type Description Default
ssp SpatialSemanticPointers

SpatialSemanticPointers instance (must be 2D).

required
query_hv ComplexHypervector

Query hypervector to compare against locations.

required
x_range tuple[float, float]

(min_x, max_x) range for grid.

required
y_range tuple[float, float]

(min_y, max_y) range for grid.

required
resolution int

Number of grid points per axis (default: 50).

50

Returns:

Type Description
tuple[ndarray, ndarray, ndarray]

Tuple of (X, Y, similarities): - X: 2D meshgrid of x coordinates (resolution x resolution) - Y: 2D meshgrid of y coordinates (resolution x resolution) - similarities: 2D array of similarity values (resolution x resolution)

Raises:

Type Description
ValueError

If ssp is not 2D (num_axes != 2).

Example

Create scene with apple at (3.5, 2.1)

scene = ssp.bind_object_location("apple", [3.5, 2.1])

Query where apple is

apple_location = ssp.query_object(scene, "apple")

Generate heatmap

X, Y, sims = similarity_map_2d( ... ssp, apple_location, ... x_range=(0, 5), y_range=(0, 5), resolution=50 ... )

Peak should be near (3.5, 2.1)

Source code in vsax/spatial/utils.py
def similarity_map_2d(
    ssp: SpatialSemanticPointers,
    query_hv: ComplexHypervector,
    x_range: tuple[float, float],
    y_range: tuple[float, float],
    resolution: int = 50,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Generate 2D similarity heatmap for visualization.

    Computes similarity between query_hv and encoded locations across a 2D grid.
    Useful for visualizing "where" an object is located, or what regions match
    a given hypervector.

    Args:
        ssp: SpatialSemanticPointers instance (must be 2D).
        query_hv: Query hypervector to compare against locations.
        x_range: (min_x, max_x) range for grid.
        y_range: (min_y, max_y) range for grid.
        resolution: Number of grid points per axis (default: 50).

    Returns:
        Tuple of (X, Y, similarities):
            - X: 2D meshgrid of x coordinates (resolution x resolution)
            - Y: 2D meshgrid of y coordinates (resolution x resolution)
            - similarities: 2D array of similarity values (resolution x resolution)

    Raises:
        ValueError: If ssp is not 2D (num_axes != 2).

    Example:
        >>> # Create scene with apple at (3.5, 2.1)
        >>> scene = ssp.bind_object_location("apple", [3.5, 2.1])
        >>> # Query where apple is
        >>> apple_location = ssp.query_object(scene, "apple")
        >>> # Generate heatmap
        >>> X, Y, sims = similarity_map_2d(
        ...     ssp, apple_location,
        ...     x_range=(0, 5), y_range=(0, 5), resolution=50
        ... )
        >>> # Peak should be near (3.5, 2.1)
    """
    if ssp.config.num_axes != 2:
        raise ValueError(
            f"similarity_map_2d only works with 2D SSP. Got num_axes={ssp.config.num_axes}"
        )

    # Create grid
    x = jnp.linspace(x_range[0], x_range[1], resolution)
    y = jnp.linspace(y_range[0], y_range[1], resolution)
    X, Y = jnp.meshgrid(x, y)

    # Compute similarities
    similarities = jnp.zeros((resolution, resolution))

    for i in range(resolution):
        for j in range(resolution):
            loc = ssp.encode_location([X[i, j].item(), Y[i, j].item()])
            sim = cosine_similarity(query_hv.vec, loc.vec)
            similarities = similarities.at[i, j].set(sim)

    return X, Y, similarities

plot_ssp_2d_scene

vsax.spatial.utils.plot_ssp_2d_scene(ssp, scene, object_names, x_range=(0, 5), y_range=(0, 5), resolution=30, figsize=(12, 4))

Plot 2D scene showing where each object is located.

Creates a figure with subplots showing similarity heatmaps for each object. Requires matplotlib.

Parameters:

Name Type Description Default
ssp SpatialSemanticPointers

SpatialSemanticPointers instance (must be 2D).

required
scene ComplexHypervector

Scene hypervector containing object-location bindings.

required
object_names list[str]

List of object names to visualize.

required
x_range tuple[float, float]

(min_x, max_x) for plot (default: (0, 5)).

(0, 5)
y_range tuple[float, float]

(min_y, max_y) for plot (default: (0, 5)).

(0, 5)
resolution int

Grid resolution (default: 30).

30
figsize tuple[int, int]

Figure size (default: (12, 4)).

(12, 4)

Returns:

Type Description
Any

Matplotlib Figure object.

Raises:

Type Description
ImportError

If matplotlib is not installed.

ValueError

If ssp is not 2D.

KeyError

If any object name is not in ssp.memory.

Example

import matplotlib.pyplot as plt scene = create_spatial_scene(ssp, { ... "apple": [1.0, 2.0], ... "banana": [3.0, 4.0] ... }) fig = plot_ssp_2d_scene(ssp, scene, ["apple", "banana"]) plt.show()

Source code in vsax/spatial/utils.py
def plot_ssp_2d_scene(
    ssp: SpatialSemanticPointers,
    scene: ComplexHypervector,
    object_names: list[str],
    x_range: tuple[float, float] = (0, 5),
    y_range: tuple[float, float] = (0, 5),
    resolution: int = 30,
    figsize: tuple[int, int] = (12, 4),
) -> Any:  # matplotlib.figure.Figure when matplotlib is installed
    """Plot 2D scene showing where each object is located.

    Creates a figure with subplots showing similarity heatmaps for each object.
    Requires matplotlib.

    Args:
        ssp: SpatialSemanticPointers instance (must be 2D).
        scene: Scene hypervector containing object-location bindings.
        object_names: List of object names to visualize.
        x_range: (min_x, max_x) for plot (default: (0, 5)).
        y_range: (min_y, max_y) for plot (default: (0, 5)).
        resolution: Grid resolution (default: 30).
        figsize: Figure size (default: (12, 4)).

    Returns:
        Matplotlib Figure object.

    Raises:
        ImportError: If matplotlib is not installed.
        ValueError: If ssp is not 2D.
        KeyError: If any object name is not in ssp.memory.

    Example:
        >>> import matplotlib.pyplot as plt
        >>> scene = create_spatial_scene(ssp, {
        ...     "apple": [1.0, 2.0],
        ...     "banana": [3.0, 4.0]
        ... })
        >>> fig = plot_ssp_2d_scene(ssp, scene, ["apple", "banana"])
        >>> plt.show()
    """
    try:
        import matplotlib.pyplot as plt  # type: ignore[import-not-found]
    except ImportError:
        raise ImportError(
            "matplotlib is required for plotting. Install with: pip install matplotlib"
        )

    if ssp.config.num_axes != 2:
        raise ValueError(
            f"plot_ssp_2d_scene only works with 2D SSP. Got num_axes={ssp.config.num_axes}"
        )

    n_objects = len(object_names)
    fig, axes = plt.subplots(1, n_objects, figsize=figsize)

    # Handle single object case
    if n_objects == 1:
        axes = [axes]

    for idx, obj_name in enumerate(object_names):
        # Query where this object is
        location_hv = ssp.query_object(scene, obj_name)

        # Generate similarity map
        X, Y, sims = similarity_map_2d(ssp, location_hv, x_range, y_range, resolution)

        # Plot heatmap
        ax = axes[idx]
        im = ax.contourf(X, Y, sims, levels=20, cmap="viridis")
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_title(f"Location of '{obj_name}'")
        ax.set_aspect("equal")
        plt.colorbar(im, ax=ax, label="Similarity")

    plt.tight_layout()
    return fig

region_query

vsax.spatial.utils.region_query(ssp, scene, object_names, center, radius, resolution=20)

Find which objects are within a spatial region.

Searches a circular region (2D) or spherical region (3D) around a center point and returns objects with high similarity to locations in that region.

Parameters:

Name Type Description Default
ssp SpatialSemanticPointers

SpatialSemanticPointers instance.

required
scene ComplexHypervector

Scene hypervector.

required
object_names list[str]

List of candidate object names to check.

required
center list[float]

Center coordinates of the search region.

required
radius float

Radius of the search region.

required
resolution int

Number of sample points to check in the region (default: 20).

20

Returns:

Type Description
dict[str, float]

Dictionary mapping object names to maximum similarity scores.

dict[str, float]

Higher scores indicate the object is likely in the region.

Raises:

Type Description
ValueError

If center length doesn't match num_axes.

KeyError

If any object name is not in ssp.memory.

Example

scene = create_spatial_scene(ssp, { ... "apple": [1.0, 1.0], ... "banana": [3.0, 3.0], ... "cherry": [5.0, 5.0] ... })

Search region around (3.0, 3.0) with radius 0.5

results = region_query( ... ssp, scene, ["apple", "banana", "cherry"], ... center=[3.0, 3.0], radius=0.5 ... )

results["banana"] should be highest

Source code in vsax/spatial/utils.py
def region_query(
    ssp: SpatialSemanticPointers,
    scene: ComplexHypervector,
    object_names: list[str],
    center: list[float],
    radius: float,
    resolution: int = 20,
) -> dict[str, float]:
    """Find which objects are within a spatial region.

    Searches a circular region (2D) or spherical region (3D) around a center point
    and returns objects with high similarity to locations in that region.

    Args:
        ssp: SpatialSemanticPointers instance.
        scene: Scene hypervector.
        object_names: List of candidate object names to check.
        center: Center coordinates of the search region.
        radius: Radius of the search region.
        resolution: Number of sample points to check in the region (default: 20).

    Returns:
        Dictionary mapping object names to maximum similarity scores.
        Higher scores indicate the object is likely in the region.

    Raises:
        ValueError: If center length doesn't match num_axes.
        KeyError: If any object name is not in ssp.memory.

    Example:
        >>> scene = create_spatial_scene(ssp, {
        ...     "apple": [1.0, 1.0],
        ...     "banana": [3.0, 3.0],
        ...     "cherry": [5.0, 5.0]
        ... })
        >>> # Search region around (3.0, 3.0) with radius 0.5
        >>> results = region_query(
        ...     ssp, scene, ["apple", "banana", "cherry"],
        ...     center=[3.0, 3.0], radius=0.5
        ... )
        >>> # results["banana"] should be highest
    """
    if len(center) != ssp.config.num_axes:
        raise ValueError(f"center must have {ssp.config.num_axes} coordinates, got {len(center)}")

    # Sample points in the region
    if ssp.config.num_axes == 2:
        # 2D: sample points in a circle
        angles = jnp.linspace(0, 2 * jnp.pi, resolution)
        radii = jnp.linspace(0, radius, max(resolution // 4, 1))

        sample_points = []
        for r in radii:
            for angle in angles:
                x = center[0] + r * jnp.cos(angle)
                y = center[1] + r * jnp.sin(angle)
                sample_points.append([x.item(), y.item()])
    elif ssp.config.num_axes == 3:
        # 3D: sample points in a sphere (simplified uniform sampling)
        phi = jnp.linspace(0, jnp.pi, resolution)
        theta = jnp.linspace(0, 2 * jnp.pi, resolution)
        radii = jnp.linspace(0, radius, max(resolution // 4, 1))

        sample_points = []
        for r in radii:
            for p in phi:
                for t in theta:
                    x = center[0] + r * jnp.sin(p) * jnp.cos(t)
                    y = center[1] + r * jnp.sin(p) * jnp.sin(t)
                    z = center[2] + r * jnp.cos(p)
                    sample_points.append([x.item(), y.item(), z.item()])
    else:
        # Higher dimensions: sample along axes and diagonals
        sample_points = []
        for _ in range(resolution * ssp.config.num_axes):
            offset = jnp.random.uniform(-radius, radius, ssp.config.num_axes)
            point = [center[i] + offset[i].item() for i in range(ssp.config.num_axes)]
            sample_points.append(point)

    # For each object, compute max similarity to sampled locations
    results = {}
    for obj_name in object_names:
        location_hv = ssp.query_object(scene, obj_name)
        max_sim = -jnp.inf

        for point in sample_points:
            point_hv = ssp.encode_location(point)
            sim = cosine_similarity(location_hv.vec, point_hv.vec)
            max_sim = max(max_sim, sim)

        results[obj_name] = float(max_sim)

    return results