///
THRML (Thermodynamic Hypergraphical Model Library) is built upon a set of fundamental abstractions that collectively enable the definition, representation, and efficient sampling of probabilistic grap
782 views
~782 views from guests
Guest views are estimated from total page views. These include anonymous visitors and users who weren't logged in when they viewed the page.
THRML (Thermodynamic Hypergraphical Model Library) is built upon a set of fundamental abstractions that collectively enable the definition, representation, and efficient sampling of probabilistic graphical models (PGMs). These core components facilitate the library's focus on blocked Gibbs sampling and energy-based models, especially within a JAX-accelerated environment. Understanding these building blocks is key to grasping how THRML operates, translating complex graph structures and probabilistic relationships into performant computational graphs.
At the most granular level, an AbstractNode represents an individual random variable within a PGM. All variables in THRML, regardless of their specific type or behavior, must inherit from this base class. Each node is assigned a unique identifier upon creation, which is crucial for internal bookkeeping and graph traversal.
AbstractNode representing a binary random variable that can take on one of two states, typically ${-1, 1}$. Internally, these are often represented by True and False boolean values.These specialized node types allow THRML to handle heterogeneous graphical models, where different parts of the graph may consist of different kinds of variables.
A Block serves as a collection of AbstractNodes. Its primary purpose is to group nodes that can be efficiently sampled or processed together in a JAX-friendly, Single Instruction, Multiple Data (SIMD) manner. A fundamental constraint of a Block is that all nodes within it must be of the same AbstractNode subclass (e.g., all SpinNodes or all CategoricalNodes). This homogeneity is essential for parallelizing computations.
Blocks are the fundamental units for blocked Gibbs sampling, allowing large sections of the graph to be updated in parallel given the states of their neighbors.
The BlockSpec is a crucial component for managing the translation between user-friendly block-local state representations and the optimized "global" state representation favored by JAX for array-level parallelism. When working with complex PGMs and many Blocks, the states of individual blocks (lists of PyTrees where each PyTree leaf corresponds to a node's state) need to be efficiently accessed.
The BlockSpec achieves this by:
This abstraction minimizes Python loops and maximizes array operations, directly supporting THRML's goal of GPU-accelerated performance. It also differentiates between "free" blocks (whose states are actively sampled) and "clamped" blocks (whose states remain fixed), which is vital for conditional sampling tasks.
AbstractFactor is the base class for defining probabilistic interactions or relationships within a PGM. Conceptually, a factor represents a "batch" of parallel interactions defined over groups of nodes. For instance, in an Ising model, a factor might define all the pairwise interactions between adjacent nodes in a grid.
AbstractFactor that includes learnable parameters, typically a weight tensor (weights). The leading dimension of this weights tensor must align with the batch dimension of the factor, meaning it corresponds to the number of parallel interactions being defined.The core role of an AbstractFactor is to specify how it contributes to the overall energy function of an Energy-Based Model (EBM) and, crucially, to to_interaction_groups that facilitate conditional sampling.
An InteractionGroup provides a directed, computational view of how one set of nodes influences another for the purpose of conditional sampling. It acts as a bridge between the abstract definition of relationships in AbstractFactors and the concrete computational steps required by a sampler.
Each InteractionGroup specifies:
head_nodes: A Block of nodes whose state updates are affected by this interaction.tail_nodes: A list of Blocks whose current states are required to compute the update for the head_nodes. These blocks are typically "neighbors" in the PGM.interaction: A PyTree containing static, parametric information (like weights or biases) that define the nature of the influence. This information is independent of the current state of the nodes.When a sampler updates a Block of head_nodes, it receives a concise summary of all relevant InteractionGroups, including the values of the tail_nodes and the static interaction parameters. This structured input is critical for vectorized computation.
The AbstractConditionalSampler defines the core mechanism for updating the state of a Block of nodes. During a Gibbs sampling iteration, a specific AbstractConditionalSampler instance is assigned to each Block that needs to be updated.
Its primary method, sample, takes the following key inputs:
InteractionGroups that define the influence from neighboring nodes and static parameters.tail_nodes of the InteractionGroups).sampler_state to maintain memory across sampling steps.The sampler then computes and returns a new set of states for its target Block. This update rule can be exact (e.g., drawing from a conditional Boltzmann distribution) or approximate (e.g., performing a Metropolis-within-Gibbs step).
compute_parameters (e.g., calculating mean and variance for a Gaussian) and sample_given_parameters (e.g., drawing from the Gaussian using the computed parameters). This pattern is often useful for discrete distributions like Bernoulli or Softmax.These fundamental components work in concert to define and efficiently sample complex probabilistic graphical models in THRML. AbstractNodes form the basic variables, Blocks group them for parallel updates, and BlockSpec optimizes state management. AbstractFactors define high-level probabilistic relationships, which are then broken down into InteractionGroups to specify computational dependencies for AbstractConditionalSamplers, which finally perform the actual state updates. This modular design, optimized for JAX, allows for flexible model construction and high-performance simulation on accelerators, laying the groundwork for experimentation with novel hardware architectures as described in [THRML: Thermodynamic Hypergraphical Model Library].