from fastai.basics import *
from fastai.vision.models.unet import *
from fastai.vision.all import *
from fastai.torch_basics import *
from denoising_diffusion_pytorch import Unet
diffusion
import wandb
=True)
wandb.init(reinitfrom fastai.callback.wandb import *
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: marii. Use `wandb login --relogin` to force relogin
wandb version 0.13.4 is available! To upgrade, please run:
$ pip install wandb --upgrade
Tracking run with wandb version 0.12.21
Run data is saved locally in
/home/molly/Projects/quatro-blog/posts/diffusion/wandb/run-20221006_011034-nhpww4wt
def gather(consts: torch.Tensor, t: torch.Tensor):
"""Gather consts for $t$ and reshape to feature map shape"""
= consts.gather(-1, t)
c return c.reshape(-1, 1, 1, 1)
class DenoiseDiffusion:
"""
## Denoise Diffusion
"""
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
"""
* `eps_model` is $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
* `n_steps` is $t$
* `device` is the device to place constants on
"""
super().__init__()
self.eps_model = eps_model
# Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
# $\alpha_t = 1 - \beta_t$
self.alpha = 1. - self.beta
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
# $T$
self.n_steps = n_steps
# $\sigma^2 = \beta$
self.sigma2 = self.beta
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
#### Get $q(x_t|x_0)$ distribution
\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}
"""
# [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
= gather(self.alpha_bar, t) ** 0.5 * tensor(x0)
mean # $(1-\bar\alpha_t) \mathbf{I}$
= 1 - gather(self.alpha_bar, t)
var #
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
"""
#### Sample from $q(x_t|x_0)$
\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}
"""
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
if eps is None:
= torch.randn_like(x0)
eps
# get $q(x_t|x_0)$
= self.q_xt_x0(x0, t)
mean, var # Sample from $q(x_t|x_0)$
return mean + (var ** 0.5) * eps
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
"""
#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
\begin{align}
\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
\textcolor{lightgreen}{\mu_\theta}(x_t, t)
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)
\end{align}
"""
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
# NOTEDDDD REMOVED t
= self.eps_model(xt,t)
eps_theta # [gather](utils.html) $\bar\alpha_t$
= gather(self.alpha_bar, t)
alpha_bar # $\alpha_t$
= gather(self.alpha, t)
alpha # $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
= (1 - alpha) / (1 - alpha_bar) ** .5
eps_coef # $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
= 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
mean # $\sigma^2$
= gather(self.sigma2, t)
var
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
= torch.randn(xt.shape, device=xt.device)
eps # Sample
return mean + (var ** .5) * eps
class Q_sample(ItemTransform):
=101
orderdef __init__(self,diffusion):
self.diffusion=diffusion
def encodes(self,xy):
=xy[0]
x=xy[-1]
y= xy[2][:,0]#torch.randint(0, self.diffusion.n_steps, (x.shape[0],), device=x.device, dtype=torch.long)
ts =type(x)
x_type=self.diffusion.q_sample(x, x_type(ts), eps=y)
xreturn (x,*xy[1:-1],y)
class LabelToNoise(ItemTransform):
=100
orderdef encodes(self,xy):
=xy[-1]
yreturn (*xy[:-1],retain_type(torch.randn(y.shape,device=y.device),old=y))
def sample():
"""
### Sample images
"""
with torch.no_grad():
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
= torch.randn([n_samples, image_channels, 32, 32],
x =device)
device
# Remove noise for $T$ steps
for t_ in range(n_steps):
# $t$
= n_steps - t_ - 1
t # Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
= diffusion.p_sample(x, x.new_full((n_samples,), t, dtype=torch.long))
x return x
=1000 n_steps
= untar_data(URLs.MNIST)
path = untar_data(URLs.CIFAR) path
=Unet(dim=32,channels=3)#UnetTime(img_channels=1,dims=[32, 64, 128, 256, 256],ks=3,stem_stride=2).cuda() m
@typedispatch
def show_batch(x:tuple, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(3*min(len(samples), max_n), nrows=nrows, ncols=3, figsize=figsize, title='Input/Original/Target')
0::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(0),ctxs[0::3],range(max_n))]
ctxs[0::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(2),ctxs[0::3],range(max_n))]
ctxs[1::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(1),ctxs[1::3],range(max_n))]
ctxs[2::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(3),ctxs[2::3],range(max_n))]
ctxs[return ctxs
= DenoiseDiffusion(m,n_steps,torch.device(0))
diffusion =DataBlock((ImageBlock(cls=PILImageBW),
dls=PILImageBW),
ImageBlock(cls=[DisplayedTransform(enc=lambda o: TensorCategory(o),dec=Category)]),
TransformBlock(type_tfms=PILImageBW)),
ImageBlock(cls=3,
n_inp=[Resize(32)],
item_tfms=(Normalize.from_stats(0.5,1.),LabelToNoise,Q_sample(diffusion)),
batch_tfms=get_image_files,
get_items=[lambda x:x,lambda x:x,
get_xlambda x: torch.randint(0, n_steps, (1,), dtype=torch.long)],
=GrandparentSplitter(train_name='training', valid_name='testing'),
splitter=128,val_bs=2*128)
).dataloaders(path,bs dls.show_batch()
IndexError: list index out of range
path.ls()
(#3) [Path('/home/molly/data/cifar10/labels.txt'),Path('/home/molly/data/cifar10/test'),Path('/home/molly/data/cifar10/train')]
=128
bs= DenoiseDiffusion(m,n_steps,torch.device(0))
diffusion =DataBlock((ImageBlock(),
dls
ImageBlock(),=[DisplayedTransform(enc=lambda o: TensorCategory(o),dec=Category)]),
TransformBlock(type_tfms
ImageBlock()),=3,
n_inp=[Resize(32)],
item_tfms=(Normalize.from_stats(0.5,1.),LabelToNoise,Q_sample(diffusion)),
batch_tfms=get_image_files,
get_items=[lambda x:x,lambda x:x,
get_xlambda x: torch.randint(0, n_steps, (1,), dtype=torch.long)],
=IndexSplitter(range(bs)),
splitter=bs,val_bs=2*bs)
).dataloaders(path,bs dls.show_batch()
class FlattenCallback(Callback):
def before_batch(self):
self.learn.xb=(self.xb[0],self.xb[-1].view(self.xb[-1].shape[::2]),)
= Learner(dls,m,MSELossFlat(),opt_func=Lamb,cbs=[FlattenCallback,WandbCallback(log_preds_every_epoch=True)]) learn
inp=m.layers0:0 inp.seq_dict[‘t’]=torch.tensor([5]).cuda() m.layers1:4.shape
6,lr=1e-4,wd=0.) learn.fit_flat_cos(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.699252 | 0.648226 | 01:58 |
1 | 0.375051 | 0.354264 | 01:54 |
2 | 0.185453 | 0.168021 | 01:52 |
3 | 0.102715 | 0.080112 | 01:53 |
4 | 0.064207 | 0.043586 | 01:56 |
5 | 0.055305 | 0.053669 | 01:52 |
WandbCallback was not able to get prediction samples -> Match length mismatch
@typedispatch
def show_results(x:tuple, y:TensorImage, samples, outs, ctxs=None, max_n=10, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(6*min(len(samples), max_n), ncols=6, figsize=figsize, title='Input/Original/DenoisedImage/Target/Prediction/Diff')
0::6] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(0),ctxs[0::6],range(max_n))]
ctxs[1::6] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(1),ctxs[1::6],range(max_n))]
ctxs[0::6] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(2),ctxs[0::6],range(max_n))]
ctxs[2::6] = [(b-o).show(ctx=c, **kwargs) for b,o,c,_ in zip(samples.itemgot(0),outs.itemgot(0),ctxs[2::6],range(max_n))]
ctxs[3::6] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(3),ctxs[3::6],range(max_n))]
ctxs[4::6] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs[4::6],range(max_n))]
ctxs[5::6] = [(b-targ).show(ctx=c, **kwargs) for b,targ,c,_ in zip(outs.itemgot(0),samples.itemgot(3),ctxs[5::6],range(max_n))]
ctxs[return ctxs
learn.show_results()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
=12
n_samples=3
image_channels= DenoiseDiffusion(m,n_steps,torch.device(0))
diffusion =torch.device(0)
device= sample() xs
1,3,1,1)-xs.repeat(1,3,1,1).mean())/xs.repeat(1,3,1,1).std()).sigmoid()),nrows=4) show_images((logit((xs.repeat(
-xs.mean())/xs.std()).sigmoid()),nrows=4) show_images((logit((xs
min() xs.
show_images(xs)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
with learn.removed_cbs(WandbCallback):
0][:4]) show_images(dls.one_batch()[
learn.show_results??
=False)[2] dls.show_batch(show
@typedispatch
def wandb_process(x:tuple, y, samples, outs, preds):
"Process `sample` and `out` depending on the type of `x/y`"
= [],[],[]
res_input, res_pred, res_label for s,o in zip(samples, outs):
= s[0].permute(1,2,0)
img ='Input_data'))
res_input.append(wandb.Image(img, captionfor t, capt, res in ((o[0], "Prediction", res_pred), (s[1], "Ground_Truth", res_label)):
= _make_plt(img)
fig, ax # Superimpose label or prediction to input image
= img.show(ctx=ax)
ax = t.show(ctx=ax)
ax =capt))
res.append(wandb.Image(fig, caption
plt.close(fig)return {"Inputs":res_input, "Predictions":res_pred, "Ground_Truth":res_label}
learn.show_results()
%debug
::: {.cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}
from nbdev import nbdev_export
nbdev_export()
:::