Skip to content
Snippets Groups Projects
Commit 2b0609b9 authored by jrichter's avatar jrichter
Browse files

48 kHz model

parent b7fa075e
Branches
No related tags found
No related merge requests found
import glob
from argparse import ArgumentParser
from os.path import join
import torch
from os import makedirs
from os.path import join, dirname
from argparse import ArgumentParser
from soundfile import write
from torchaudio import load
from tqdm import tqdm
from sgmse.model import ScoreModel
from sgmse.util.other import ensure_dir, pad_spec
......@@ -19,6 +18,7 @@ if __name__ == '__main__':
parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.")
parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
parser.add_argument("--format", type=str, default='default', help='Format of the directory structure. Use "default" for the default format and "ears" for the EARS format.')
args = parser.parse_args()
noisy_dir = join(args.test_dir, 'noisy/')
......@@ -29,7 +29,6 @@ if __name__ == '__main__':
ensure_dir(target_dir)
# Settings
sr = 16000
snr = args.snr
N = args.N
corrector_steps = args.corrector_steps
......@@ -39,10 +38,21 @@ if __name__ == '__main__':
model.eval(no_ema=False)
model.cuda()
noisy_files = sorted(glob.glob('{}/*.wav'.format(noisy_dir)))
# Check format
if args.format == 'default':
noisy_files = sorted(glob.glob(join(noisy_dir, '*.wav')))
sr = 16000
pad_mode = "zero_pad"
elif args.format == 'ears':
noisy_files = sorted(glob.glob(join(noisy_dir, '**', '*.wav')))
sr = 48000
pad_mode = "reflection"
else:
raise ValueError('Unknown format')
for noisy_file in tqdm(noisy_files):
filename = noisy_file.split('/')[-1]
filename = noisy_file.replace(noisy_dir, "")[1:] # Remove the first character which is a slash
# Load wav
y, _ = load(noisy_file)
......@@ -54,7 +64,7 @@ if __name__ == '__main__':
# Prepare DNN input
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
Y = pad_spec(Y)
Y = pad_spec(Y, mode=pad_mode)
# Reverse sampling
sampler = model.get_pc_sampler(
......@@ -69,4 +79,5 @@ if __name__ == '__main__':
x_hat = x_hat * norm_factor
# Write enhanced wav file
write(join(target_dir, filename), x_hat.cpu().numpy(), 16000)
makedirs(dirname(join(target_dir, filename)), exist_ok=True)
write(join(target_dir, filename), x_hat.cpu().numpy(), sr)
from .shared import BackboneRegistry
from .ncsnpp import NCSNpp
from .ncsnpp_48k import NCSNpp_48k
from .dcunet import DCUNet
__all__ = ['BackboneRegistry', 'NCSNpp', 'DCUNet']
__all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
from .ncsnpp_utils import layers, layerspp, normalization
import torch.nn as nn
import functools
import torch
import numpy as np
from .shared import BackboneRegistry
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
Combine = layerspp.Combine
conv3x3 = layerspp.conv3x3
conv1x1 = layerspp.conv1x1
get_act = layers.get_act
get_normalization = normalization.get_normalization
default_initializer = layers.default_init
@BackboneRegistry.register("ncsnpp_48k")
class NCSNpp_48k(nn.Module):
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
parser.add_argument("--num_res_blocks", type=int, default=2)
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
parser.set_defaults(centered=True)
return parser
def __init__(self,
scale_by_sigma = True,
nonlinearity = 'swish',
nf = 128,
ch_mult = (1, 1, 2, 2, 2, 2, 2),
num_res_blocks = 2,
attn_resolutions = (),
resamp_with_conv = True,
conditional = True,
fir = True,
fir_kernel = [1, 3, 3, 1],
skip_rescale = True,
resblock_type = 'biggan',
progressive = 'none',
progressive_input = 'none',
progressive_combine = 'sum',
init_scale = 0.,
fourier_scale = 16,
image_size = 256,
embedding_type = 'fourier',
dropout = .0,
centered = True,
**unused_kwargs
):
super().__init__()
self.act = act = get_act(nonlinearity)
self.nf = nf = nf
ch_mult = ch_mult
self.num_res_blocks = num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions = attn_resolutions
dropout = dropout
resamp_with_conv = resamp_with_conv
self.num_resolutions = num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
self.conditional = conditional = conditional # noise-conditional
self.centered = centered
self.scale_by_sigma = scale_by_sigma
fir = fir
fir_kernel = fir_kernel
self.skip_rescale = skip_rescale = skip_rescale
self.resblock_type = resblock_type = resblock_type.lower()
self.progressive = progressive = progressive.lower()
self.progressive_input = progressive_input = progressive_input.lower()
self.embedding_type = embedding_type = embedding_type.lower()
init_scale = init_scale
assert progressive in ['none', 'output_skip', 'residual']
assert progressive_input in ['none', 'input_skip', 'residual']
assert embedding_type in ['fourier', 'positional']
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
num_channels = 4 # x.real, x.imag, y.real, y.imag
self.output_layer = nn.Conv2d(num_channels, 2, 1)
modules = []
# timestep/noise_level embedding
if embedding_type == 'fourier':
# Gaussian Fourier features embeddings.
modules.append(layerspp.GaussianFourierProjection(
embedding_size=nf, scale=fourier_scale
))
embed_dim = 2 * nf
elif embedding_type == 'positional':
embed_dim = nf
else:
raise ValueError(f'embedding type {embedding_type} unknown.')
if conditional:
modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(layerspp.AttnBlockpp,
init_scale=init_scale, skip_rescale=skip_rescale)
Upsample = functools.partial(layerspp.Upsample,
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive == 'output_skip':
self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
elif progressive == 'residual':
pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
fir_kernel=fir_kernel, with_conv=True)
Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive_input == 'input_skip':
self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
elif progressive_input == 'residual':
pyramid_downsample = functools.partial(layerspp.Downsample,
fir=fir, fir_kernel=fir_kernel, with_conv=True)
if resblock_type == 'ddpm':
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
dropout=dropout, init_scale=init_scale,
skip_rescale=skip_rescale, temb_dim=nf * 4)
elif resblock_type == 'biggan':
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
else:
raise ValueError(f'resblock type {resblock_type} unrecognized.')
# Downsampling block
channels = num_channels
if progressive_input != 'none':
input_pyramid_ch = channels
modules.append(conv3x3(channels, nf))
hs_c = [nf]
in_ch = nf
for i_level in range(num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != num_resolutions - 1:
if resblock_type == 'ddpm':
modules.append(Downsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
if progressive_input == 'input_skip':
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == 'cat':
in_ch *= 2
elif progressive_input == 'residual':
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
in_ch = hs_c[-1]
modules.append(ResnetBlock(in_ch=in_ch))
modules.append(AttnBlock(channels=in_ch))
modules.append(ResnetBlock(in_ch=in_ch))
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(num_resolutions)):
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != 'none':
if i_level == num_resolutions - 1:
if progressive == 'output_skip':
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
pyramid_ch = channels
elif progressive == 'residual':
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, in_ch, bias=True))
pyramid_ch = in_ch
else:
raise ValueError(f'{progressive} is not a valid name.')
else:
if progressive == 'output_skip':
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
pyramid_ch = channels
elif progressive == 'residual':
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f'{progressive} is not a valid name')
if i_level != 0:
if resblock_type == 'ddpm':
modules.append(Upsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
assert not hs_c
if progressive != 'output_skip':
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules)
def forward(self, x, time_cond):
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
# Convert real and imaginary parts of (x,y) into four channel dimensions
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
if self.embedding_type == 'fourier':
# Gaussian Fourier features embeddings.
used_sigmas = time_cond
temb = modules[m_idx](torch.log(used_sigmas))
m_idx += 1
elif self.embedding_type == 'positional':
# Sinusoidal positional embeddings.
timesteps = time_cond
used_sigmas = self.sigmas[time_cond.long()]
temb = layers.get_timestep_embedding(timesteps, self.nf)
else:
raise ValueError(f'embedding type {self.embedding_type} unknown.')
if self.conditional:
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
else:
temb = None
if not self.centered:
# If input data is in [0, 1]
x = 2 * x - 1.
# Downsampling block
input_pyramid = None
if self.progressive_input != 'none':
input_pyramid = x
# Input layer: Conv2d: 4ch -> 128ch
hs = [modules[m_idx](x)]
m_idx += 1
# Down path in U-Net
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
# Attention layer (optional)
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
h = modules[m_idx](h)
m_idx += 1
hs.append(h)
# Downsampling
if i_level != self.num_resolutions - 1:
if self.resblock_type == 'ddpm':
h = modules[m_idx](hs[-1])
m_idx += 1
else:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == 'input_skip': # Combine h with x
input_pyramid = self.pyramid_downsample(input_pyramid)
h = modules[m_idx](input_pyramid, h)
m_idx += 1
elif self.progressive_input == 'residual':
input_pyramid = modules[m_idx](input_pyramid)
m_idx += 1
if self.skip_rescale:
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
else:
input_pyramid = input_pyramid + h
h = input_pyramid
hs.append(h)
h = hs[-1] # actualy equal to: h = h
h = modules[m_idx](h, temb) # ResNet block
m_idx += 1
h = modules[m_idx](h) # Attention block
m_idx += 1
h = modules[m_idx](h, temb) # ResNet block
m_idx += 1
pyramid = None
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
m_idx += 1
# edit: from -1 to -2
if h.shape[-2] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if self.progressive != 'none':
if i_level == self.num_resolutions - 1:
if self.progressive == 'output_skip':
pyramid = self.act(modules[m_idx](h)) # GroupNorm
m_idx += 1
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
m_idx += 1
elif self.progressive == 'residual':
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
else:
raise ValueError(f'{self.progressive} is not a valid name.')
else:
if self.progressive == 'output_skip':
pyramid = self.pyramid_upsample(pyramid) # Upsample
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
m_idx += 1
pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1
pyramid = pyramid + pyramid_h
elif self.progressive == 'residual':
pyramid = modules[m_idx](pyramid)
m_idx += 1
if self.skip_rescale:
pyramid = (pyramid + h) / np.sqrt(2.)
else:
pyramid = pyramid + h
h = pyramid
else:
raise ValueError(f'{self.progressive} is not a valid name')
# Upsampling Layer
if i_level != 0:
if self.resblock_type == 'ddpm':
h = modules[m_idx](h)
m_idx += 1
else:
h = modules[m_idx](h, temb) # Upspampling
m_idx += 1
assert not hs
if self.progressive == 'output_skip':
h = pyramid
else:
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules), "Implementation error"
# Convert back to complex number
h = self.output_layer(h)
if self.scale_by_sigma:
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
h = h / used_sigmas
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
h = torch.view_as_complex(h)[:,None, :, :]
return h
......@@ -56,9 +56,13 @@ def get_pc_sampler(
timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
for i in range(sde.N):
t = timesteps[i]
if i != len(timesteps) - 1:
stepsize = t - timesteps[i+1]
else:
stepsize = timesteps[-1] # from eps to 0
vec_t = torch.ones(y.shape[0], device=y.device) * t
xt, xt_mean = corrector.update_fn(xt, vec_t, y)
xt, xt_mean = predictor.update_fn(xt, vec_t, y)
xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
x_result = xt_mean if denoise else xt
ns = sde.N * (corrector.n_steps + 1)
return x_result, ns
......
......@@ -57,8 +57,8 @@ class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow=probability_flow)
def update_fn(self, x, t, *args):
f, g = self.rsde.discretize(x, t, *args)
def update_fn(self, x, t, y, stepsize):
f, g = self.rsde.discretize(x, t, y, stepsize)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + g[:, None, None, None] * z
......
......@@ -69,7 +69,7 @@ class SDE(abc.ABC):
"""
pass
def discretize(self, x, t, *args):
def discretize(self, x, t, y, stepsize):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
......@@ -82,10 +82,10 @@ class SDE(abc.ABC):
Returns:
f, G
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t, *args)
dt = stepsize
drift, diffusion = self.sde(x, t, y)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
G = diffusion * torch.sqrt(dt)
return f, G
def reverse(oself, score_model, probability_flow=False):
......@@ -127,10 +127,10 @@ class SDE(abc.ABC):
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
}
def discretize(self, x, t, *args):
def discretize(self, x, t, y, stepsize):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t, *args)
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * (0.5 if self.probability_flow else 1.)
f, G = discretize_fn(x, t, y, stepsize)
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
......@@ -198,6 +198,9 @@ class OUVESDE(SDE):
exp_interp = torch.exp(-theta * t)[:, None, None, None]
return exp_interp * x0 + (1 - exp_interp) * y
def alpha(self, t):
return torch.exp(-self.theta * t)
def _std(self, t):
# This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
......
......@@ -80,16 +80,22 @@ def snr_dB(s,n):
snr_dB = 10*np.log10(s_power/n_power)
return snr_dB
def pad_spec(Y):
def pad_spec(Y, mode="zero_pad"):
T = Y.size(3)
if T%64 !=0:
num_pad = 64-T%64
else:
num_pad = 0
if mode == "zero_pad":
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
elif mode == "reflection":
pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
elif mode == "replication":
pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
else:
raise NotImplementedError("This function hasn't been implemented yet.")
return pad2d(Y)
def ensure_dir(file_path):
directory = file_path
if not os.path.exists(directory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment