| title | Example Notes |
|---|---|
| title_url | https://arxiv.org/abs/2111.09266 |
| subtitle | An example subtitle. |
| toc | true |
Question: Can we improve on MCMC-based methods for sampling using Generative Flow Networks?
Answer: Yes, when there is structure shared between the modes in a distribution.
- DAG: Directed Acyclic Graph
- GFN: Generative Flow Network
- MCMC: Markov Chain Monte Carlo
- PPO: Proximal Policy Optimization
-
RL: Reinforcement Learning
{% marginnote "mn-1" "This is an example Liquid
marginnote." %} -
Proxy: A function that approximates an oracle, i.e.
$R(x)$ , trained using$(x,y)$ pairs, which can incorporate (Bayesian) uncertainty - Active learning: A context in which "student" and "teacher" interact during training.
-
Flow matching: The total flow going into a state must match the total
flow leaving the state, except for the source,
$s_0$ , and sink(s),$s_f$ or$s_T$ . {% marginfigure "mf-1" "assets/img/side.png" "This is side image." %} -
Assay: Measures the composition or quality of a substance.
{% sidenote "sn-1" "This is an example Liquid
sidenote." %}
-
$\mathcal{S}$ : The set of states, i.e. the state space. -
$\mathcal{A}$ : The set of actions, i.e. the action space. -
$\mathcal{A}(s)$ : The set of allowed actions in state$s$ . -
$\mathcal{A}^\ast(s)$ : The set of all sequences of actions allowed after state$s$ , i.e.$\vec{a}$ . -
$C: \mathcal{A}^\ast\rightarrow \mathcal{S}$ : Function that maps a sequence of actions to a state; when the sequence is incomplete, the reward is 0.- When the correspondence between actions and states is bijective, the
states are uniquely described by some sequence
$\vec{a}$ and the generative process can be represented as a tree. - When the correspondence is non-injective, i.e. multiple action
sequences define the same
$x$ , the generative process resembles a DAG.
- When the correspondence between actions and states is bijective, the
states are uniquely described by some sequence
-
$\tilde{V}: \mathcal{S}\rightarrow\mathbb{R}^+$ : maps states to their expected values. -
$\tau$ : A trajectory, i.e. a sequences of states and actions or edges. -
$\pi(a\mid s)$ : The probability of an action$a\in\mathcal A$ conditional on a state$s\in\mathcal{S}$ . -
$\pi(s)$ : The probability of visiting state$s$ when starting at$s_0$ and following$\pi(\cdot\mid\cdot)$ . -
$\mathcal{X}$ : A set of discrete objects. -
$R(x)$ : The reward associated with terminal state$x\in\mathcal{X}$ . -
$F(s)$ : Total flow going through state$s$ ; terminal states have out-flow$R(s)$ . -
$F(s, a)$ : The total flow through edge$(s, a)$ . -
$F_\theta$ : A flow parameterized by$\theta$ . -
$Z=F(s_0)=F(s_f)$ : The partition function, which is equivalent to all the flow out of the source node or all the flow into the sink node.
{% maincolumn "assets/img/full-width.png" "A main column image." %}
-
Definitions: {% marginnote 'm1' 'Example margin note.' %}
-
$f$ : energy function$\rightarrow$ generative distribution. - Transforms a positive reward function into a generative policy that samples in proportion returns.
-
$\pi(x) \approx \frac{R(x)}{Z}=\frac{R(x)}{\sum_{x^{\prime} \in \mathcal{X}} R\left(x^{\prime}\right)}$ where$x\in\mathcal{X}$ - The flow consistency equations ensure that the inflow must equal outflow
(note that for interior nodes
$R(s)=0$ and for terminal nodes$A(s)=\emptyset$ ): $$ \sum_{s, a: T(s, a)=s^{\prime}} F(s, a) =R\left(s^{\prime}\right)+\sum_{a^{\prime} \in \mathcal{A}\left(s^{\prime}\right)} F\left(s^{\prime}, a^{\prime}\right) $$
-
-
Propositions: {% sidenote 's1' 'Example side note.' %}
-
Proposition 1: Illustrates the "overcounting" problem in the non-injective case.
- Given: $$ \begin{aligned} s_0 &= C(\emptyset) \ \pi(a\mid s) &= \frac{\tilde{V}(s+a)} {\sum_{b\in\mathcal{A}(s)}\tilde{V}(s+b)} \ \pi(\vec{a}=(a_1,\ldots, a_N)) &= \Pi_{i=1}^N \pi(a_i\mid C(a_1,\ldots,a_{i-1})) \ \end{aligned} $$
- Then:
- The probability of a given state is equivalent to the sum of all action
sequences leading to that state,
$\pi=\sum_{\vec{a}_i:C(\vec{a}_i)=s}\pi(\vec{a}_i)$ . - If
$C$ is bijective, then$\pi(s)=\frac{\tilde{V}(s)}{\tilde{V}(s_0)}$ and as a special case for terminal states$x$ ,$\pi(x)=\frac{R(x)}{\sum_{x\in\mathcal{X}}R(x)}$ . - If
$C$ is non-injective and there are$n(x)$ distinct action sequences$\vec{a_i}$ such that$C(\vec{a_i})=x$ , then$\pi(x)=\frac{n(x)R(x)}{\sum_{x'\in\mathcal{X}}n(x')R(x')}$ .
- The probability of a given state is equivalent to the sum of all action
sequences leading to that state,
- Comments:
- In combinatorial spaces, i.e. molecules, larger molecules would be
exponentially more likely to be sampled because there are more paths
leading to them, which means that
$\tilde{V}$ is biased, since it assumes the MDP's structure is a tree.
- In combinatorial spaces, i.e. molecules, larger molecules would be
exponentially more likely to be sampled because there are more paths
leading to them, which means that
-
Proposition 2: Shows that
$\pi(a\mid s)=\frac{F(s,a)}{F(s)}$ creates a generative policy where$\pi(x)\propto R(x)$ regardless of whether$C$ is bijective or non-injective.- Given: $$ \begin{aligned} \pi(a\mid s) &=\frac{F(s,a)}{F(s)},\text{ where }F(s,a) > 0 \ F(s) &= R(s) + \sum_{a\in\mathcal{A}(s)}F(s,a) \ \sum_{s, a: T(s, a)=s^{\prime}} F(s, a) &=R\left(s^{\prime}\right)+\sum_{a^{\prime} \in \mathcal{A}\left(s^{\prime}\right)} F\left(s^{\prime}, a^{\prime}\right) \end{aligned} $$
- Then:
-
$\pi(s)=\frac{F(s)}{F(s_0)}$ for non-terminal$s$ . $F(s_0)=\sum_{x\in\mathcal{X}}R(x)$ -
$\pi(x)=\frac{R(x)}{\sum_{x'\in\mathcal{X}} R(x')}$ for terminal$x\in\mathcal{X}$
-
- Comments:
- In all cases,
$\sum_{x\in\mathcal{X}}\pi(x)=1$ because terminal states are mutually exclusive. - In the non-injective case,
$\sum_{s\in\mathcal{S}}\pi(s)\ne 1$ in general because internal states are not mutually exclusive.
- In all cases,
-
Proposition 3: Off-policy sampling works provided the sampling policy
$P$ has adequate support.- Given:
- The exploratory sampling policy
$P$ parameterized by$\theta'$ has the same support as the optimal$\pi$ for a consistent flow$F^\ast\in\mathcal{F}^\ast$ , i.e. in flow equals out flow. -
$\exists \theta: F_\theta=F^\ast$ , i.e. there is a sufficiently rich family of predictors. -
$\theta^\ast\in \arg\min_\theta\mathbb{E_{\tau\sim P(\theta')}}\left[L_\theta(\tau)\right]$ , where$L_\theta$ is the Flow-matching Loss.
- The exploratory sampling policy
- Then:
- A global optimum for the expected loss generates the correct flows:
-
$\forall\tau\sim P(\theta')$ :$F_{\theta^\ast}=F^\ast$ $L_{\theta^\ast}(\tau)=0$
-
- If
$\pi_{\theta^\ast}(a\mid s)=\frac{F_{\theta^\ast}(s, a)}{\sum_{a'\in\mathcal{A}(s)}F_{\theta^\ast}(s,a')}$ then$\pi_{\theta^\ast}(x)=\frac{R(x)}{Z}$ .
- A global optimum for the expected loss generates the correct flows:
- Comments:
- In general, there are an infinite number of solutions. Imagine a case
where only two trajectories are possible,
$\tau_1=(s_0,a_1,s_A,a_2,s_T)$ and$\tau_2=(s_0,a_3,s_B,a_4,s_T)$ . Then$F(s_A)+F(s_B)=R(s_T)$ and the solution is any linear combination of the two, i.e.$F(s_A)=\alpha$ and$F(s_B)=r-\alpha$ for$\alpha\in[0, r]$ .
- In general, there are an infinite number of solutions. Imagine a case
where only two trajectories are possible,
- Given:
-
Proposition 1: Illustrates the "overcounting" problem in the non-injective case.
-
Objective Functions:
-
Naive flow-matching loss:
$$
\tilde{\mathcal{L}}\theta(\tau)
=\sum{s^{\prime} \in \tau \neq s_0}\left(\sum_{s, a: T(s, a)=s^{\prime}}
F_\theta(s, a)
-R\left(s^{\prime}\right)
-\sum_{a^{\prime} \in \mathcal{A}\left(s^{\prime}\right)}
F_\theta\left(s^{\prime}, a^{\prime}\right)\right)^2
$$
- Issues:
- Flow variability:
-
Problem: flow will be much larger for nodes near the source; the
higher the cardinality of
$\mathcal{X}$ , the smaller the terminal flow for each terminal state$x$ . - Solution: estimate flow on the log scale.
-
Problem: flow will be much larger for nodes near the source; the
higher the cardinality of
- Flow variability:
- Issues:
-
Flow-matching Loss:
$$
\mathcal{L}{\theta, \epsilon}(\tau)
=\sum{s^{\prime} \in \tau \neq s_0}\left(\log \left[\epsilon
+\sum_{s, a: T(s, a)=s^{\prime}} \exp F_\theta^{\log }(s, a)\right]
-\log \left[\begin{array}{c}
\epsilon
+R\left(s^{\prime}\right)+\sum_{a^{\prime} \in \mathcal{A}\left(s^{\prime}\right)}
\exp F_\theta^{\log }\left(s^{\prime}, a^{\prime}\right)
\end{array}\right]\right)^2
$$
- Issues:
- Numerical instability:
- Problem: taking the logarithm of very small flows.
-
Solution: the hyperpameter
$\epsilon$ adjusts the pressure on matching small vs. large flows; the larger the value, the less sensitive to small flows; in practice,$\epsilon\approx \min_s R(s)$ .
- Numerical instability:
- Issues:
-
Naive flow-matching loss:
$$
\tilde{\mathcal{L}}\theta(\tau)
=\sum{s^{\prime} \in \tau \neq s_0}\left(\sum_{s, a: T(s, a)=s^{\prime}}
F_\theta(s, a)
-R\left(s^{\prime}\right)
-\sum_{a^{\prime} \in \mathcal{A}\left(s^{\prime}\right)}
F_\theta\left(s^{\prime}, a^{\prime}\right)\right)^2
$$
-
Details:
- When the mapping between action sequences and states is bijective,
generating one
$x$ is like an episode in a tree-structured deterministic MDP where the leaves are terminal states; in this case$\tilde{V}(s)$ is the sum of all descendant rewards. - When the mapping is non-injective, methods like MaxEntRL and other autoregressive methods "overcount."
- When the mapping between action sequences and states is bijective,
generating one
-
Examples:
- Molecules:
-
$\mathcal{X}$ is a collection of molecules. -
$R(x)$ measures a chemical property of the molecule, which is a proxy for actual values obtained from assays.
-
- Molecules:
-
Controls:
- Temperature to control exploration around modes.
- Powers of returns, i.e.
$R(x)^\beta$ , also control exploration.
-
Advantages:
- Off-policy training: can use samples generated by
$\pi_T$ , which is not the same as the trained distribution, provided it has adequate support for the true distribution.
- Off-policy training: can use samples generated by
- Hypergrid:
- Converges to
$\pi(x)\propto R(x)$ . - Requires less samples than MCMC and PPO under various performance metrics.
- Recovers all the modes and does so faster than MCMC and PPO.
- Robust to separation between modes.
- Top-k returns are greater in an active learning setting for nearly all rounds.
- Converges to
- Molecule generation:
- Generates higher reward molecules than baselines.
- Generates more diverse candidates than baselines.
- Top-k returns are greater in an active learning setting for nearly all rounds.
- RL methods approximate the best possible policy in static and stochastic environments, i.e. they put the probability mass on the best action in each state.
- GFNs approximate distributions in deterministic environments.
- GFNs trade sampling complexity (MCMC hallmark) for training complexity; training this model amortizes the cost of generating samples.
- Bootstrapping can cause optimization challenge and limit performance.
- MCMC suffers from mode-mixing problem, i.e. probability deserts between high value modes.
- MCMC methods are expensive when they generate samples uniformly because they generate low value samples.
- When the modes are random, GFNs should provide no added benefit.
- Lemma 5: