///
This page provides a detailed API reference for the core modeling components in THRML, including factors, interactions, energy-based models (EBMs), and conditional samplers. These components are essen
577 views
~577 views from guests
Guest views are estimated from total page views. These include anonymous visitors and users who weren't logged in when they viewed the page.
This page provides a detailed API reference for the core modeling components in THRML, including factors, interactions, energy-based models (EBMs), and conditional samplers. These components are essential for defining the probabilistic relationships within your graphical models and for implementing efficient sampling strategies.
thrml/factor.py)A factor represents a batch of undirected interactions between sets of random variables.
Attributes:
node_groups: list[Block]AbstractFactor.__init__(self, node_groups: list[Block]) -> NoneCreate a batch of Factors.
AbstractFactor.to_interaction_groups(self) -> list[InteractionGroup]Compile a factor to a set of directed interactions.
A factor that is parameterized by a weight tensor.
Attributes:
weights: ArrayWeightedFactor.__init__(self, weights: Array, node_groups: list[Block]) -> NoneCreate an instance of WeightedFactor.
A sampling program built out of factors.
FactorSamplingProgram.__init__(self, gibbs_spec: BlockGibbsSpec, samplers: list[AbstractConditionalSampler], factors: Sequence[AbstractFactor], other_interaction_groups: list[InteractionGroup]) -> NoneCreate a FactorSamplingProgram. Thin wrapper over BlockSamplingProgram.
thrml/interaction.py)Defines computational dependencies for conditional sampling updates.
Attributes:
head_nodes: Blocktail_nodes: list[Block]interaction: PyTreeInteractionGroup.__init__(self, interaction: PyTree, head_nodes: Block, tail_nodes: list[Block]) -> NoneCreate an InteractionGroup.
thrml/models/ebm.py)Something that has a well-defined energy function (map from a state to a scalar).
AbstractEBM.energy(self, state: list[_State], blocks: list[Block]) -> Float[Array, ""]Evaluate the energy function of the EBM given some state information.
A factor that defines an energy function.
EBMFactor.energy(self, global_state: list[Array], block_spec: BlockSpec) -> Float[Array, ""]Evaluate the energy function of the factor.
An EBM that is made up of Factors, i.e., an EBM with an energy function like, $$\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)$$ where the sum over $i$ is taken over factors.
Attributes:
node_shape_dtypes: _SDAbstractFactorizedEBM.__init__(self, node_shape_dtypes: _SD = DEFAULT_NODE_SHAPE_DTYPES) -> NoneInitializes the AbstractFactorizedEBM with specified node shape dtypes.
AbstractFactorizedEBM.factors(self) -> list[EBMFactor]A concrete implementation of this class must define this method that returns a list of factors that substantiate the EBM.
An EBM that is defined by a concrete list of factors.
Attributes:
_factors: list[EBMFactor]FactorizedEBM.__init__(self, factors: list[EBMFactor], node_shape_dtypes: _SD = DEFAULT_NODE_SHAPE_DTYPES) -> NoneInitializes a FactorizedEBM with a list of factors and optional node shape dtypes.
FactorizedEBM.factors(self) -> list[EBMFactor]Returns the list of factors that define this EBM.
thrml/models/discrete_ebm.py)An interaction that shows up when sampling from discrete-variable EBMs.
Attributes:
n_spin: intweights: ArrayImplements batches of energy function terms of the form s_1 * ... * s_M * W[c_1, ..., c_N], where the s_i are spin variables and the c_i are categorical variables.
Attributes:
spin_node_groups: list[Block]categorical_node_groups: list[Block]weights: Arrayis_spin: dict[Type[AbstractNode], bool]DiscreteEBMFactor.__init__(self, spin_node_groups: list[Block], categorical_node_groups: list[Block], weights: Array) -> NoneCreate a DiscreteEBMFactor.
DiscreteEBMFactor.to_interaction_groups(self) -> list[InteractionGroup]Produce interaction groups that implement this factor.
DiscreteEBMFactor.energy(self, global_state: list[Array], block_spec: BlockSpec) -> Float[Array, ""]Compute the energy associated with this factor.
A discrete factor with a square interaction weight tensor (shape [b, x, x, ..., x]).
SquareDiscreteEBMFactor.__init__(self, spin_node_groups: list[Block], categorical_node_groups: list[Block], weights: Array) -> NoneEnforce that the weights are actually square.
SquareDiscreteEBMFactor.to_interaction_groups(self) -> list[InteractionGroup]Call the parent class to_interaction_groups, and merge the results.
A DiscreteEBMFactor that involves only spin variables.
SpinEBMFactor.__init__(self, node_groups: list[Block], weights: Array) -> NoneInitializes a SpinEBMFactor with specified node groups and weights, assuming only spin variables.
A DiscreteEBMFactor that involves only categorical variables.
CategoricalEBMFactor.__init__(self, node_groups: list[Block], weights: Array) -> NoneInitializes a CategoricalEBMFactor with specified node groups and weights, assuming only categorical variables.
A DiscreteEBMFactor that involves only categorical variables that also has a square weight tensor.
SquareCategoricalEBMFactor.__init__(self, node_groups: list[Block], weights: Array) -> NoneInitializes a SquareCategoricalEBMFactor with specified node groups and weights, assuming only categorical variables and a square weight tensor.
A conditional update for spin-valued random variables that will perform a Gibbs sampling update given one or more DiscreteEBMInteractions.
SpinGibbsConditional.compute_parameters(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTreeCompute the parameter $\gamma$ of a spin-valued Bernoulli distribution given DiscreteEBMInteractions:
$$\gamma = \sum_i s_1^i \dots s_K^i : W^i[x_1^i, \dots, x_M^i]$$
where the sum over $i$ is over all the DiscreteEBMInteractions seen by this function.
A conditional update for categorical random variables that will perform a Gibbs sampling update given one or more DiscreteEBMInteractions.
Attributes:
n_categories: intCategoricalGibbsConditional.compute_parameters(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTreeCompute the parameter $\theta$ of a softmax distribution given DiscreteEBMInteractions:
$$\theta = \sum_i s_1^i \dots s_K^i : W^i[:, x_1^i, \dots, x_M^i]$$
where the sum over $i$ is over all the DiscreteEBMInteractions seen by this function.
thrml/models/ising.py)An EBM with the energy function, $$\mathcal{E}(s) = -\beta \left( \sum_{i \in S_1} b_i s_i + \sum_{(i, j) \in S_2} J_{ij} s_i s_j \right)$$ where $S_1$ and $S_2$ are the sets of biases and weights that make up the model, respectively.
Attributes:
nodes: list[AbstractNode]biases: Arrayedges: list[Edge]weights: Arraybeta: ArrayIsingEBM.__init__(self, nodes: list[AbstractNode], edges: list[Edge], biases: Array, weights: Array, beta: Array) -> NoneInitialize an Ising EBM.
IsingEBM.factors(self) -> list[EBMFactor]Returns the list of factors that define this EBM.
A very thin wrapper on FactorSamplingProgram that specializes it to the case of an Ising Model.
IsingSamplingProgram.__init__(self, ebm: IsingEBM, free_blocks: list[SuperBlock], clamped_blocks: list[Block]) -> NoneInitialize an Ising sampling program.
Contains a complete specification of an Ising EBM that can be trained using sampling-based gradients.
Attributes:
ebm: IsingEBMprogram_positive: IsingSamplingProgramprogram_negative: IsingSamplingProgramschedule_positive: SamplingScheduleschedule_negative: SamplingScheduleIsingTrainingSpec.__init__(self, ebm: IsingEBM, data_blocks: list[Block], conditioning_blocks: list[Block], positive_sampling_blocks: list[SuperBlock], negative_sampling_blocks: list[SuperBlock], schedule_positive: SamplingSchedule, schedule_negative: SamplingSchedule) -> NoneInitializes an IsingTrainingSpec with the EBM, block configurations, and sampling schedules for training.
hinton_init(key: Key[Array, ""], model: IsingEBM, blocks: list[Block[AbstractNode]], batch_shape: tuple[int]) -> list[Bool[Array, "batch_size block_size"]]Initialize the blocks according to the marginal bias.
estimate_moments(key: Key[Array, ""], first_moment_nodes: list[AbstractNode], second_moment_edges: list[Edge], program: BlockSamplingProgram, schedule: SamplingSchedule, init_state: list[Array], clamped_data: list[Array]) -> tupleEstimates the first and second moments of an Ising model Boltzmann distribution via sampling.
estimate_kl_grad(key: Key[Array, ""], training_spec: IsingTrainingSpec, bias_nodes: list[AbstractNode], weight_edges: list[Edge], data: list[Array], conditioning_values: list[Array], init_state_positive: list[Array], init_state_negative: list[Array]) -> tupleEstimate the KL-gradients of an Ising model with respect to its weights and biases.
thrml/conditional_samplers.py)Base class for all conditional samplers.
AbstractConditionalSampler.sample(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: _SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[_State, _SamplerState]Draw a sample from this conditional.
AbstractConditionalSampler.init(self) -> NoneInitialize the sampler state before sampling begins.
A conditional sampler that leverages a parameterized distribution.
AbstractParametricConditionalSampler.compute_parameters(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: PyTree, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTreeCompute the parameters of the distribution. For a description of the arguments, see [thrml.AbstractConditionalSampler.sample][].
AbstractParametricConditionalSampler.sample_given_parameters(self, key: Key, parameters: PyTree, sampler_state: _SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[_State, _SamplerState]Produce a sample given the parameters of the distribution, passed in as the parameters argument.
AbstractParametricConditionalSampler.sample(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: _SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[_State, _SamplerState]Sample from the distribution by first computing the parameters and then generating a sample based off of them.
Sample from a bernoulli distribution.
BernoulliConditional.compute_parameters(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTreeA concrete implementation of this function has to return a value of $\gamma$ for every node in the block that is being updated. This array should have shape [b].
BernoulliConditional.sample_given_parameters(self, key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[_State, None]Sample from a spin-valued bernoulli distribution given the parameter $\gamma$. In THRML, 1 is represented by the boolean value True and -1 is represented by False.
Sample from a softmax distribution.
Attributes:
n_categories: intSoftmaxConditional.compute_parameters(self, key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[_State]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTreeA concrete implementation of this function has to return $\theta$ vector for every node in the block that is being updated. This array should have shape [b, M], where $M$ is the number of possible values that $X$ may take on.
SoftmaxConditional.sample_given_parameters(self, key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[_State, None]Sample from a softmax distribution given the parameter vector $\theta$.