Code repository for Test-Time Guidance for Flow-Based Generative Models via Parallel Tempering on Source Distributions. Implementations for Source Parallel Tempering (SPT) and other test-time guidance methods are provided as an installable Python package (see Usage).
The code for replicating the paper's experiments are on different branches of this Git repository:
- mixture-gaussians: Code for Section 5.1: Analytic Examples.
- images: Code for Section 5.2: Text to Image Guidance.
- dynamical-systems: Code for Section 5.3: Sampling dynamical system trajectories.
Clone the repository and install using pip or uv:
pip install .Warning
The dependency versions for PyTorch and Jax have not been specified to maximize compatibility potential without verification of compatibility. We hope it is compatible with the versions of PyTorch or Jax that your project is using.
We provide implementations of the guidance algorithms in PyTorch and Jax.
Example usage of Source Parallel Tempering.
import spt_guidance.torch.spt as spt
# import spt_guidance.jax.spt as spt # for Jax
kernel = spt.KernelPreconditionedCrankNicolson(
data_dims # shape of one data sample, not a batch
)
ptmcmc = spt.PTMCMC(
kernel,
potential_fn, # function to be minimized (negative reward function)
torch.linspace(0., 5., num_chains, device=device), # inverse temperatures
torch.linspace(torch.pi / 2 - 1e-3, 0.05, num_chains, device=device).sin(), # pCN kernel step size
)
last_chain = ptmcmc.run(
# key, # PRNG key if using Jax
100, # number of iterations (tmax) to run SPT for
batch_size=batch_size, # number of source samples to generate at once
)[-1] # take the last (coldest) chain
final_samples = transport(last_chain) # map source samples to the target distributionExample usage of Feynmann-Kac Steering.
import spt_guidance.torch.fk_steering as fk_steering
# import spt_guidance.jax.fk_steering as fk_steering # for Jax
fk = fk_steering.FeynmannKac(
init_param_fn=partial(torch.randn, data_dims, device=device), # data_dims is shape of one data sample, not a batch
reward_fn=reward_fn, # function be be maximized (reward function)
propose_fn=propose_fn, # function that updates the current particles to the next time step
resample_fn=fk_steering.ResampleFn.adaptive_resample, # function to resample the particles based on their potential
potential=fk_steering.Potential.MAX, # how the potential of each particle is computed
intermediate_reward=fk_steering.IntermediateReward.EXPECTED_SAMPLE, # how the intermediate reward of each particle is computed
inverse_temperature=5., # the tilt of the target distribution
kwargs_intermediate_reward=dict(
expected_sample_fn=lambda t0, z: propose_fn(t0, 1., z), # function to return the expected final samples of the current particles
),
)
final_samples = fk.run(
# key, # PRNG key if using Jax
times, # sampling time interval discretization
4, # number of FK Steering chains to use
batch_size=batch_size, # number of target samples to generate at once
)