class TensorNoise(TensorImageBase):pass
class TensorStep(TensorBase): pass
Add explanation of why subclassing TensorImageBase, and what is TypeDispatch. (title needs to be full name)
We create two new Tensor Types, one for our noise, and one for our timestep.
We would like normalize to denormalize our noise before showing it. This is so the noise in our image looks similar to the noise in our noised image.
@Normalize
def decodes(self, x:TensorNoise):
= to_cpu if x.device.type=='cpu' else noop
f return (x*f(self.std) + f(self.mean))
= Normalize.from_stats(*imagenet_stats)
norm show_images(norm.decode(img))
I patch ItemTransform here, so that it can work off of TypedTuples. Essentially if we have a DiffusionTuple
, the transform will apply to that if it should apply to that type of tuple.
class ItemTransform(Transform):
"A transform that always take tuples as items"
= True
_retain # Only showing important code
def _call1(self:ItemTransform, x, name, **kwargs):
if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
=self._call_tuple(name,x,**kwargs)
yif not self._retain: return y
if is_listy(y) and not isinstance(y, tuple): y = tuple(y)
return retain_type(y, x)
def _call_tuple(self:ItemTransform, name, x, split_idx=None, **kwargs):
= getattr(super(), name)
f ='encodes' if name == '__call__' else 'decodes' if name == 'decode' else name
f2name= getattr(self, f2name)
f2 if isinstance(f2,TypeDispatch) and f2[type(x)] is not None:
if split_idx!=self.split_idx and self.split_idx is not None: return x
= f2(x, **kwargs)
y else:
= f(list(x), **kwargs)
y return y
The general idea is to implement a named tuple, and use duck typing. In the future, we should look at the named tuple class and do something more similar to that.
class DiffusionTuple(fastuple):
def __new__(cls, *rest):
self=super().__new__(cls, *rest)
=0
iself.x=self[i]
if(isinstance(self[i+1],TensorImage)): self.x0=self[i:=i+1]
self.t=self[i:=i+1]
if(len(self)>i+1): self.y=self[i:=i+1]
if(len(self)>i+1): self.pred=self[i:=i+1]
if(len(self)>i+1): self.sampled_pred=self[i:=i+1]
return self
A little transform to make our tuple a DiffusionTuple
class ToDiffusionTuple(ItemTransform):
=100
orderdef encodes(self,xy):
return DiffusionTuple(*xy[:-1],TensorNoise(xy[-1]))
This Transform
expects y to contain an image, and just replaces it with noise. Our model tries to predict the noise in an image.
class LabelToNoise(ItemTransform):
=101
orderdef encodes(self,xy:DiffusionTuple):
=xy.y
y=TensorNoise(torch.randn_like(y))
xy.y[:]return xy
=LabelToNoise.encodes(DiffusionTuple(img[0].detach().clone(),TensorStep(torch.tensor([[200]])),TensorNoise(img[0].clone()))) diff_tuple
We can access tuple elements by attributes. This is useful when you don’t know what index a particular value is located.
diff_tuple.x.shape,diff_tuple.t.shape,diff_tuple.y.shape
(torch.Size([3, 320, 480]), torch.Size([1, 1]), torch.Size([3, 320, 480]))
Now we have a way to create an image, and convert the label to noise.
=True) norm.decode(diff_tuple).show(show_noise
<AxesSubplot:title={'center':'TensorStep([[200]])'}>
Next, we need to go create a noised image, to pass to our model.
First, how much noise to apply to each step?
class LinearNoiseSchedule:
"Schedule like used in DDPM"
def __init__(self,betas=None,n_steps=None,device='cuda'):
if betas is not None: self.n_steps=betas.shape[0]
if n_steps is None: self.n_steps=1000
if betas is None: self.betas = torch.linspace(0.0001, 0.02, self.n_steps,device=device)
self.alphas = 1. - self.betas
self.alpha_bar = torch.cumprod(self.alphas, dim=0)
Lets graph the various values here, in order to see what happens. Pay particularly close attention to alpha_bar
as that controls the balance betwen our signal(image) and our noise.
=LinearNoiseSchedule()
lns
plt.plot((lns.betas).cpu())
plt.plot((lns.alphas).cpu())
plt.plot((lns.alpha_bar).cpu())'betas', 'alphas','alpha_bar']) plt.legend([
<matplotlib.legend.Legend at 0x7f89721ee710>
Next is DDPM-style Q-sampling. This is pretty much used for all diffusion models, and is the process that takes us from and image to noise.
class DDPM_Q_Sampling():
def __init__(self,predicts_x=False,noise_schedule=LinearNoiseSchedule(),n_steps=1000,device='cuda'):
self.device=device
self.ns=noise_schedule
self.n_steps=n_steps
self.t_sched=torch.linspace(0,len(self.ns.alpha_bar)-1,n_steps,dtype=torch.long)[...,None,None,None]
def __call__(self,x,es,t):
=self.t_sched[t]
t=self.ns.alpha_bar[t].to(device=x.device)
a= (a ** .5)*x
signal = (1-a)**.5 * es
noise return signal + noise
def undo(self,z,es,t):
"Goes back to the original image given noise. Only works if es is the original noise. If es is a TensorImage, assumes it is the original."
if(isinstance(es,TensorImage)):
return es
=self.t_sched[t]
t=TensorBase(self.ns.alpha_bar[t].to(device=z.device))
a=TensorBase((1-a)**.5 * es)
noisereturn TensorImage((z-noise)/(a ** .5))
= DiffusionSamplingTransform(DDPM_Q_Sampling(),lambda x:x) diff_trans
norm.decode(diff_trans(diff_tuple)).show()
<AxesSubplot:title={'center':'TensorStep([[200]])'}>
Lets now test so make sure our noise is being generated correctly.
=LabelToNoise.encodes(DiffusionTuple(img[0].detach().clone(),TensorStep(torch.tensor([[999]])),TensorNoise(img[0].clone()))) noise_tuple
=True) norm.decode(diff_trans(noise_tuple)).show(show_noise
<AxesSubplot:title={'center':'TensorStep([[999]])'}>
These are not exactly the same as it is one noising step, but they are fairly close.
0],TensorImage(norm.decode(diff_trans(noise_tuple))[2]),eps=1e-02) is_close(norm.decode(diff_trans(noise_tuple))[
TensorImage(True, device='cuda:0')
Going from noise to and image, p_sampling
P sampling is going from noise to image!
class Diffusion_P_Sampler():
def __init__(self,model,sampling_function):
self.device=sampling_function.device
self.model=model
self.sampling_function=sampling_function
# __call__ implemented, but not shown.
def iter_noise(self,x_t,ts,t_start):
=0
iwhile((ts>0).any()):
=x_t[ts>0],ts[ts>0]
x,twith autocast(device_type=self.device, dtype=x.dtype):
with torch.no_grad():
= self.model(x,self.deconvert(t) if i!=0 else t_start)
e >0]=self.sampling_function(x,e,t,t=t_start if i==0 else None)
x_t[ts>0]-=1
ts[ts+=1
iyield x_t
Notice here that we generate noise as random numbers during our sampling process. This makes DDPM sampling a stocastic process.
@patch
def __call__(self:DDPM_P_Sampling,x,es,ns_t,t=None):
= self.t_sched[ns_t] if(t is None) else t[...,None,None,None]
t=torch.randn_like(x)
n=self._noise_at_t(es,t),self.ns.alphas[t],self.ns.betas[t]
e,a,b= (x - e) / (a ** 0.5)
signal = b**.5 * n
noise return signal + noise
@patch
def _noise_at_t(self:DDPM_P_Sampling,es,t):
= (1 - self.ns.alphas[t]) / (1 - self.ns.alpha_bar[t]) ** .5
eps_coef return eps_coef* es
We implement DDIM sampling here, as it drastically reduces sampling time from 1000 steps to 50. DDIM is also a deterministic sampler, meaning we do not have random numbers generated as part of our sampling process. Just generally helps us keep our sanity when trying to show our results. https://arxiv.org/abs/2010.02502
@patch
def __call__(self:DDIM_P_Sampling,z,es,ns_t,t=None):
if(t is None): t=self.t_sched[ns_t]
=self.t_sched[ns_t-1]
tp1=self.ns.alpha_bar[t][...,None,None,None],self.ns.alpha_bar[tp1][...,None,None,None]
a,a_tp1if self.predicts_x:
=es
xs=(z - (a)**.5 * xs)/(1-a)**.5
eselse: xs=(z - (1-a)**.5 * es)/ (a ** .5)
= a_tp1**.5*(xs)
signal = (1-a_tp1)**.5*es
noise return signal + noise
Training a model
= untar_data(URLs.CIFAR) path
=Unet(dim=192+192//8,channels=3,).cuda() m
=128
bs=1000
n_steps= DiffusionSamplingTransform(DDPM_Q_Sampling(),Diffusion_P_Sampler(m,DDIM_P_Sampling()))
diffusion_transform =DataBlock((ImageBlock(),
dls=[Transform(enc=lambda o: TensorStep(o))]),
TransformBlock(type_tfms
ImageBlock()),=2,
n_inp=[Resize(32)],
item_tfms=(Normalize.from_stats(*cifar_stats),ToDiffusionTuple,LabelToNoise,diffusion_transform),
batch_tfms=get_image_files,
get_items=[lambda x:x,
get_xlambda x: torch.randint(1, n_steps, (1,), dtype=torch.long)],
=IndexSplitter(range(bs)),
splitter=bs,val_bs=2*bs)
).dataloaders(path,bs dls.show_batch()
/home/molly/miniconda3/envs/fastai/lib/python3.10/site-packages/torch/_tensor.py:1121: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
ret = func(*args, **kwargs)
def mse_loss_weighted(ys,targ):
return torch.mean(targ.w_sched[...,None] * ((ys - targ).flatten(start_dim=1) ** 2))
def snr(at): return at/(1-at)
Taken from a paper called “Perception Prioritized Training of Diffusion Models, this is a continuous version of the weights. Which becomes our signal=to-noise ration, over the change in our signal to noise ratio.”https://arxiv.org/abs/2204.00227
def continuous_weights(at):
= -snr(at[1:])/(snr(at[1:])-snr(at[:-1]))
weights return torch.cat((weights[0:1],weights))
This is what the weights look like. Notice I clip the weights at 1.
min=1)) plt.plot(continuous_weights(LinearNoiseSchedule().alpha_bar).cpu().clip(
class WeightedLinSched(Callback):
def after_pred(self):
if(not hasattr(self,'ws')):
self.ws = continuous_weights(LinearNoiseSchedule().alpha_bar).clip(min=1)
self.ws /= self.ws.mean()
=self.learn.xb[1].flatten()
tsself.learn.yb[0].w_sched=self.ws[ts]
= Learner(dls,m,mse_loss_weighted,opt_func=Adam,cbs=[FlattenCallback,WeightedLinSched])
learn = learn.to_fp16()
learn 10,lr=2e-4,wd=1e-4) learn.fit_flat_cos(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.049204 | 0.042475 | 04:47 |
1 | 0.043398 | 0.044983 | 04:48 |
2 | 0.042191 | 0.034068 | 04:49 |
3 | 0.040477 | 0.038086 | 04:48 |
4 | 0.039220 | 0.038858 | 04:49 |
5 | 0.038735 | 0.032376 | 04:49 |
6 | 0.038864 | 0.029669 | 04:49 |
7 | 0.038684 | 0.041698 | 04:49 |
8 | 0.037839 | 0.031916 | 04:50 |
9 | 0.037525 | 0.041718 | 04:50 |
next check show_results
learn.show_results()
=True) learn.show_results(show_noise
channels last, fused optimizers(foreach benjamin’s), jit model.