Skip to content

Utah-Math-Data-Science/SPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Test-Time Guidance for Flow-Based Generative Models via Parallel Tempering on Source Distributions

Code repository for Test-Time Guidance for Flow-Based Generative Models via Parallel Tempering on Source Distributions. Implementations for Source Parallel Tempering (SPT) and other test-time guidance methods are provided as an installable Python package (see Usage).

The code for replicating the paper's experiments are on different branches of this Git repository:

Library Installation

Clone the repository and install using pip or uv:

pip install .

Warning

The dependency versions for PyTorch and Jax have not been specified to maximize compatibility potential without verification of compatibility. We hope it is compatible with the versions of PyTorch or Jax that your project is using.

Usage

We provide implementations of the guidance algorithms in PyTorch and Jax.

Source Parallel Tempering

Example usage of Source Parallel Tempering.

import spt_guidance.torch.spt as spt
# import spt_guidance.jax.spt as spt   # for Jax

kernel = spt.KernelPreconditionedCrankNicolson(
   data_dims  # shape of one data sample, not a batch
)
ptmcmc = spt.PTMCMC(
    kernel,
    potential_fn,  # function to be minimized (negative reward function)
    torch.linspace(0., 5., num_chains, device=device),  # inverse temperatures
    torch.linspace(torch.pi / 2 - 1e-3, 0.05, num_chains, device=device).sin(),  # pCN kernel step size
)
last_chain = ptmcmc.run(
   # key,  # PRNG key if using Jax
   100,  # number of iterations (tmax) to run SPT for
   batch_size=batch_size,  # number of source samples to generate at once
)[-1]  # take the last (coldest) chain
final_samples = transport(last_chain)  # map source samples to the target distribution

Feynmann-Kac Steering

Example usage of Feynmann-Kac Steering.

import spt_guidance.torch.fk_steering as fk_steering
# import spt_guidance.jax.fk_steering as fk_steering   # for Jax

fk = fk_steering.FeynmannKac(
    init_param_fn=partial(torch.randn, data_dims, device=device),  # data_dims is shape of one data sample, not a batch
    reward_fn=reward_fn,  # function be be maximized (reward function)
    propose_fn=propose_fn,  # function that updates the current particles to the next time step
    resample_fn=fk_steering.ResampleFn.adaptive_resample,  # function to resample the particles based on their potential
    potential=fk_steering.Potential.MAX,  # how the potential of each particle is computed
    intermediate_reward=fk_steering.IntermediateReward.EXPECTED_SAMPLE,  # how the intermediate reward of each particle is computed
    inverse_temperature=5.,  # the tilt of the target distribution
    kwargs_intermediate_reward=dict(
        expected_sample_fn=lambda t0, z: propose_fn(t0, 1., z),  # function to return the expected final samples of the current particles
    ),
)
final_samples = fk.run(
   # key,  # PRNG key if using Jax
   times,  # sampling time interval discretization
   4,  # number of FK Steering chains to use
   batch_size=batch_size,  # number of target samples to generate at once
)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages