Skip to content

Vector Function Architecture

Function approximation in Reproducing Kernel Hilbert Space (RKHS).

NEW in v1.2.0 - Based on Frady et al. (2021).

VectorFunctionEncoder

vsax.vfa.VectorFunctionEncoder

Vector Function Architecture for encoding functions in RKHS.

Represents functions f(x) in a Reproducing Kernel Hilbert Space (RKHS) using hypervectors: f(x) ≈ Σ α_i * z_i^x

where z_i are basis vectors and α_i are coefficients learned from samples.

This enables
  • Function approximation from samples
  • Point evaluation: f(x_query)
  • Function arithmetic: αf + βg
  • Function shifting: f(x - shift)
  • Function convolution: f * g

Based on Frady et al. 2021 which demonstrates that VFA can represent and manipulate functions using vector symbolic operations.

Attributes:

Name Type Description
model

VSAModel instance (must use ComplexHypervector).

memory

VSAMemory for storing basis vectors.

kernel_config

KernelConfig for basis vector sampling.

basis_vector

The function basis vector z (sampled once).

Example

import jax import jax.numpy as jnp from vsax import create_fhrr_model, VSAMemory from vsax.vfa import VectorFunctionEncoder, KernelConfig

Create VFA encoder

model = create_fhrr_model(dim=512, key=jax.random.PRNGKey(0)) memory = VSAMemory(model) vfa = VectorFunctionEncoder(model, memory)

Sample a simple function

x = jnp.linspace(0, 2*jnp.pi, 20) y = jnp.sin(x)

Encode the function

f_hv = vfa.encode_function_1d(x, y)

Evaluate at a query point

y_pred = vfa.evaluate_1d(f_hv, 1.5)

See Also
  • Frady et al. 2021: "Computing on Functions Using Randomized Vector Representations"
Source code in vsax/vfa/function_encoder.py
class VectorFunctionEncoder:
    """Vector Function Architecture for encoding functions in RKHS.

    Represents functions f(x) in a Reproducing Kernel Hilbert Space (RKHS)
    using hypervectors:
        f(x) ≈ Σ α_i * z_i^x

    where z_i are basis vectors and α_i are coefficients learned from samples.

    This enables:
        - Function approximation from samples
        - Point evaluation: f(x_query)
        - Function arithmetic: α*f + β*g
        - Function shifting: f(x - shift)
        - Function convolution: f * g

    Based on Frady et al. 2021 which demonstrates that VFA can represent
    and manipulate functions using vector symbolic operations.

    Attributes:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for storing basis vectors.
        kernel_config: KernelConfig for basis vector sampling.
        basis_vector: The function basis vector z (sampled once).

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from vsax import create_fhrr_model, VSAMemory
        >>> from vsax.vfa import VectorFunctionEncoder, KernelConfig
        >>>
        >>> # Create VFA encoder
        >>> model = create_fhrr_model(dim=512, key=jax.random.PRNGKey(0))
        >>> memory = VSAMemory(model)
        >>> vfa = VectorFunctionEncoder(model, memory)
        >>>
        >>> # Sample a simple function
        >>> x = jnp.linspace(0, 2*jnp.pi, 20)
        >>> y = jnp.sin(x)
        >>>
        >>> # Encode the function
        >>> f_hv = vfa.encode_function_1d(x, y)
        >>>
        >>> # Evaluate at a query point
        >>> y_pred = vfa.evaluate_1d(f_hv, 1.5)

    See Also:
        - Frady et al. 2021: "Computing on Functions Using Randomized
          Vector Representations"
    """

    def __init__(
        self,
        model: VSAModel,
        memory: VSAMemory,
        kernel_config: Optional[KernelConfig] = None,
        basis_key: Optional[jax.Array] = None,
    ) -> None:
        """Initialize VectorFunctionEncoder.

        Args:
            model: VSAModel instance (must use ComplexHypervector).
            memory: VSAMemory for storing basis vectors.
            kernel_config: KernelConfig for basis sampling (default: UNIFORM).
            basis_key: Optional JAX key for basis sampling. If None, uses PRNGKey(0).

        Raises:
            TypeError: If model doesn't use ComplexHypervector.
        """
        from vsax.representations import ComplexHypervector

        if model.rep_cls != ComplexHypervector:
            raise TypeError(
                "VectorFunctionEncoder requires ComplexHypervector (FHRR) model. "
                f"Got {model.rep_cls.__name__}. "
                "Use create_fhrr_model() to create a compatible model."
            )

        self.model = model
        self.memory = memory

        # Ensure kernel_config dim matches model dim
        if kernel_config is not None:
            self.kernel_config = kernel_config
        else:
            self.kernel_config = KernelConfig(dim=model.dim)

        # Sample a single basis vector for function encoding
        key = basis_key if basis_key is not None else jax.random.PRNGKey(0)
        self.basis_vector = sample_kernel_basis(self.kernel_config, key)

    def encode_function_1d(
        self,
        x_samples: jnp.ndarray,
        y_samples: jnp.ndarray,
        regularization: float = 1e-6,
    ) -> ComplexHypervector:
        """Encode a 1D function from samples.

        Learns coefficients α such that f(x) ≈ Σ α_i * z^x_i for the sample points.
        Uses least squares to fit the coefficients.

        Args:
            x_samples: Array of x coordinates (shape: (n_samples,)).
            y_samples: Array of y values (shape: (n_samples,)).
            regularization: Regularization parameter for least squares (default: 1e-6).

        Returns:
            ComplexHypervector encoding the function.

        Raises:
            ValueError: If x_samples and y_samples have different lengths.

        Example:
            >>> # Encode sine function
            >>> x = jnp.linspace(0, 2*jnp.pi, 50)
            >>> y = jnp.sin(x)
            >>> f_hv = vfa.encode_function_1d(x, y)
        """
        if len(x_samples) != len(y_samples):
            raise ValueError(
                f"x_samples and y_samples must have same length. "
                f"Got {len(x_samples)} and {len(y_samples)}"
            )

        n_samples = len(x_samples)

        # Build design matrix: each row is z^x_i
        # Z[i] = basis_vector^x_samples[i]
        design_matrix = jnp.zeros((n_samples, self.model.dim), dtype=jnp.complex64)

        for i in range(n_samples):
            # Raise basis vector to power x_samples[i]
            powered = jnp.power(self.basis_vector, x_samples[i])
            design_matrix = design_matrix.at[i].set(powered)

        # Solve for coefficients: Z * alpha = y
        # Using regularized least squares: alpha = (Z^H Z + lambda I)^(-1) Z^H y
        ZH_Z = jnp.dot(design_matrix.conj().T, design_matrix)
        reg_term = regularization * jnp.eye(self.model.dim)
        ZH_y = jnp.dot(design_matrix.conj().T, y_samples)

        # Solve linear system
        coefficients = jnp.linalg.solve(ZH_Z + reg_term, ZH_y)

        return ComplexHypervector(coefficients)

    def evaluate_1d(
        self,
        function_hv: ComplexHypervector,
        x_query: float,
    ) -> float:
        """Evaluate encoded function at a query point.

        Computes f(x_query) = <function_hv, z^x_query> where <·,·> is inner product.

        Args:
            function_hv: Encoded function hypervector (from encode_function_1d).
            x_query: Point at which to evaluate the function.

        Returns:
            Estimated function value at x_query (real-valued).

        Example:
            >>> f_hv = vfa.encode_function_1d(x_train, y_train)
            >>> y_pred = vfa.evaluate_1d(f_hv, 1.5)
        """
        # Compute z^x_query
        query_vec = jnp.power(self.basis_vector, x_query)

        # Inner product: f(x) = <coefficients, z^x>
        result = jnp.vdot(query_vec, function_hv.vec)

        # Return real part (imaginary part should be ~0 for real functions)
        return float(jnp.real(result))

    def add_functions(
        self,
        f1: ComplexHypervector,
        f2: ComplexHypervector,
        alpha: float = 1.0,
        beta: float = 1.0,
    ) -> ComplexHypervector:
        """Compute linear combination of two functions.

        Returns h = alpha * f1 + beta * f2

        Args:
            f1: First encoded function.
            f2: Second encoded function.
            alpha: Coefficient for f1 (default: 1.0).
            beta: Coefficient for f2 (default: 1.0).

        Returns:
            Encoded function representing alpha*f1 + beta*f2.

        Example:
            >>> # Compute f + g
            >>> h = vfa.add_functions(f_hv, g_hv)
            >>> # Compute 2*f - 0.5*g
            >>> h = vfa.add_functions(f_hv, g_hv, alpha=2.0, beta=-0.5)
        """
        result = alpha * f1.vec + beta * f2.vec
        return ComplexHypervector(result)

    def shift_function(
        self,
        function_hv: ComplexHypervector,
        shift: float,
    ) -> ComplexHypervector:
        """Shift a function: f(x) -> f(x - shift).

        Uses the property: if f(x) = Σ α_i * z^x, then
        f(x - s) = Σ α_i * z^(x-s) = Σ α_i * z^(-s) * z^x = (z^(-s) ⊙ α) * z^x

        where ⊙ is element-wise multiplication.

        Args:
            function_hv: Encoded function to shift.
            shift: Amount to shift (positive shifts right).

        Returns:
            Encoded function representing f(x - shift).

        Example:
            >>> # Shift sine function to the right by π/2
            >>> x = jnp.linspace(0, 2*jnp.pi, 50)
            >>> y = jnp.sin(x)
            >>> f_hv = vfa.encode_function_1d(x, y)
            >>> shifted = vfa.shift_function(f_hv, jnp.pi/2)
            >>> # shifted should approximate cos(x)
        """
        # Compute shift factor: z^(-shift)
        shift_factor = jnp.power(self.basis_vector, -shift)

        # Apply element-wise multiplication
        result = shift_factor * function_hv.vec

        return ComplexHypervector(result)

    def convolve_functions(
        self,
        f1: ComplexHypervector,
        f2: ComplexHypervector,
    ) -> ComplexHypervector:
        """Compute convolution of two functions.

        For functions represented in VFA, convolution can be approximated
        by binding (circular convolution in frequency domain).

        Args:
            f1: First encoded function.
            f2: Second encoded function.

        Returns:
            Encoded function representing approximate convolution f1 * f2.

        Example:
            >>> # Convolve two functions
            >>> h = vfa.convolve_functions(f_hv, g_hv)

        Note:
            This is an approximation. The quality depends on the dimensionality
            and the specific functions being convolved.
        """
        # Use FHRR binding (circular convolution)
        result = self.model.opset.bind(f1.vec, f2.vec)
        return ComplexHypervector(result)

    def evaluate_batch(
        self,
        function_hv: ComplexHypervector,
        x_queries: jnp.ndarray,
    ) -> jnp.ndarray:
        """Evaluate function at multiple query points.

        Convenience function for batch evaluation.

        Args:
            function_hv: Encoded function hypervector.
            x_queries: Array of query points (shape: (n_queries,)).

        Returns:
            Array of function values (shape: (n_queries,)).

        Example:
            >>> x_test = jnp.linspace(0, 2*jnp.pi, 100)
            >>> y_pred = vfa.evaluate_batch(f_hv, x_test)
        """
        return jnp.array([self.evaluate_1d(function_hv, x) for x in x_queries])

