Optimal Transport-Guided Conditional Score-Based Diffusion Model (OTCS)

Xi'an Jiaotong University,
NeurIPS 2023

Abstract

This paper investigates the conditional score-based diffusion model for unpaired or partially paired data. We build the first conditional score-based model, OTCS, for unpaired data. It can be used in real-world unpaired image-restoration, unpaired cross-modal medical image translation, unpaired text-to-image generation, etc. For simplicity, we take the unpaired image super-resolution as example to introduce our method.


Unpaired Super-resolution Problem

In unpaired super-resolution, we are provided two sets of images of low-resolution and high-resolution, respectively. In traditional paired super-resolution, the low-resolution and high-resolution images are one-to-one paired. By contrast, in unpaired super-resolution, there lacks coupling relationship (e.g., one-to-one relationship) between low-resolution and high-resolution images (see the figures below). Unpaired super-resolution is more realistic in real-world applications, because collecting one-to-one paired images with the same content is not easy.

./otcs/figures/images.png

The goal of unpaired super-resolution is to train an AI model using the unpaired datasets such that it can output the high-resolution image for a test low-resolution image in inference.

Overall Framework of OTCS

./otcs/figures/OT_guided_SBDMs1.png

Stage I: Building Coupling Relation Using Optimal Transport

Since the coupling relationship between low-resolution and high-resolution images are unpaired, we build the coupling relationship using optimal transport. We denote low-resolution images as $\{{x}_i\}_{i=1}^m$ and high-resolution images as $\{{y}_j\}_{j=1}^n$. We use the dual formulation of optimal transport to learn the coupling relationship. Specifically, we train two neural networks $u_{\omega}$ and $v_{\omega}$ using the loss

$$\min_{\omega}\mathcal{F}_{\rm OT}\left(u_\omega,v_\omega\right)=\frac{1}{m}\sum_{i}{u_\omega({x}_i)}+\frac{1}{n}\sum_{j}{v_\omega({y}_j)}-\frac{1}{mn}\sum_{ij}{\frac{1}{4\epsilon}\left[\left(u_\omega\left({x}_i\right)+v_\omega\left({y}_j\right)-c\left({x}_i,{y}_j\right)\right)_+\right]^2}.$$

Using the trained model, the can compute the compatibility function $H({x}_i,{y}_j)$ by $$ H\left(x_i,y_j\right)=\frac{1}{2\epsilon}\left(u_\omega\left(x_i\right)+v_\omega\left(y_j\right)-c\left(x_i,y_j\right)\right)_+. $$ $H({x}_i,{y}_j)$ models the coupling relationship between ${x}_i,{y}_j$. Using $H$, the optimal transport plan is given by $\hat{\pi}\left(x_i,y_j\right)=\frac{1}{mn}H\left(x_i,y_j\right)$. The images with $H>0$ are coupled. We show the coupled images are below.
./otcs/figures/guided_images.png

Coupled images ($H>0$) buided using optimal transport

Stage II: Training Conditional Score-Based Model

For any $x_i$, we compute the compatibility function values $H({x}_i,{y}_j)$ for all the high-resolution images $\left\{y_j\right\}_{j=1}^n$. According to the compatibility function values, the coupling relationship value $\hat{\pi}\left(x_i,y_j\right)$ of all high-resolution images $\left\{y_j\right\}_{j=1}^n$ and the degraded image $x_i$ is obtained. Then, we randomly choose a high-resolution image $y_i$ from $\hat{\pi}\left(x_i,y_j\right)$ with probability of $\hat{\pi}(x_i,\cdot)$. Taking $y_i$ as the initial value, we utilize the forward stochastic differential equation $dy_t=f\left(y_t,t\right)dt+g\left(t\right)dw$ to produce noisy images, where $w$ is the Wiener process parameter, $t$ is time, and the transition probability from time 0 to time $t$ is $p_{0t}(y|y_i)$. The noisy image $y_{i,t}$ is obtained by sampling according to the transition probability $p_{0t}(y|y_i)$. We train the conditional score-based model using the conditional denoising score matching loss:

$$\mathcal{J}_{\rm CDSM}\left(\theta\right)=\frac{1}{m}\sum_{i}\Vert s_{\theta}(y_{i,t}|x_i,t)-\nabla\log p_{0t}(y_{i,t}|y_i)\Vert_2^2.$$

Generating Samples in Inference

With the trained conditional score-based model, we can generate samples as follows in inference. Given the degraded image $x$ to be restored in inference, the noisy image $y_M$ corresponding to the degraded image can be obtained by sampling from the transition probability $p_{0M}(y|x)$. We take $y_M$ as the initial value, and perform the reverse stochastic differential equation $dy_t=\left[f\left(y_t,t\right)-g^2\left(t\right)s_\theta\left(y_t\middle| x,t\right)\right]dt+g\left(t\right)d\bar{w}$ to generate the corresponding restored image of $x$.

We show the generated images as follows.

./otcs/figures/qualitative_results_celeba.png

The generated images using our method OTCS.

BibTeX

@inproceedings{
      Gu2023optimal,
      title={Optimal Transport-Guided Conditional Score-Based Diffusion Model},
      author={Gu, Xiang and Yang, Liwei and Sun, Jian and Xu, Zongben},
      booktitle={NeurIPS},
      year={2023}
      }
}