diff --git a/enhancement.py b/enhancement.py index a51a32825c925045120e703839a07f3191fdeca0..cee35ace2f181e2ad7e4576f71c4caa1b8e2ef1e 100644 --- a/enhancement.py +++ b/enhancement.py @@ -1,12 +1,11 @@ import glob -from argparse import ArgumentParser -from os.path import join - import torch +from os import makedirs +from os.path import join, dirname +from argparse import ArgumentParser from soundfile import write from torchaudio import load from tqdm import tqdm - from sgmse.model import ScoreModel from sgmse.util.other import ensure_dir, pad_spec @@ -19,6 +18,7 @@ if __name__ == '__main__': parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps") parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.") parser.add_argument("--N", type=int, default=30, help="Number of reverse steps") + parser.add_argument("--format", type=str, default='default', help='Format of the directory structure. Use "default" for the default format and "ears" for the EARS format.') args = parser.parse_args() noisy_dir = join(args.test_dir, 'noisy/') @@ -29,7 +29,6 @@ if __name__ == '__main__': ensure_dir(target_dir) # Settings - sr = 16000 snr = args.snr N = args.N corrector_steps = args.corrector_steps @@ -39,10 +38,21 @@ if __name__ == '__main__': model.eval(no_ema=False) model.cuda() - noisy_files = sorted(glob.glob('{}/*.wav'.format(noisy_dir))) + # Check format + if args.format == 'default': + noisy_files = sorted(glob.glob(join(noisy_dir, '*.wav'))) + sr = 16000 + pad_mode = "zero_pad" + elif args.format == 'ears': + noisy_files = sorted(glob.glob(join(noisy_dir, '**', '*.wav'))) + sr = 48000 + pad_mode = "reflection" + else: + raise ValueError('Unknown format') for noisy_file in tqdm(noisy_files): filename = noisy_file.split('/')[-1] + filename = noisy_file.replace(noisy_dir, "")[1:] # Remove the first character which is a slash # Load wav y, _ = load(noisy_file) @@ -54,7 +64,7 @@ if __name__ == '__main__': # Prepare DNN input Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) - Y = pad_spec(Y) + Y = pad_spec(Y, mode=pad_mode) # Reverse sampling sampler = model.get_pc_sampler( @@ -69,4 +79,5 @@ if __name__ == '__main__': x_hat = x_hat * norm_factor # Write enhanced wav file - write(join(target_dir, filename), x_hat.cpu().numpy(), 16000) + makedirs(dirname(join(target_dir, filename)), exist_ok=True) + write(join(target_dir, filename), x_hat.cpu().numpy(), sr) diff --git a/sgmse/backbones/__init__.py b/sgmse/backbones/__init__.py index 386d0a022395bd8343e9f7810b40314cd50b4a70..508a62b40b8b623fca3563fefdf98799f71432f1 100644 --- a/sgmse/backbones/__init__.py +++ b/sgmse/backbones/__init__.py @@ -1,5 +1,6 @@ from .shared import BackboneRegistry from .ncsnpp import NCSNpp +from .ncsnpp_48k import NCSNpp_48k from .dcunet import DCUNet -__all__ = ['BackboneRegistry', 'NCSNpp', 'DCUNet'] +__all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet'] diff --git a/sgmse/backbones/ncsnpp_48k.py b/sgmse/backbones/ncsnpp_48k.py new file mode 100644 index 0000000000000000000000000000000000000000..1737baf0c8d518f59bc354635b415dbe48bffb20 --- /dev/null +++ b/sgmse/backbones/ncsnpp_48k.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file + +from .ncsnpp_utils import layers, layerspp, normalization +import torch.nn as nn +import functools +import torch +import numpy as np + +from .shared import BackboneRegistry + +ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp +ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp +Combine = layerspp.Combine +conv3x3 = layerspp.conv3x3 +conv1x1 = layerspp.conv1x1 +get_act = layers.get_act +get_normalization = normalization.get_normalization +default_initializer = layers.default_init + + +@BackboneRegistry.register("ncsnpp_48k") +class NCSNpp_48k(nn.Module): + """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository""" + + @staticmethod + def add_argparse_args(parser): + parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2]) + parser.add_argument("--num_res_blocks", type=int, default=2) + parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[]) + parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model") + parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[]) + parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]") + parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]") + parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method") + parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method") + parser.set_defaults(centered=True) + return parser + + def __init__(self, + scale_by_sigma = True, + nonlinearity = 'swish', + nf = 128, + ch_mult = (1, 1, 2, 2, 2, 2, 2), + num_res_blocks = 2, + attn_resolutions = (), + resamp_with_conv = True, + conditional = True, + fir = True, + fir_kernel = [1, 3, 3, 1], + skip_rescale = True, + resblock_type = 'biggan', + progressive = 'none', + progressive_input = 'none', + progressive_combine = 'sum', + init_scale = 0., + fourier_scale = 16, + image_size = 256, + embedding_type = 'fourier', + dropout = .0, + centered = True, + **unused_kwargs + ): + super().__init__() + self.act = act = get_act(nonlinearity) + + self.nf = nf = nf + ch_mult = ch_mult + self.num_res_blocks = num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions = attn_resolutions + dropout = dropout + resamp_with_conv = resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)] + + self.conditional = conditional = conditional # noise-conditional + self.centered = centered + self.scale_by_sigma = scale_by_sigma + + fir = fir + fir_kernel = fir_kernel + self.skip_rescale = skip_rescale = skip_rescale + self.resblock_type = resblock_type = resblock_type.lower() + self.progressive = progressive = progressive.lower() + self.progressive_input = progressive_input = progressive_input.lower() + self.embedding_type = embedding_type = embedding_type.lower() + init_scale = init_scale + assert progressive in ['none', 'output_skip', 'residual'] + assert progressive_input in ['none', 'input_skip', 'residual'] + assert embedding_type in ['fourier', 'positional'] + combine_method = progressive_combine.lower() + combiner = functools.partial(Combine, method=combine_method) + + num_channels = 4 # x.real, x.imag, y.real, y.imag + self.output_layer = nn.Conv2d(num_channels, 2, 1) + + modules = [] + # timestep/noise_level embedding + if embedding_type == 'fourier': + # Gaussian Fourier features embeddings. + modules.append(layerspp.GaussianFourierProjection( + embedding_size=nf, scale=fourier_scale + )) + embed_dim = 2 * nf + elif embedding_type == 'positional': + embed_dim = nf + else: + raise ValueError(f'embedding type {embedding_type} unknown.') + + if conditional: + modules.append(nn.Linear(embed_dim, nf * 4)) + modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + + AttnBlock = functools.partial(layerspp.AttnBlockpp, + init_scale=init_scale, skip_rescale=skip_rescale) + + Upsample = functools.partial(layerspp.Upsample, + with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive == 'output_skip': + self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive == 'residual': + pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir, + fir_kernel=fir_kernel, with_conv=True) + + Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive_input == 'input_skip': + self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive_input == 'residual': + pyramid_downsample = functools.partial(layerspp.Downsample, + fir=fir, fir_kernel=fir_kernel, with_conv=True) + + if resblock_type == 'ddpm': + ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, + dropout=dropout, init_scale=init_scale, + skip_rescale=skip_rescale, temb_dim=nf * 4) + + elif resblock_type == 'biggan': + ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act, + dropout=dropout, fir=fir, fir_kernel=fir_kernel, + init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4) + + else: + raise ValueError(f'resblock type {resblock_type} unrecognized.') + + # Downsampling block + + channels = num_channels + if progressive_input != 'none': + input_pyramid_ch = channels + + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + + in_ch = nf + for i_level in range(num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + + if i_level != num_resolutions - 1: + if resblock_type == 'ddpm': + modules.append(Downsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(down=True, in_ch=in_ch)) + + if progressive_input == 'input_skip': + modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) + if combine_method == 'cat': + in_ch *= 2 + + elif progressive_input == 'residual': + modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + input_pyramid_ch = in_ch + + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + pyramid_ch = 0 + # Upsampling block + for i_level in reversed(range(num_resolutions)): + for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling) + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + + if progressive != 'none': + if i_level == num_resolutions - 1: + if progressive == 'output_skip': + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == 'residual': + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, in_ch, bias=True)) + pyramid_ch = in_ch + else: + raise ValueError(f'{progressive} is not a valid name.') + else: + if progressive == 'output_skip': + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == 'residual': + modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + pyramid_ch = in_ch + else: + raise ValueError(f'{progressive} is not a valid name') + + if i_level != 0: + if resblock_type == 'ddpm': + modules.append(Upsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(in_ch=in_ch, up=True)) + + assert not hs_c + + if progressive != 'output_skip': + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + + self.all_modules = nn.ModuleList(modules) + + + def forward(self, x, time_cond): + # timestep/noise_level embedding; only for continuous training + modules = self.all_modules + m_idx = 0 + + # Convert real and imaginary parts of (x,y) into four channel dimensions + x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag, + x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1) + + if self.embedding_type == 'fourier': + # Gaussian Fourier features embeddings. + used_sigmas = time_cond + temb = modules[m_idx](torch.log(used_sigmas)) + m_idx += 1 + + elif self.embedding_type == 'positional': + # Sinusoidal positional embeddings. + timesteps = time_cond + used_sigmas = self.sigmas[time_cond.long()] + temb = layers.get_timestep_embedding(timesteps, self.nf) + + else: + raise ValueError(f'embedding type {self.embedding_type} unknown.') + + if self.conditional: + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if not self.centered: + # If input data is in [0, 1] + x = 2 * x - 1. + + # Downsampling block + input_pyramid = None + if self.progressive_input != 'none': + input_pyramid = x + + # Input layer: Conv2d: 4ch -> 128ch + hs = [modules[m_idx](x)] + m_idx += 1 + + # Down path in U-Net + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + # Attention layer (optional) + if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1) + h = modules[m_idx](h) + m_idx += 1 + hs.append(h) + + # Downsampling + if i_level != self.num_resolutions - 1: + if self.resblock_type == 'ddpm': + h = modules[m_idx](hs[-1]) + m_idx += 1 + else: + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + + if self.progressive_input == 'input_skip': # Combine h with x + input_pyramid = self.pyramid_downsample(input_pyramid) + h = modules[m_idx](input_pyramid, h) + m_idx += 1 + + elif self.progressive_input == 'residual': + input_pyramid = modules[m_idx](input_pyramid) + m_idx += 1 + if self.skip_rescale: + input_pyramid = (input_pyramid + h) / np.sqrt(2.) + else: + input_pyramid = input_pyramid + h + h = input_pyramid + hs.append(h) + + h = hs[-1] # actualy equal to: h = h + h = modules[m_idx](h, temb) # ResNet block + m_idx += 1 + h = modules[m_idx](h) # Attention block + m_idx += 1 + h = modules[m_idx](h, temb) # ResNet block + m_idx += 1 + + pyramid = None + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + # edit: from -1 to -2 + if h.shape[-2] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + if self.progressive != 'none': + if i_level == self.num_resolutions - 1: + if self.progressive == 'output_skip': + pyramid = self.act(modules[m_idx](h)) # GroupNorm + m_idx += 1 + pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4 + m_idx += 1 + elif self.progressive == 'residual': + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + else: + raise ValueError(f'{self.progressive} is not a valid name.') + else: + if self.progressive == 'output_skip': + pyramid = self.pyramid_upsample(pyramid) # Upsample + pyramid_h = self.act(modules[m_idx](h)) # GroupNorm + m_idx += 1 + pyramid_h = modules[m_idx](pyramid_h) + m_idx += 1 + pyramid = pyramid + pyramid_h + elif self.progressive == 'residual': + pyramid = modules[m_idx](pyramid) + m_idx += 1 + if self.skip_rescale: + pyramid = (pyramid + h) / np.sqrt(2.) + else: + pyramid = pyramid + h + h = pyramid + else: + raise ValueError(f'{self.progressive} is not a valid name') + + # Upsampling Layer + if i_level != 0: + if self.resblock_type == 'ddpm': + h = modules[m_idx](h) + m_idx += 1 + else: + h = modules[m_idx](h, temb) # Upspampling + m_idx += 1 + + assert not hs + + if self.progressive == 'output_skip': + h = pyramid + else: + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + + assert m_idx == len(modules), "Implementation error" + + # Convert back to complex number + h = self.output_layer(h) + + if self.scale_by_sigma: + used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) + h = h / used_sigmas + + h = torch.permute(h, (0, 2, 3, 1)).contiguous() + h = torch.view_as_complex(h)[:,None, :, :] + return h diff --git a/sgmse/sampling/__init__.py b/sgmse/sampling/__init__.py index 0046a1c1948087d6d97871c4d251a6d2c8daf075..44ee62a5e02b3e3ee0900469f5150f4e22be74d0 100644 --- a/sgmse/sampling/__init__.py +++ b/sgmse/sampling/__init__.py @@ -56,9 +56,13 @@ def get_pc_sampler( timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device) for i in range(sde.N): t = timesteps[i] + if i != len(timesteps) - 1: + stepsize = t - timesteps[i+1] + else: + stepsize = timesteps[-1] # from eps to 0 vec_t = torch.ones(y.shape[0], device=y.device) * t xt, xt_mean = corrector.update_fn(xt, vec_t, y) - xt, xt_mean = predictor.update_fn(xt, vec_t, y) + xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize) x_result = xt_mean if denoise else xt ns = sde.N * (corrector.n_steps + 1) return x_result, ns diff --git a/sgmse/sampling/predictors.py b/sgmse/sampling/predictors.py index 93437d9fe3398494fbda113428292c63dbd98d02..d5520ddfd7bf54cdedc175f44e821370c190b1e8 100644 --- a/sgmse/sampling/predictors.py +++ b/sgmse/sampling/predictors.py @@ -57,8 +57,8 @@ class ReverseDiffusionPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow=probability_flow) - def update_fn(self, x, t, *args): - f, g = self.rsde.discretize(x, t, *args) + def update_fn(self, x, t, y, stepsize): + f, g = self.rsde.discretize(x, t, y, stepsize) z = torch.randn_like(x) x_mean = x - f x = x_mean + g[:, None, None, None] * z diff --git a/sgmse/sdes.py b/sgmse/sdes.py index affa02e14e47401c14e3a706480c7bfb2bb8403d..6c6a230945b0e5dbc30e5f0896131cb0d1ba9bba 100644 --- a/sgmse/sdes.py +++ b/sgmse/sdes.py @@ -69,7 +69,7 @@ class SDE(abc.ABC): """ pass - def discretize(self, x, t, *args): + def discretize(self, x, t, y, stepsize): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. @@ -82,10 +82,10 @@ class SDE(abc.ABC): Returns: f, G """ - dt = 1 / self.N - drift, diffusion = self.sde(x, t, *args) + dt = stepsize + drift, diffusion = self.sde(x, t, y) f = drift * dt - G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) + G = diffusion * torch.sqrt(dt) return f, G def reverse(oself, score_model, probability_flow=False): @@ -127,10 +127,10 @@ class SDE(abc.ABC): 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score, } - def discretize(self, x, t, *args): + def discretize(self, x, t, y, stepsize): """Create discretized iteration rules for the reverse diffusion sampler.""" - f, G = discretize_fn(x, t, *args) - rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * (0.5 if self.probability_flow else 1.) + f, G = discretize_fn(x, t, y, stepsize) + rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G @@ -198,6 +198,9 @@ class OUVESDE(SDE): exp_interp = torch.exp(-theta * t)[:, None, None, None] return exp_interp * x0 + (1 - exp_interp) * y + def alpha(self, t): + return torch.exp(-self.theta * t) + def _std(self, t): # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde() sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig diff --git a/sgmse/util/other.py b/sgmse/util/other.py index 565b1cee67fcdabf667f27784db4f9db14982c4c..fbe5052f348def7eb05ed4682cfed3700f8e2ab0 100644 --- a/sgmse/util/other.py +++ b/sgmse/util/other.py @@ -80,16 +80,22 @@ def snr_dB(s,n): snr_dB = 10*np.log10(s_power/n_power) return snr_dB -def pad_spec(Y): +def pad_spec(Y, mode="zero_pad"): T = Y.size(3) if T%64 !=0: num_pad = 64-T%64 else: num_pad = 0 - pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) + if mode == "zero_pad": + pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) + elif mode == "reflection": + pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0)) + elif mode == "replication": + pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0)) + else: + raise NotImplementedError("This function hasn't been implemented yet.") return pad2d(Y) - def ensure_dir(file_path): directory = file_path if not os.path.exists(directory):