Functions

__init__(model, memory, kernel_config=None, basis_key=None)

Initialize VectorFunctionEncoder.

Parameters:

Name Type Description Default
model VSAModel

VSAModel instance (must use ComplexHypervector).

required
memory VSAMemory

VSAMemory for storing basis vectors.

required
kernel_config Optional[KernelConfig]

KernelConfig for basis sampling (default: UNIFORM).

None
basis_key Optional[Array]

Optional JAX key for basis sampling. If None, uses PRNGKey(0).

None

Raises:

Type Description
TypeError

If model doesn't use ComplexHypervector.

Source code in vsax/vfa/function_encoder.py
def __init__(
    self,
    model: VSAModel,
    memory: VSAMemory,
    kernel_config: Optional[KernelConfig] = None,
    basis_key: Optional[jax.Array] = None,
) -> None:
    """Initialize VectorFunctionEncoder.

    Args:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for storing basis vectors.
        kernel_config: KernelConfig for basis sampling (default: UNIFORM).
        basis_key: Optional JAX key for basis sampling. If None, uses PRNGKey(0).

    Raises:
        TypeError: If model doesn't use ComplexHypervector.
    """
    from vsax.representations import ComplexHypervector

    if model.rep_cls != ComplexHypervector:
        raise TypeError(
            "VectorFunctionEncoder requires ComplexHypervector (FHRR) model. "
            f"Got {model.rep_cls.__name__}. "
            "Use create_fhrr_model() to create a compatible model."
        )

    self.model = model
    self.memory = memory

    # Ensure kernel_config dim matches model dim
    if kernel_config is not None:
        self.kernel_config = kernel_config
    else:
        self.kernel_config = KernelConfig(dim=model.dim)

    # Sample a single basis vector for function encoding
    key = basis_key if basis_key is not None else jax.random.PRNGKey(0)
    self.basis_vector = sample_kernel_basis(self.kernel_config, key)

encode_function_1d(x_samples, y_samples, regularization=1e-06)

Encode a 1D function from samples.

Learns coefficients α such that f(x) ≈ Σ α_i * z^x_i for the sample points. Uses least squares to fit the coefficients.

Parameters:

Name Type Description Default
x_samples ndarray

Array of x coordinates (shape: (n_samples,)).

required
y_samples ndarray

Array of y values (shape: (n_samples,)).

required
regularization float

Regularization parameter for least squares (default: 1e-6).

1e-06

Returns:

Type Description
ComplexHypervector

ComplexHypervector encoding the function.

Raises:

Type Description
ValueError

If x_samples and y_samples have different lengths.

Example
Encode sine function

x = jnp.linspace(0, 2*jnp.pi, 50) y = jnp.sin(x) f_hv = vfa.encode_function_1d(x, y)

Source code in vsax/vfa/function_encoder.py
def encode_function_1d(
    self,
    x_samples: jnp.ndarray,
    y_samples: jnp.ndarray,
    regularization: float = 1e-6,
) -> ComplexHypervector:
    """Encode a 1D function from samples.

    Learns coefficients α such that f(x) ≈ Σ α_i * z^x_i for the sample points.
    Uses least squares to fit the coefficients.

    Args:
        x_samples: Array of x coordinates (shape: (n_samples,)).
        y_samples: Array of y values (shape: (n_samples,)).
        regularization: Regularization parameter for least squares (default: 1e-6).

    Returns:
        ComplexHypervector encoding the function.

    Raises:
        ValueError: If x_samples and y_samples have different lengths.

    Example:
        >>> # Encode sine function
        >>> x = jnp.linspace(0, 2*jnp.pi, 50)
        >>> y = jnp.sin(x)
        >>> f_hv = vfa.encode_function_1d(x, y)
    """
    if len(x_samples) != len(y_samples):
        raise ValueError(
            f"x_samples and y_samples must have same length. "
            f"Got {len(x_samples)} and {len(y_samples)}"
        )

    n_samples = len(x_samples)

    # Build design matrix: each row is z^x_i
    # Z[i] = basis_vector^x_samples[i]
    design_matrix = jnp.zeros((n_samples, self.model.dim), dtype=jnp.complex64)

    for i in range(n_samples):
        # Raise basis vector to power x_samples[i]
        powered = jnp.power(self.basis_vector, x_samples[i])
        design_matrix = design_matrix.at[i].set(powered)

    # Solve for coefficients: Z * alpha = y
    # Using regularized least squares: alpha = (Z^H Z + lambda I)^(-1) Z^H y
    ZH_Z = jnp.dot(design_matrix.conj().T, design_matrix)
    reg_term = regularization * jnp.eye(self.model.dim)
    ZH_y = jnp.dot(design_matrix.conj().T, y_samples)

    # Solve linear system
    coefficients = jnp.linalg.solve(ZH_Z + reg_term, ZH_y)

    return ComplexHypervector(coefficients)

evaluate_1d(function_hv, x_query)

Evaluate encoded function at a query point.

Computes f(x_query) = where <·,·> is inner product.

Parameters:

Name Type Description Default
function_hv ComplexHypervector

Encoded function hypervector (from encode_function_1d).

required
x_query float

Point at which to evaluate the function.

required

Returns:

Type Description
float

Estimated function value at x_query (real-valued).

Example

f_hv = vfa.encode_function_1d(x_train, y_train) y_pred = vfa.evaluate_1d(f_hv, 1.5)

Source code in vsax/vfa/function_encoder.py
def evaluate_1d(
    self,
    function_hv: ComplexHypervector,
    x_query: float,
) -> float:
    """Evaluate encoded function at a query point.

    Computes f(x_query) = <function_hv, z^x_query> where <·,·> is inner product.

    Args:
        function_hv: Encoded function hypervector (from encode_function_1d).
        x_query: Point at which to evaluate the function.

    Returns:
        Estimated function value at x_query (real-valued).

    Example:
        >>> f_hv = vfa.encode_function_1d(x_train, y_train)
        >>> y_pred = vfa.evaluate_1d(f_hv, 1.5)
    """
    # Compute z^x_query
    query_vec = jnp.power(self.basis_vector, x_query)

    # Inner product: f(x) = <coefficients, z^x>
    result = jnp.vdot(query_vec, function_hv.vec)

    # Return real part (imaginary part should be ~0 for real functions)
    return float(jnp.real(result))

add_functions(f1, f2, alpha=1.0, beta=1.0)

Compute linear combination of two functions.

Returns h = alpha * f1 + beta * f2

Parameters:

Name Type Description Default
f1 ComplexHypervector

First encoded function.

required
f2 ComplexHypervector

Second encoded function.

required
alpha float

Coefficient for f1 (default: 1.0).

1.0
beta float

Coefficient for f2 (default: 1.0).

1.0

Returns:

Type Description
ComplexHypervector

Encoded function representing alphaf1 + betaf2.

Example
Compute f + g

h = vfa.add_functions(f_hv, g_hv)

Compute 2f - 0.5g

h = vfa.add_functions(f_hv, g_hv, alpha=2.0, beta=-0.5)

Source code in vsax/vfa/function_encoder.py
def add_functions(
    self,
    f1: ComplexHypervector,
    f2: ComplexHypervector,
    alpha: float = 1.0,
    beta: float = 1.0,
) -> ComplexHypervector:
    """Compute linear combination of two functions.

    Returns h = alpha * f1 + beta * f2

    Args:
        f1: First encoded function.
        f2: Second encoded function.
        alpha: Coefficient for f1 (default: 1.0).
        beta: Coefficient for f2 (default: 1.0).

    Returns:
        Encoded function representing alpha*f1 + beta*f2.

    Example:
        >>> # Compute f + g
        >>> h = vfa.add_functions(f_hv, g_hv)
        >>> # Compute 2*f - 0.5*g
        >>> h = vfa.add_functions(f_hv, g_hv, alpha=2.0, beta=-0.5)
    """
    result = alpha * f1.vec + beta * f2.vec
    return ComplexHypervector(result)

shift_function(function_hv, shift)

Shift a function: f(x) -> f(x - shift).

Uses the property: if f(x) = Σ α_i * z^x, then f(x - s) = Σ α_i * z^(x-s) = Σ α_i * z^(-s) * z^x = (z^(-s) ⊙ α) * z^x

where ⊙ is element-wise multiplication.

Parameters:

Name Type Description Default
function_hv ComplexHypervector

Encoded function to shift.

required
shift float

Amount to shift (positive shifts right).

required

Returns:

Type Description
ComplexHypervector

Encoded function representing f(x - shift).

Example
Shift sine function to the right by π/2

x = jnp.linspace(0, 2*jnp.pi, 50) y = jnp.sin(x) f_hv = vfa.encode_function_1d(x, y) shifted = vfa.shift_function(f_hv, jnp.pi/2)

shifted should approximate cos(x)
Source code in vsax/vfa/function_encoder.py
def shift_function(
    self,
    function_hv: ComplexHypervector,
    shift: float,
) -> ComplexHypervector:
    """Shift a function: f(x) -> f(x - shift).

    Uses the property: if f(x) = Σ α_i * z^x, then
    f(x - s) = Σ α_i * z^(x-s) = Σ α_i * z^(-s) * z^x = (z^(-s) ⊙ α) * z^x

    where ⊙ is element-wise multiplication.

    Args:
        function_hv: Encoded function to shift.
        shift: Amount to shift (positive shifts right).

    Returns:
        Encoded function representing f(x - shift).

    Example:
        >>> # Shift sine function to the right by π/2
        >>> x = jnp.linspace(0, 2*jnp.pi, 50)
        >>> y = jnp.sin(x)
        >>> f_hv = vfa.encode_function_1d(x, y)
        >>> shifted = vfa.shift_function(f_hv, jnp.pi/2)
        >>> # shifted should approximate cos(x)
    """
    # Compute shift factor: z^(-shift)
    shift_factor = jnp.power(self.basis_vector, -shift)

    # Apply element-wise multiplication
    result = shift_factor * function_hv.vec

    return ComplexHypervector(result)

convolve_functions(f1, f2)

Compute convolution of two functions.

For functions represented in VFA, convolution can be approximated by binding (circular convolution in frequency domain).

Parameters:

Name Type Description Default
f1 ComplexHypervector

First encoded function.

required
f2 ComplexHypervector

Second encoded function.

required

Returns:

Type Description
ComplexHypervector

Encoded function representing approximate convolution f1 * f2.

Example
Convolve two functions

h = vfa.convolve_functions(f_hv, g_hv)

Note

This is an approximation. The quality depends on the dimensionality and the specific functions being convolved.

Source code in vsax/vfa/function_encoder.py
def convolve_functions(
    self,
    f1: ComplexHypervector,
    f2: ComplexHypervector,
) -> ComplexHypervector:
    """Compute convolution of two functions.

    For functions represented in VFA, convolution can be approximated
    by binding (circular convolution in frequency domain).

    Args:
        f1: First encoded function.
        f2: Second encoded function.

    Returns:
        Encoded function representing approximate convolution f1 * f2.

    Example:
        >>> # Convolve two functions
        >>> h = vfa.convolve_functions(f_hv, g_hv)

    Note:
        This is an approximation. The quality depends on the dimensionality
        and the specific functions being convolved.
    """
    # Use FHRR binding (circular convolution)
    result = self.model.opset.bind(f1.vec, f2.vec)
    return ComplexHypervector(result)

evaluate_batch(function_hv, x_queries)

Evaluate function at multiple query points.

Convenience function for batch evaluation.

Parameters:

Name Type Description Default
function_hv ComplexHypervector

Encoded function hypervector.

required
x_queries ndarray

Array of query points (shape: (n_queries,)).

required

Returns:

Type Description
ndarray

Array of function values (shape: (n_queries,)).

Example

x_test = jnp.linspace(0, 2*jnp.pi, 100) y_pred = vfa.evaluate_batch(f_hv, x_test)

Source code in vsax/vfa/function_encoder.py
def evaluate_batch(
    self,
    function_hv: ComplexHypervector,
    x_queries: jnp.ndarray,
) -> jnp.ndarray:
    """Evaluate function at multiple query points.

    Convenience function for batch evaluation.

    Args:
        function_hv: Encoded function hypervector.
        x_queries: Array of query points (shape: (n_queries,)).

    Returns:
        Array of function values (shape: (n_queries,)).

    Example:
        >>> x_test = jnp.linspace(0, 2*jnp.pi, 100)
        >>> y_pred = vfa.evaluate_batch(f_hv, x_test)
    """
    return jnp.array([self.evaluate_1d(function_hv, x) for x in x_queries])

Kernel Configuration

KernelConfig

vsax.vfa.KernelConfig dataclass

Configuration for VFA kernel basis vectors.

Attributes:

Name Type Description
kernel_type KernelType

Type of kernel (UNIFORM, GAUSSIAN, LAPLACE, or CUSTOM).

bandwidth float

Kernel bandwidth parameter (default: 1.0). For GAUSSIAN: standard deviation of frequency distribution. For LAPLACE: scale parameter. For UNIFORM: ignored (uses full frequency range).

dim int

Dimensionality of basis vectors (default: 512).

custom_sampler Optional[Callable[[Array, int], ndarray]]

Optional custom sampling function for CUSTOM kernel. Should have signature: (key, dim) -> complex array of shape (dim,)

Example

Default: uniform kernel (standard FHRR)

config = KernelConfig()

Gaussian kernel with specific bandwidth

config = KernelConfig( ... kernel_type=KernelType.GAUSSIAN, ... bandwidth=2.0, ... dim=1024 ... )

Source code in vsax/vfa/kernels.py
@dataclass
class KernelConfig:
    """Configuration for VFA kernel basis vectors.

    Attributes:
        kernel_type: Type of kernel (UNIFORM, GAUSSIAN, LAPLACE, or CUSTOM).
        bandwidth: Kernel bandwidth parameter (default: 1.0).
            For GAUSSIAN: standard deviation of frequency distribution.
            For LAPLACE: scale parameter.
            For UNIFORM: ignored (uses full frequency range).
        dim: Dimensionality of basis vectors (default: 512).
        custom_sampler: Optional custom sampling function for CUSTOM kernel.
            Should have signature: (key, dim) -> complex array of shape (dim,)

    Example:
        >>> # Default: uniform kernel (standard FHRR)
        >>> config = KernelConfig()
        >>>
        >>> # Gaussian kernel with specific bandwidth
        >>> config = KernelConfig(
        ...     kernel_type=KernelType.GAUSSIAN,
        ...     bandwidth=2.0,
        ...     dim=1024
        ... )
    """

    kernel_type: KernelType = KernelType.UNIFORM
    bandwidth: float = 1.0
    dim: int = 512
    custom_sampler: Optional[Callable[[jax.Array, int], jnp.ndarray]] = None

    def __post_init__(self) -> None:
        """Validate configuration."""
        if self.bandwidth <= 0:
            raise ValueError(f"bandwidth must be positive, got {self.bandwidth}")
        if self.dim <= 0:
            raise ValueError(f"dim must be positive, got {self.dim}")
        if self.kernel_type == KernelType.CUSTOM and self.custom_sampler is None:
            raise ValueError("custom_sampler must be provided for CUSTOM kernel type")

Functions

__post_init__()

Validate configuration.

Source code in vsax/vfa/kernels.py
def __post_init__(self) -> None:
    """Validate configuration."""
    if self.bandwidth <= 0:
        raise ValueError(f"bandwidth must be positive, got {self.bandwidth}")
    if self.dim <= 0:
        raise ValueError(f"dim must be positive, got {self.dim}")
    if self.kernel_type == KernelType.CUSTOM and self.custom_sampler is None:
        raise ValueError("custom_sampler must be provided for CUSTOM kernel type")

KernelType

vsax.vfa.KernelType

Bases: Enum

Kernel types for VFA basis vector sampling.

Different kernel types correspond to different frequency sampling distributions, which affect the smoothness and approximation properties of encoded functions.

Attributes:

Name Type Description
UNIFORM

Uniform distribution over unit circle (standard FHRR). Equivalent to sampling phases uniformly from [0, 2π). Default choice, works well for most applications.

GAUSSIAN

Gaussian-weighted frequency distribution. Concentrates frequencies near center, produces smoother functions. Better for low-frequency functions.

LAPLACE

Laplace (double exponential) distribution. Heavier tails than Gaussian, allows more high-frequency content. Good for functions with sharp features.

CUSTOM

User-defined sampling function. Allows complete control over frequency distribution.

Source code in vsax/vfa/kernels.py
class KernelType(Enum):
    """Kernel types for VFA basis vector sampling.

    Different kernel types correspond to different frequency sampling
    distributions, which affect the smoothness and approximation
    properties of encoded functions.

    Attributes:
        UNIFORM: Uniform distribution over unit circle (standard FHRR).
            Equivalent to sampling phases uniformly from [0, 2π).
            Default choice, works well for most applications.

        GAUSSIAN: Gaussian-weighted frequency distribution.
            Concentrates frequencies near center, produces smoother functions.
            Better for low-frequency functions.

        LAPLACE: Laplace (double exponential) distribution.
            Heavier tails than Gaussian, allows more high-frequency content.
            Good for functions with sharp features.

        CUSTOM: User-defined sampling function.
            Allows complete control over frequency distribution.
    """

    UNIFORM = "uniform"
    GAUSSIAN = "gaussian"
    LAPLACE = "laplace"
    CUSTOM = "custom"

sample_kernel_basis

vsax.vfa.sample_kernel_basis(config, key)

Sample a basis vector according to kernel configuration.

Generates a random complex hypervector with frequency distribution determined by the kernel type. For VFA, these basis vectors are raised to fractional powers to encode function values.

Parameters:

Name Type Description Default
config KernelConfig

KernelConfig specifying kernel type and parameters.

required
key Array

JAX random key for sampling.

required

Returns:

Type Description
ndarray

Complex array of shape (config.dim,) representing a basis vector.

ndarray

All elements have unit magnitude (phase-only representation).

Raises:

Type Description
ValueError

If kernel_type is not recognized.

Example

key = jax.random.PRNGKey(42) config = KernelConfig(kernel_type=KernelType.UNIFORM, dim=512) basis = sample_kernel_basis(config, key) assert basis.shape == (512,) assert jnp.iscomplexobj(basis)

All magnitudes should be 1.0

assert jnp.allclose(jnp.abs(basis), 1.0)

Note

Currently only UNIFORM is fully implemented. GAUSSIAN and LAPLACE are placeholders for future extension.

Source code in vsax/vfa/kernels.py
def sample_kernel_basis(config: KernelConfig, key: jax.Array) -> jnp.ndarray:
    """Sample a basis vector according to kernel configuration.

    Generates a random complex hypervector with frequency distribution
    determined by the kernel type. For VFA, these basis vectors are raised
    to fractional powers to encode function values.

    Args:
        config: KernelConfig specifying kernel type and parameters.
        key: JAX random key for sampling.

    Returns:
        Complex array of shape (config.dim,) representing a basis vector.
        All elements have unit magnitude (phase-only representation).

    Raises:
        ValueError: If kernel_type is not recognized.

    Example:
        >>> key = jax.random.PRNGKey(42)
        >>> config = KernelConfig(kernel_type=KernelType.UNIFORM, dim=512)
        >>> basis = sample_kernel_basis(config, key)
        >>> assert basis.shape == (512,)
        >>> assert jnp.iscomplexobj(basis)
        >>> # All magnitudes should be 1.0
        >>> assert jnp.allclose(jnp.abs(basis), 1.0)

    Note:
        Currently only UNIFORM is fully implemented. GAUSSIAN and LAPLACE
        are placeholders for future extension.
    """
    if config.kernel_type == KernelType.UNIFORM:
        # Standard FHRR: sample phases uniformly from [0, 2π)
        phases = jax.random.uniform(key, shape=(config.dim,), minval=0, maxval=2 * jnp.pi)
        return jnp.exp(1j * phases)

    elif config.kernel_type == KernelType.GAUSSIAN:
        # Gaussian kernel: sample phases with Gaussian-weighted frequencies
        # For future implementation: sample frequencies from Gaussian,
        # then construct phases
        # For now, fall back to uniform
        phases = jax.random.uniform(key, shape=(config.dim,), minval=0, maxval=2 * jnp.pi)
        # TODO: Implement Gaussian frequency weighting
        # frequencies = jax.random.normal(key, (config.dim,)) * config.bandwidth
        # phases = ...
        return jnp.exp(1j * phases)

    elif config.kernel_type == KernelType.LAPLACE:
        # Laplace kernel: sample phases with Laplace-distributed frequencies
        # For future implementation
        phases = jax.random.uniform(key, shape=(config.dim,), minval=0, maxval=2 * jnp.pi)
        # TODO: Implement Laplace frequency weighting
        return jnp.exp(1j * phases)

    elif config.kernel_type == KernelType.CUSTOM:
        # Use custom sampler
        if config.custom_sampler is None:
            raise ValueError("custom_sampler must be provided for CUSTOM kernel")
        return config.custom_sampler(key, config.dim)

    else:
        raise ValueError(f"Unknown kernel type: {config.kernel_type}")

Applications

DensityEstimator

vsax.vfa.applications.DensityEstimator

Kernel density estimation using VFA (Frady et al. 2021 §7.2.1).

Estimates probability density functions from sample data by encoding the density as a hypervector. Uses VFA to represent the density function in RKHS.

The density is estimated as

p(x) ∝ Σ_i K(x - x_i)

where K is a kernel function and x_i are the sample points.

Attributes:

Name Type Description
vfa

VectorFunctionEncoder for function encoding.

bandwidth

Kernel bandwidth for density estimation.

Example

import jax from vsax import create_fhrr_model, VSAMemory from vsax.vfa import DensityEstimator

Create estimator

key = jax.random.PRNGKey(42) model = create_fhrr_model(dim=512, key=key) memory = VSAMemory(model) estimator = DensityEstimator(model, memory, bandwidth=0.5)

Fit to data

samples = jax.random.normal(key, (100,)) estimator.fit(samples)

Evaluate density

x_query = jnp.array([0.0, 1.0, 2.0]) densities = estimator.evaluate(x_query)

Source code in vsax/vfa/applications.py
class DensityEstimator:
    """Kernel density estimation using VFA (Frady et al. 2021 §7.2.1).

    Estimates probability density functions from sample data by encoding
    the density as a hypervector. Uses VFA to represent the density function
    in RKHS.

    The density is estimated as:
        p(x) ∝ Σ_i K(x - x_i)
    where K is a kernel function and x_i are the sample points.

    Attributes:
        vfa: VectorFunctionEncoder for function encoding.
        bandwidth: Kernel bandwidth for density estimation.

    Example:
        >>> import jax
        >>> from vsax import create_fhrr_model, VSAMemory
        >>> from vsax.vfa import DensityEstimator
        >>>
        >>> # Create estimator
        >>> key = jax.random.PRNGKey(42)
        >>> model = create_fhrr_model(dim=512, key=key)
        >>> memory = VSAMemory(model)
        >>> estimator = DensityEstimator(model, memory, bandwidth=0.5)
        >>>
        >>> # Fit to data
        >>> samples = jax.random.normal(key, (100,))
        >>> estimator.fit(samples)
        >>>
        >>> # Evaluate density
        >>> x_query = jnp.array([0.0, 1.0, 2.0])
        >>> densities = estimator.evaluate(x_query)
    """

    def __init__(
        self,
        model: VSAModel,
        memory: VSAMemory,
        bandwidth: float = 1.0,
        kernel_config: Optional[KernelConfig] = None,
    ) -> None:
        """Initialize DensityEstimator.

        Args:
            model: VSAModel instance (must use ComplexHypervector).
            memory: VSAMemory for storing basis vectors.
            bandwidth: Kernel bandwidth for density estimation (default: 1.0).
            kernel_config: Optional KernelConfig for VFA basis.

        Raises:
            TypeError: If model doesn't use ComplexHypervector.
            ValueError: If bandwidth is not positive.
        """
        if bandwidth <= 0:
            raise ValueError(f"bandwidth must be positive, got {bandwidth}")

        self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
        self.bandwidth = bandwidth
        self._density_hv: Optional[ComplexHypervector] = None
        self._fitted = False

    def fit(self, samples: jnp.ndarray) -> None:
        """Fit density estimator to sample data.

        Estimates the density function from samples by creating a kernel
        density estimate and encoding it as a hypervector.

        Args:
            samples: 1D array of sample points (shape: (n_samples,)).

        Raises:
            ValueError: If samples is not 1D.
        """
        if samples.ndim != 1:
            raise ValueError(f"samples must be 1D array, got shape {samples.shape}")

        # Create grid for density evaluation
        x_min, x_max = jnp.min(samples), jnp.max(samples)
        margin = 3 * self.bandwidth
        x_grid = jnp.linspace(x_min - margin, x_max + margin, 100)

        # Compute kernel density estimate on grid
        # p(x) ∝ Σ_i exp(-(x - x_i)^2 / (2 * h^2))
        density_values = jnp.zeros_like(x_grid)
        for sample in samples:
            kernel_values = jnp.exp(-((x_grid - sample) ** 2) / (2 * self.bandwidth**2))
            density_values += kernel_values

        # Normalize
        density_values /= len(samples) * self.bandwidth * jnp.sqrt(2 * jnp.pi)

        # Encode density function
        self._density_hv = self.vfa.encode_function_1d(x_grid, density_values)
        self._fitted = True

    def evaluate(self, x_query: jnp.ndarray) -> jnp.ndarray:
        """Evaluate density at query points.

        Args:
            x_query: Array of query points (shape: (n_queries,)).

        Returns:
            Estimated density values at query points (shape: (n_queries,)).

        Raises:
            RuntimeError: If estimator has not been fitted.
        """
        if not self._fitted:
            raise RuntimeError(
                "DensityEstimator must be fitted before evaluation. Call fit() first."
            )

        assert self._density_hv is not None  # Set by fit()
        return self.vfa.evaluate_batch(self._density_hv, x_query)

    def sample(self, key: jax.Array, n_samples: int = 1) -> jnp.ndarray:
        """Sample from the estimated density (not implemented).

        Args:
            key: JAX random key.
            n_samples: Number of samples to generate.

        Returns:
            Samples from the density.

        Raises:
            NotImplementedError: Sampling not yet implemented.
        """
        raise NotImplementedError(
            "Sampling from VFA density estimates is not yet implemented. "
            "This would require inverse transform sampling or MCMC."
        )

Functions

__init__(model, memory, bandwidth=1.0, kernel_config=None)

Initialize DensityEstimator.

Parameters:

Name Type Description Default
model VSAModel

VSAModel instance (must use ComplexHypervector).

required
memory VSAMemory

VSAMemory for storing basis vectors.

required
bandwidth float

Kernel bandwidth for density estimation (default: 1.0).

1.0
kernel_config Optional[KernelConfig]

Optional KernelConfig for VFA basis.

None

Raises:

Type Description
TypeError

If model doesn't use ComplexHypervector.

ValueError

If bandwidth is not positive.

Source code in vsax/vfa/applications.py
def __init__(
    self,
    model: VSAModel,
    memory: VSAMemory,
    bandwidth: float = 1.0,
    kernel_config: Optional[KernelConfig] = None,
) -> None:
    """Initialize DensityEstimator.

    Args:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for storing basis vectors.
        bandwidth: Kernel bandwidth for density estimation (default: 1.0).
        kernel_config: Optional KernelConfig for VFA basis.

    Raises:
        TypeError: If model doesn't use ComplexHypervector.
        ValueError: If bandwidth is not positive.
    """
    if bandwidth <= 0:
        raise ValueError(f"bandwidth must be positive, got {bandwidth}")

    self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
    self.bandwidth = bandwidth
    self._density_hv: Optional[ComplexHypervector] = None
    self._fitted = False

fit(samples)

Fit density estimator to sample data.

Estimates the density function from samples by creating a kernel density estimate and encoding it as a hypervector.

Parameters:

Name Type Description Default
samples ndarray

1D array of sample points (shape: (n_samples,)).

required

Raises:

Type Description
ValueError

If samples is not 1D.

Source code in vsax/vfa/applications.py
def fit(self, samples: jnp.ndarray) -> None:
    """Fit density estimator to sample data.

    Estimates the density function from samples by creating a kernel
    density estimate and encoding it as a hypervector.

    Args:
        samples: 1D array of sample points (shape: (n_samples,)).

    Raises:
        ValueError: If samples is not 1D.
    """
    if samples.ndim != 1:
        raise ValueError(f"samples must be 1D array, got shape {samples.shape}")

    # Create grid for density evaluation
    x_min, x_max = jnp.min(samples), jnp.max(samples)
    margin = 3 * self.bandwidth
    x_grid = jnp.linspace(x_min - margin, x_max + margin, 100)

    # Compute kernel density estimate on grid
    # p(x) ∝ Σ_i exp(-(x - x_i)^2 / (2 * h^2))
    density_values = jnp.zeros_like(x_grid)
    for sample in samples:
        kernel_values = jnp.exp(-((x_grid - sample) ** 2) / (2 * self.bandwidth**2))
        density_values += kernel_values

    # Normalize
    density_values /= len(samples) * self.bandwidth * jnp.sqrt(2 * jnp.pi)

    # Encode density function
    self._density_hv = self.vfa.encode_function_1d(x_grid, density_values)
    self._fitted = True

evaluate(x_query)

Evaluate density at query points.

Parameters:

Name Type Description Default
x_query ndarray

Array of query points (shape: (n_queries,)).

required

Returns:

Type Description
ndarray

Estimated density values at query points (shape: (n_queries,)).

Raises:

Type Description
RuntimeError

If estimator has not been fitted.

Source code in vsax/vfa/applications.py
def evaluate(self, x_query: jnp.ndarray) -> jnp.ndarray:
    """Evaluate density at query points.

    Args:
        x_query: Array of query points (shape: (n_queries,)).

    Returns:
        Estimated density values at query points (shape: (n_queries,)).

    Raises:
        RuntimeError: If estimator has not been fitted.
    """
    if not self._fitted:
        raise RuntimeError(
            "DensityEstimator must be fitted before evaluation. Call fit() first."
        )

    assert self._density_hv is not None  # Set by fit()
    return self.vfa.evaluate_batch(self._density_hv, x_query)

sample(key, n_samples=1)

Sample from the estimated density (not implemented).

Parameters:

Name Type Description Default
key Array

JAX random key.

required
n_samples int

Number of samples to generate.

1

Returns:

Type Description
ndarray

Samples from the density.

Raises:

Type Description
NotImplementedError

Sampling not yet implemented.

Source code in vsax/vfa/applications.py
def sample(self, key: jax.Array, n_samples: int = 1) -> jnp.ndarray:
    """Sample from the estimated density (not implemented).

    Args:
        key: JAX random key.
        n_samples: Number of samples to generate.

    Returns:
        Samples from the density.

    Raises:
        NotImplementedError: Sampling not yet implemented.
    """
    raise NotImplementedError(
        "Sampling from VFA density estimates is not yet implemented. "
        "This would require inverse transform sampling or MCMC."
    )

NonlinearRegressor

vsax.vfa.applications.NonlinearRegressor

Nonlinear regression using VFA (Frady et al. 2021 §7.2.2).

Fits nonlinear functions to (x, y) data using vector function architecture. Provides a scikit-learn-like interface for regression tasks.

Attributes:

Name Type Description
vfa

VectorFunctionEncoder for function encoding.

regularization

Regularization parameter for least squares.

Example

import jax.numpy as jnp from vsax import create_fhrr_model, VSAMemory from vsax.vfa import NonlinearRegressor

Create regressor

key = jax.random.PRNGKey(42) model = create_fhrr_model(dim=1024, key=key) memory = VSAMemory(model) regressor = NonlinearRegressor(model, memory)

Fit to nonlinear data

x_train = jnp.linspace(0, 10, 50) y_train = jnp.sin(x_train) + 0.1 * jax.random.normal(key, (50,)) regressor.fit(x_train, y_train)

Predict

x_test = jnp.linspace(0, 10, 100) y_pred = regressor.predict(x_test)

Source code in vsax/vfa/applications.py
class NonlinearRegressor:
    """Nonlinear regression using VFA (Frady et al. 2021 §7.2.2).

    Fits nonlinear functions to (x, y) data using vector function architecture.
    Provides a scikit-learn-like interface for regression tasks.

    Attributes:
        vfa: VectorFunctionEncoder for function encoding.
        regularization: Regularization parameter for least squares.

    Example:
        >>> import jax.numpy as jnp
        >>> from vsax import create_fhrr_model, VSAMemory
        >>> from vsax.vfa import NonlinearRegressor
        >>>
        >>> # Create regressor
        >>> key = jax.random.PRNGKey(42)
        >>> model = create_fhrr_model(dim=1024, key=key)
        >>> memory = VSAMemory(model)
        >>> regressor = NonlinearRegressor(model, memory)
        >>>
        >>> # Fit to nonlinear data
        >>> x_train = jnp.linspace(0, 10, 50)
        >>> y_train = jnp.sin(x_train) + 0.1 * jax.random.normal(key, (50,))
        >>> regressor.fit(x_train, y_train)
        >>>
        >>> # Predict
        >>> x_test = jnp.linspace(0, 10, 100)
        >>> y_pred = regressor.predict(x_test)
    """

    def __init__(
        self,
        model: VSAModel,
        memory: VSAMemory,
        regularization: float = 1e-6,
        kernel_config: Optional[KernelConfig] = None,
    ) -> None:
        """Initialize NonlinearRegressor.

        Args:
            model: VSAModel instance (must use ComplexHypervector).
            memory: VSAMemory for storing basis vectors.
            regularization: Regularization parameter (default: 1e-6).
            kernel_config: Optional KernelConfig for VFA basis.

        Raises:
            TypeError: If model doesn't use ComplexHypervector.
            ValueError: If regularization is negative.
        """
        if regularization < 0:
            raise ValueError(f"regularization must be non-negative, got {regularization}")

        self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
        self.regularization = regularization
        self._function_hv: Optional[ComplexHypervector] = None
        self._fitted = False

    def fit(self, x_train: jnp.ndarray, y_train: jnp.ndarray) -> None:
        """Fit regressor to training data.

        Args:
            x_train: Training input values (shape: (n_samples,)).
            y_train: Training target values (shape: (n_samples,)).

        Raises:
            ValueError: If x_train and y_train have different lengths.
        """
        self._function_hv = self.vfa.encode_function_1d(
            x_train, y_train, regularization=self.regularization
        )
        self._fitted = True

    def predict(self, x_test: jnp.ndarray) -> jnp.ndarray:
        """Predict target values for test inputs.

        Args:
            x_test: Test input values (shape: (n_test,)).

        Returns:
            Predicted target values (shape: (n_test,)).

        Raises:
            RuntimeError: If regressor has not been fitted.
        """
        if not self._fitted:
            raise RuntimeError(
                "NonlinearRegressor must be fitted before prediction. Call fit() first."
            )

        assert self._function_hv is not None  # Set by fit()
        return self.vfa.evaluate_batch(self._function_hv, x_test)

    def score(self, x_test: jnp.ndarray, y_test: jnp.ndarray) -> float:
        """Compute R² score on test data.

        Args:
            x_test: Test input values (shape: (n_test,)).
            y_test: True target values (shape: (n_test,)).

        Returns:
            R² score (1.0 is perfect, 0.0 is baseline, negative is worse than baseline).

        Raises:
            RuntimeError: If regressor has not been fitted.
        """
        if not self._fitted:
            raise RuntimeError(
                "NonlinearRegressor must be fitted before scoring. Call fit() first."
            )

        y_pred = self.predict(x_test)

        # R² = 1 - (SS_res / SS_tot)
        ss_res = jnp.sum((y_test - y_pred) ** 2)
        ss_tot = jnp.sum((y_test - jnp.mean(y_test)) ** 2)

        r2 = 1.0 - (ss_res / ss_tot)
        return float(r2)

Functions

__init__(model, memory, regularization=1e-06, kernel_config=None)

Initialize NonlinearRegressor.

Parameters:

Name Type Description Default
model VSAModel

VSAModel instance (must use ComplexHypervector).

required
memory VSAMemory

VSAMemory for storing basis vectors.

required
regularization float

Regularization parameter (default: 1e-6).

1e-06
kernel_config Optional[KernelConfig]

Optional KernelConfig for VFA basis.

None

Raises:

Type Description
TypeError

If model doesn't use ComplexHypervector.

ValueError

If regularization is negative.

Source code in vsax/vfa/applications.py
def __init__(
    self,
    model: VSAModel,
    memory: VSAMemory,
    regularization: float = 1e-6,
    kernel_config: Optional[KernelConfig] = None,
) -> None:
    """Initialize NonlinearRegressor.

    Args:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for storing basis vectors.
        regularization: Regularization parameter (default: 1e-6).
        kernel_config: Optional KernelConfig for VFA basis.

    Raises:
        TypeError: If model doesn't use ComplexHypervector.
        ValueError: If regularization is negative.
    """
    if regularization < 0:
        raise ValueError(f"regularization must be non-negative, got {regularization}")

    self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
    self.regularization = regularization
    self._function_hv: Optional[ComplexHypervector] = None
    self._fitted = False

fit(x_train, y_train)

Fit regressor to training data.

Parameters:

Name Type Description Default
x_train ndarray

Training input values (shape: (n_samples,)).

required
y_train ndarray

Training target values (shape: (n_samples,)).

required

Raises:

Type Description
ValueError

If x_train and y_train have different lengths.

Source code in vsax/vfa/applications.py
def fit(self, x_train: jnp.ndarray, y_train: jnp.ndarray) -> None:
    """Fit regressor to training data.

    Args:
        x_train: Training input values (shape: (n_samples,)).
        y_train: Training target values (shape: (n_samples,)).

    Raises:
        ValueError: If x_train and y_train have different lengths.
    """
    self._function_hv = self.vfa.encode_function_1d(
        x_train, y_train, regularization=self.regularization
    )
    self._fitted = True

predict(x_test)

Predict target values for test inputs.

Parameters:

Name Type Description Default
x_test ndarray

Test input values (shape: (n_test,)).

required

Returns:

Type Description
ndarray

Predicted target values (shape: (n_test,)).

Raises:

Type Description
RuntimeError

If regressor has not been fitted.

Source code in vsax/vfa/applications.py
def predict(self, x_test: jnp.ndarray) -> jnp.ndarray:
    """Predict target values for test inputs.

    Args:
        x_test: Test input values (shape: (n_test,)).

    Returns:
        Predicted target values (shape: (n_test,)).

    Raises:
        RuntimeError: If regressor has not been fitted.
    """
    if not self._fitted:
        raise RuntimeError(
            "NonlinearRegressor must be fitted before prediction. Call fit() first."
        )

    assert self._function_hv is not None  # Set by fit()
    return self.vfa.evaluate_batch(self._function_hv, x_test)

score(x_test, y_test)

Compute R² score on test data.

Parameters:

Name Type Description Default
x_test ndarray

Test input values (shape: (n_test,)).

required
y_test ndarray

True target values (shape: (n_test,)).

required

Returns:

Type Description
float

R² score (1.0 is perfect, 0.0 is baseline, negative is worse than baseline).

Raises:

Type Description
RuntimeError

If regressor has not been fitted.

Source code in vsax/vfa/applications.py
def score(self, x_test: jnp.ndarray, y_test: jnp.ndarray) -> float:
    """Compute R² score on test data.

    Args:
        x_test: Test input values (shape: (n_test,)).
        y_test: True target values (shape: (n_test,)).

    Returns:
        R² score (1.0 is perfect, 0.0 is baseline, negative is worse than baseline).

    Raises:
        RuntimeError: If regressor has not been fitted.
    """
    if not self._fitted:
        raise RuntimeError(
            "NonlinearRegressor must be fitted before scoring. Call fit() first."
        )

    y_pred = self.predict(x_test)

    # R² = 1 - (SS_res / SS_tot)
    ss_res = jnp.sum((y_test - y_pred) ** 2)
    ss_tot = jnp.sum((y_test - jnp.mean(y_test)) ** 2)

    r2 = 1.0 - (ss_res / ss_tot)
    return float(r2)

ImageProcessor

vsax.vfa.applications.ImageProcessor

Image processing using VFA (Frady et al. 2021 §7.1).

Encodes 2D images as vector functions and supports spatial transformations. Images are represented as functions f(x, y) mapping coordinates to pixel values.

Attributes:

Name Type Description
vfa

VectorFunctionEncoder for function encoding.

image_shape

Shape of the encoded image (height, width).

Example

import jax.numpy as jnp from vsax import create_fhrr_model, VSAMemory from vsax.vfa import ImageProcessor

Create processor

key = jax.random.PRNGKey(42) model = create_fhrr_model(dim=2048, key=key) memory = VSAMemory(model) processor = ImageProcessor(model, memory)

Encode image

image = jnp.zeros((32, 32)) image = image.at[10:20, 10:20].set(1.0) # White square processor.encode(image)

Shift image

shifted = processor.shift(dx=5.0, dy=5.0) reconstructed = processor.decode(shifted, shape=(32, 32))

Source code in vsax/vfa/applications.py
class ImageProcessor:
    """Image processing using VFA (Frady et al. 2021 §7.1).

    Encodes 2D images as vector functions and supports spatial transformations.
    Images are represented as functions f(x, y) mapping coordinates to pixel values.

    Attributes:
        vfa: VectorFunctionEncoder for function encoding.
        image_shape: Shape of the encoded image (height, width).

    Example:
        >>> import jax.numpy as jnp
        >>> from vsax import create_fhrr_model, VSAMemory
        >>> from vsax.vfa import ImageProcessor
        >>>
        >>> # Create processor
        >>> key = jax.random.PRNGKey(42)
        >>> model = create_fhrr_model(dim=2048, key=key)
        >>> memory = VSAMemory(model)
        >>> processor = ImageProcessor(model, memory)
        >>>
        >>> # Encode image
        >>> image = jnp.zeros((32, 32))
        >>> image = image.at[10:20, 10:20].set(1.0)  # White square
        >>> processor.encode(image)
        >>>
        >>> # Shift image
        >>> shifted = processor.shift(dx=5.0, dy=5.0)
        >>> reconstructed = processor.decode(shifted, shape=(32, 32))
    """

    def __init__(
        self,
        model: VSAModel,
        memory: VSAMemory,
        kernel_config: Optional[KernelConfig] = None,
    ) -> None:
        """Initialize ImageProcessor.

        Args:
            model: VSAModel instance (must use ComplexHypervector).
            memory: VSAMemory for storing basis vectors.
            kernel_config: Optional KernelConfig for VFA basis.

        Raises:
            TypeError: If model doesn't use ComplexHypervector.
        """
        self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
        self._image_hv: Optional[ComplexHypervector] = None
        self._image_shape: Optional[tuple[int, int]] = None
        self._encoded = False

    def encode(self, image: jnp.ndarray) -> ComplexHypervector:
        """Encode a 2D image as a hypervector.

        Note: Currently implements a simplified version that flattens the image
        and encodes it as a 1D function. Full 2D encoding would require
        multi-dimensional VFA which is planned for future work.

        Args:
            image: 2D array representing the image (shape: (height, width)).

        Returns:
            Encoded image hypervector.

        Raises:
            ValueError: If image is not 2D.
        """
        if image.ndim != 2:
            raise ValueError(f"image must be 2D array, got shape {image.shape}")

        self._image_shape = image.shape
        height, width = image.shape

        # Flatten image for 1D encoding (simplified approach)
        # Full 2D encoding would use multi-dimensional VFA
        x_coords = jnp.arange(height * width, dtype=jnp.float32)
        pixel_values = image.flatten()

        self._image_hv = self.vfa.encode_function_1d(x_coords, pixel_values)
        self._encoded = True

        return self._image_hv

    def decode(
        self,
        image_hv: ComplexHypervector,
        shape: tuple[int, int],
    ) -> jnp.ndarray:
        """Decode hypervector back to image.

        Args:
            image_hv: Encoded image hypervector.
            shape: Output image shape (height, width).

        Returns:
            Reconstructed 2D image array.
        """
        height, width = shape

        # Evaluate at grid points
        x_coords = jnp.arange(height * width, dtype=jnp.float32)
        pixel_values = self.vfa.evaluate_batch(image_hv, x_coords)

        # Reshape to image
        return pixel_values.reshape(height, width)

    def shift(self, dx: float = 0.0, dy: float = 0.0) -> ComplexHypervector:
        """Shift the encoded image.

        Note: Current simplified implementation shifts in the flattened space.
        Full 2D shifting would require multi-dimensional VFA.

        Args:
            dx: Horizontal shift amount (not used in current simplified version).
            dy: Vertical shift amount (not used in current simplified version).

        Returns:
            Shifted image hypervector.

        Raises:
            RuntimeError: If no image has been encoded.
            NotImplementedError: 2D shifting not yet implemented.
        """
        if not self._encoded:
            raise RuntimeError(
                "ImageProcessor must encode an image before shifting. Call encode() first."
            )

        # For proper 2D shifting, we would need multi-dimensional VFA
        raise NotImplementedError(
            "2D image shifting requires multi-dimensional VFA which is planned "
            "for future work. Current implementation only supports 1D flattened encoding."
        )

    def blend(
        self,
        image_hv1: ComplexHypervector,
        image_hv2: ComplexHypervector,
        alpha: float = 0.5,
    ) -> ComplexHypervector:
        """Blend two encoded images.

        Args:
            image_hv1: First encoded image.
            image_hv2: Second encoded image.
            alpha: Blending factor (0.0 = all image1, 1.0 = all image2).

        Returns:
            Blended image hypervector.

        Raises:
            ValueError: If alpha is not in [0, 1].
        """
        if not (0.0 <= alpha <= 1.0):
            raise ValueError(f"alpha must be in [0, 1], got {alpha}")

        # Blend using function addition
        return self.vfa.add_functions(image_hv1, image_hv2, alpha=(1.0 - alpha), beta=alpha)

Functions

__init__(model, memory, kernel_config=None)

Initialize ImageProcessor.

Parameters:

Name Type Description Default
model VSAModel

VSAModel instance (must use ComplexHypervector).

required
memory VSAMemory

VSAMemory for storing basis vectors.

required
kernel_config Optional[KernelConfig]

Optional KernelConfig for VFA basis.

None

Raises:

Type Description
TypeError

If model doesn't use ComplexHypervector.

Source code in vsax/vfa/applications.py
def __init__(
    self,
    model: VSAModel,
    memory: VSAMemory,
    kernel_config: Optional[KernelConfig] = None,
) -> None:
    """Initialize ImageProcessor.

    Args:
        model: VSAModel instance (must use ComplexHypervector).
        memory: VSAMemory for storing basis vectors.
        kernel_config: Optional KernelConfig for VFA basis.

    Raises:
        TypeError: If model doesn't use ComplexHypervector.
    """
    self.vfa = VectorFunctionEncoder(model, memory, kernel_config)
    self._image_hv: Optional[ComplexHypervector] = None
    self._image_shape: Optional[tuple[int, int]] = None
    self._encoded = False

encode(image)

Encode a 2D image as a hypervector.

Note: Currently implements a simplified version that flattens the image and encodes it as a 1D function. Full 2D encoding would require multi-dimensional VFA which is planned for future work.

Parameters:

Name Type Description Default
image ndarray

2D array representing the image (shape: (height, width)).

required

Returns:

Type Description
ComplexHypervector

Encoded image hypervector.

Raises:

Type Description
ValueError

If image is not 2D.

Source code in vsax/vfa/applications.py
def encode(self, image: jnp.ndarray) -> ComplexHypervector:
    """Encode a 2D image as a hypervector.

    Note: Currently implements a simplified version that flattens the image
    and encodes it as a 1D function. Full 2D encoding would require
    multi-dimensional VFA which is planned for future work.

    Args:
        image: 2D array representing the image (shape: (height, width)).

    Returns:
        Encoded image hypervector.

    Raises:
        ValueError: If image is not 2D.
    """
    if image.ndim != 2:
        raise ValueError(f"image must be 2D array, got shape {image.shape}")

    self._image_shape = image.shape
    height, width = image.shape

    # Flatten image for 1D encoding (simplified approach)
    # Full 2D encoding would use multi-dimensional VFA
    x_coords = jnp.arange(height * width, dtype=jnp.float32)
    pixel_values = image.flatten()

    self._image_hv = self.vfa.encode_function_1d(x_coords, pixel_values)
    self._encoded = True

    return self._image_hv

decode(image_hv, shape)

Decode hypervector back to image.

Parameters:

Name Type Description Default
image_hv ComplexHypervector

Encoded image hypervector.

required
shape tuple[int, int]

Output image shape (height, width).

required

Returns:

Type Description
ndarray

Reconstructed 2D image array.

Source code in vsax/vfa/applications.py
def decode(
    self,
    image_hv: ComplexHypervector,
    shape: tuple[int, int],
) -> jnp.ndarray:
    """Decode hypervector back to image.

    Args:
        image_hv: Encoded image hypervector.
        shape: Output image shape (height, width).

    Returns:
        Reconstructed 2D image array.
    """
    height, width = shape

    # Evaluate at grid points
    x_coords = jnp.arange(height * width, dtype=jnp.float32)
    pixel_values = self.vfa.evaluate_batch(image_hv, x_coords)

    # Reshape to image
    return pixel_values.reshape(height, width)

shift(dx=0.0, dy=0.0)

Shift the encoded image.

Note: Current simplified implementation shifts in the flattened space. Full 2D shifting would require multi-dimensional VFA.

Parameters:

Name Type Description Default
dx float

Horizontal shift amount (not used in current simplified version).

0.0
dy float

Vertical shift amount (not used in current simplified version).

0.0

Returns:

Type Description
ComplexHypervector

Shifted image hypervector.

Raises:

Type Description
RuntimeError

If no image has been encoded.

NotImplementedError

2D shifting not yet implemented.

Source code in vsax/vfa/applications.py
def shift(self, dx: float = 0.0, dy: float = 0.0) -> ComplexHypervector:
    """Shift the encoded image.

    Note: Current simplified implementation shifts in the flattened space.
    Full 2D shifting would require multi-dimensional VFA.

    Args:
        dx: Horizontal shift amount (not used in current simplified version).
        dy: Vertical shift amount (not used in current simplified version).

    Returns:
        Shifted image hypervector.

    Raises:
        RuntimeError: If no image has been encoded.
        NotImplementedError: 2D shifting not yet implemented.
    """
    if not self._encoded:
        raise RuntimeError(
            "ImageProcessor must encode an image before shifting. Call encode() first."
        )

    # For proper 2D shifting, we would need multi-dimensional VFA
    raise NotImplementedError(
        "2D image shifting requires multi-dimensional VFA which is planned "
        "for future work. Current implementation only supports 1D flattened encoding."
    )

blend(image_hv1, image_hv2, alpha=0.5)

Blend two encoded images.

Parameters:

Name Type Description Default
image_hv1 ComplexHypervector

First encoded image.

required
image_hv2 ComplexHypervector

Second encoded image.

required
alpha float

Blending factor (0.0 = all image1, 1.0 = all image2).

0.5

Returns:

Type Description
ComplexHypervector

Blended image hypervector.

Raises:

Type Description
ValueError

If alpha is not in [0, 1].

Source code in vsax/vfa/applications.py
def blend(
    self,
    image_hv1: ComplexHypervector,
    image_hv2: ComplexHypervector,
    alpha: float = 0.5,
) -> ComplexHypervector:
    """Blend two encoded images.

    Args:
        image_hv1: First encoded image.
        image_hv2: Second encoded image.
        alpha: Blending factor (0.0 = all image1, 1.0 = all image2).

    Returns:
        Blended image hypervector.

    Raises:
        ValueError: If alpha is not in [0, 1].
    """
    if not (0.0 <= alpha <= 1.0):
        raise ValueError(f"alpha must be in [0, 1], got {alpha}")

    # Blend using function addition
    return self.vfa.add_functions(image_hv1, image_hv2, alpha=(1.0 - alpha), beta=alpha)