diff --git a/enhancement.py b/enhancement.py
index a51a32825c925045120e703839a07f3191fdeca0..cee35ace2f181e2ad7e4576f71c4caa1b8e2ef1e 100644
--- a/enhancement.py
+++ b/enhancement.py
@@ -1,12 +1,11 @@
 import glob
-from argparse import ArgumentParser
-from os.path import join
-
 import torch
+from os import makedirs
+from os.path import join, dirname
+from argparse import ArgumentParser
 from soundfile import write
 from torchaudio import load
 from tqdm import tqdm
-
 from sgmse.model import ScoreModel
 from sgmse.util.other import ensure_dir, pad_spec
 
@@ -19,6 +18,7 @@ if __name__ == '__main__':
     parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
     parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.")
     parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
+    parser.add_argument("--format", type=str, default='default', help='Format of the directory structure. Use "default" for the default format and "ears" for the EARS format.')
     args = parser.parse_args()
 
     noisy_dir = join(args.test_dir, 'noisy/')
@@ -29,7 +29,6 @@ if __name__ == '__main__':
     ensure_dir(target_dir)
 
     # Settings
-    sr = 16000
     snr = args.snr
     N = args.N
     corrector_steps = args.corrector_steps
@@ -39,10 +38,21 @@ if __name__ == '__main__':
     model.eval(no_ema=False)
     model.cuda()
 
-    noisy_files = sorted(glob.glob('{}/*.wav'.format(noisy_dir)))
+    # Check format
+    if args.format == 'default':
+        noisy_files = sorted(glob.glob(join(noisy_dir, '*.wav')))
+        sr = 16000
+        pad_mode = "zero_pad"
+    elif args.format == 'ears':
+        noisy_files = sorted(glob.glob(join(noisy_dir, '**', '*.wav')))
+        sr = 48000
+        pad_mode = "reflection"
+    else:
+        raise ValueError('Unknown format')
 
     for noisy_file in tqdm(noisy_files):
         filename = noisy_file.split('/')[-1]
+        filename = noisy_file.replace(noisy_dir, "")[1:] # Remove the first character which is a slash
         
         # Load wav
         y, _ = load(noisy_file) 
@@ -54,7 +64,7 @@ if __name__ == '__main__':
         
         # Prepare DNN input
         Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
-        Y = pad_spec(Y)
+        Y = pad_spec(Y, mode=pad_mode)
         
         # Reverse sampling
         sampler = model.get_pc_sampler(
@@ -69,4 +79,5 @@ if __name__ == '__main__':
         x_hat = x_hat * norm_factor
 
         # Write enhanced wav file
-        write(join(target_dir, filename), x_hat.cpu().numpy(), 16000)
+        makedirs(dirname(join(target_dir, filename)), exist_ok=True)
+        write(join(target_dir, filename), x_hat.cpu().numpy(), sr)
diff --git a/sgmse/backbones/__init__.py b/sgmse/backbones/__init__.py
index 386d0a022395bd8343e9f7810b40314cd50b4a70..508a62b40b8b623fca3563fefdf98799f71432f1 100644
--- a/sgmse/backbones/__init__.py
+++ b/sgmse/backbones/__init__.py
@@ -1,5 +1,6 @@
 from .shared import BackboneRegistry
 from .ncsnpp import NCSNpp
+from .ncsnpp_48k import NCSNpp_48k
 from .dcunet import DCUNet
 
-__all__ = ['BackboneRegistry', 'NCSNpp', 'DCUNet']
+__all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
diff --git a/sgmse/backbones/ncsnpp_48k.py b/sgmse/backbones/ncsnpp_48k.py
new file mode 100644
index 0000000000000000000000000000000000000000..1737baf0c8d518f59bc354635b415dbe48bffb20
--- /dev/null
+++ b/sgmse/backbones/ncsnpp_48k.py
@@ -0,0 +1,425 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+
+from .ncsnpp_utils import layers, layerspp, normalization
+import torch.nn as nn
+import functools
+import torch
+import numpy as np
+
+from .shared import BackboneRegistry
+
+ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
+ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
+Combine = layerspp.Combine
+conv3x3 = layerspp.conv3x3
+conv1x1 = layerspp.conv1x1
+get_act = layers.get_act
+get_normalization = normalization.get_normalization
+default_initializer = layers.default_init
+
+
+@BackboneRegistry.register("ncsnpp_48k")
+class NCSNpp_48k(nn.Module):
+    """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
+
+    @staticmethod
+    def add_argparse_args(parser):
+        parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
+        parser.add_argument("--num_res_blocks", type=int, default=2)
+        parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
+        parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
+        parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
+        parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
+        parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
+        parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
+        parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
+        parser.set_defaults(centered=True)
+        return parser
+
+    def __init__(self,
+        scale_by_sigma = True,
+        nonlinearity = 'swish',
+        nf = 128,
+        ch_mult = (1, 1, 2, 2, 2, 2, 2),
+        num_res_blocks = 2,
+        attn_resolutions = (),
+        resamp_with_conv = True,
+        conditional = True,
+        fir = True,
+        fir_kernel = [1, 3, 3, 1],
+        skip_rescale = True,
+        resblock_type = 'biggan',
+        progressive = 'none',
+        progressive_input = 'none',
+        progressive_combine = 'sum',
+        init_scale = 0.,
+        fourier_scale = 16,
+        image_size = 256,
+        embedding_type = 'fourier',
+        dropout = .0,
+        centered = True,
+        **unused_kwargs
+    ):
+        super().__init__()
+        self.act = act = get_act(nonlinearity)
+
+        self.nf = nf = nf
+        ch_mult = ch_mult
+        self.num_res_blocks = num_res_blocks = num_res_blocks
+        self.attn_resolutions = attn_resolutions = attn_resolutions
+        dropout = dropout
+        resamp_with_conv = resamp_with_conv
+        self.num_resolutions = num_resolutions = len(ch_mult)
+        self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
+
+        self.conditional = conditional = conditional  # noise-conditional
+        self.centered = centered
+        self.scale_by_sigma = scale_by_sigma
+
+        fir = fir
+        fir_kernel = fir_kernel
+        self.skip_rescale = skip_rescale = skip_rescale
+        self.resblock_type = resblock_type = resblock_type.lower()
+        self.progressive = progressive = progressive.lower()
+        self.progressive_input = progressive_input = progressive_input.lower()
+        self.embedding_type = embedding_type = embedding_type.lower()
+        init_scale = init_scale
+        assert progressive in ['none', 'output_skip', 'residual']
+        assert progressive_input in ['none', 'input_skip', 'residual']
+        assert embedding_type in ['fourier', 'positional']
+        combine_method = progressive_combine.lower()
+        combiner = functools.partial(Combine, method=combine_method)
+
+        num_channels = 4  # x.real, x.imag, y.real, y.imag
+        self.output_layer = nn.Conv2d(num_channels, 2, 1)
+
+        modules = []
+        # timestep/noise_level embedding
+        if embedding_type == 'fourier':
+            # Gaussian Fourier features embeddings.
+            modules.append(layerspp.GaussianFourierProjection(
+                embedding_size=nf, scale=fourier_scale
+            ))
+            embed_dim = 2 * nf
+        elif embedding_type == 'positional':
+            embed_dim = nf
+        else:
+            raise ValueError(f'embedding type {embedding_type} unknown.')
+
+        if conditional:
+            modules.append(nn.Linear(embed_dim, nf * 4))
+            modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
+            nn.init.zeros_(modules[-1].bias)
+            modules.append(nn.Linear(nf * 4, nf * 4))
+            modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
+            nn.init.zeros_(modules[-1].bias)
+
+        AttnBlock = functools.partial(layerspp.AttnBlockpp,
+            init_scale=init_scale, skip_rescale=skip_rescale)
+
+        Upsample = functools.partial(layerspp.Upsample,
+            with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
+
+        if progressive == 'output_skip':
+            self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
+        elif progressive == 'residual':
+            pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
+                fir_kernel=fir_kernel, with_conv=True)
+
+        Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
+
+        if progressive_input == 'input_skip':
+            self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
+        elif progressive_input == 'residual':
+            pyramid_downsample = functools.partial(layerspp.Downsample,
+                fir=fir, fir_kernel=fir_kernel, with_conv=True)
+
+        if resblock_type == 'ddpm':
+            ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
+                dropout=dropout, init_scale=init_scale,
+                skip_rescale=skip_rescale, temb_dim=nf * 4)
+
+        elif resblock_type == 'biggan':
+            ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
+                dropout=dropout, fir=fir, fir_kernel=fir_kernel,
+                init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
+
+        else:
+            raise ValueError(f'resblock type {resblock_type} unrecognized.')
+
+        # Downsampling block
+
+        channels = num_channels
+        if progressive_input != 'none':
+            input_pyramid_ch = channels
+
+        modules.append(conv3x3(channels, nf))
+        hs_c = [nf]
+
+        in_ch = nf
+        for i_level in range(num_resolutions):
+            # Residual blocks for this resolution
+            for i_block in range(num_res_blocks):
+                out_ch = nf * ch_mult[i_level]
+                modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
+                in_ch = out_ch
+
+                if all_resolutions[i_level] in attn_resolutions:
+                    modules.append(AttnBlock(channels=in_ch))
+                hs_c.append(in_ch)
+
+            if i_level != num_resolutions - 1:
+                if resblock_type == 'ddpm':
+                    modules.append(Downsample(in_ch=in_ch))
+                else:
+                    modules.append(ResnetBlock(down=True, in_ch=in_ch))
+
+                if progressive_input == 'input_skip':
+                    modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
+                    if combine_method == 'cat':
+                        in_ch *= 2
+
+                elif progressive_input == 'residual':
+                    modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
+                    input_pyramid_ch = in_ch
+
+                hs_c.append(in_ch)
+
+        in_ch = hs_c[-1]
+        modules.append(ResnetBlock(in_ch=in_ch))
+        modules.append(AttnBlock(channels=in_ch))
+        modules.append(ResnetBlock(in_ch=in_ch))
+
+        pyramid_ch = 0
+        # Upsampling block
+        for i_level in reversed(range(num_resolutions)):
+            for i_block in range(num_res_blocks + 1):  # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
+                out_ch = nf * ch_mult[i_level]
+                modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
+                in_ch = out_ch
+
+            if all_resolutions[i_level] in attn_resolutions:
+                modules.append(AttnBlock(channels=in_ch))
+
+            if progressive != 'none':
+                if i_level == num_resolutions - 1:
+                    if progressive == 'output_skip':
+                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
+                            num_channels=in_ch, eps=1e-6))
+                        modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
+                        pyramid_ch = channels
+                    elif progressive == 'residual':
+                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
+                        modules.append(conv3x3(in_ch, in_ch, bias=True))
+                        pyramid_ch = in_ch
+                    else:
+                        raise ValueError(f'{progressive} is not a valid name.')
+                else:
+                    if progressive == 'output_skip':
+                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
+                            num_channels=in_ch, eps=1e-6))
+                        modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
+                        pyramid_ch = channels
+                    elif progressive == 'residual':
+                        modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
+                        pyramid_ch = in_ch
+                    else:
+                        raise ValueError(f'{progressive} is not a valid name')
+
+            if i_level != 0:
+                if resblock_type == 'ddpm':
+                    modules.append(Upsample(in_ch=in_ch))
+                else:
+                    modules.append(ResnetBlock(in_ch=in_ch, up=True))
+
+        assert not hs_c
+
+        if progressive != 'output_skip':
+            modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
+                                                                    num_channels=in_ch, eps=1e-6))
+            modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
+
+        self.all_modules = nn.ModuleList(modules)
+        
+
+    def forward(self, x, time_cond):
+        # timestep/noise_level embedding; only for continuous training
+        modules = self.all_modules
+        m_idx = 0
+
+        # Convert real and imaginary parts of (x,y) into four channel dimensions
+        x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
+                x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
+
+        if self.embedding_type == 'fourier':
+            # Gaussian Fourier features embeddings.
+            used_sigmas = time_cond
+            temb = modules[m_idx](torch.log(used_sigmas))
+            m_idx += 1
+
+        elif self.embedding_type == 'positional':
+            # Sinusoidal positional embeddings.
+            timesteps = time_cond
+            used_sigmas = self.sigmas[time_cond.long()]
+            temb = layers.get_timestep_embedding(timesteps, self.nf)
+
+        else:
+            raise ValueError(f'embedding type {self.embedding_type} unknown.')
+
+        if self.conditional:
+            temb = modules[m_idx](temb)
+            m_idx += 1
+            temb = modules[m_idx](self.act(temb))
+            m_idx += 1
+        else:
+            temb = None
+
+        if not self.centered:
+            # If input data is in [0, 1]
+            x = 2 * x - 1.
+
+        # Downsampling block
+        input_pyramid = None
+        if self.progressive_input != 'none':
+            input_pyramid = x
+
+        # Input layer: Conv2d: 4ch -> 128ch
+        hs = [modules[m_idx](x)]
+        m_idx += 1
+
+        # Down path in U-Net
+        for i_level in range(self.num_resolutions):
+            # Residual blocks for this resolution
+            for i_block in range(self.num_res_blocks):
+                h = modules[m_idx](hs[-1], temb)
+                m_idx += 1
+                # Attention layer (optional)
+                if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
+                    h = modules[m_idx](h)
+                    m_idx += 1
+                hs.append(h)
+
+            # Downsampling
+            if i_level != self.num_resolutions - 1:
+                if self.resblock_type == 'ddpm':
+                    h = modules[m_idx](hs[-1])
+                    m_idx += 1
+                else:
+                    h = modules[m_idx](hs[-1], temb)
+                    m_idx += 1
+
+                if self.progressive_input == 'input_skip':   # Combine h with x
+                    input_pyramid = self.pyramid_downsample(input_pyramid)
+                    h = modules[m_idx](input_pyramid, h)
+                    m_idx += 1
+
+                elif self.progressive_input == 'residual':
+                    input_pyramid = modules[m_idx](input_pyramid)
+                    m_idx += 1
+                    if self.skip_rescale:
+                        input_pyramid = (input_pyramid + h) / np.sqrt(2.)
+                    else:
+                        input_pyramid = input_pyramid + h
+                    h = input_pyramid
+                hs.append(h)
+
+        h = hs[-1] # actualy equal to: h = h
+        h = modules[m_idx](h, temb)  # ResNet block
+        m_idx += 1
+        h = modules[m_idx](h)  # Attention block
+        m_idx += 1
+        h = modules[m_idx](h, temb)  # ResNet block
+        m_idx += 1
+
+        pyramid = None
+
+        # Upsampling block
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
+                m_idx += 1
+
+            # edit: from -1 to -2
+            if h.shape[-2] in self.attn_resolutions:
+                h = modules[m_idx](h)
+                m_idx += 1
+
+            if self.progressive != 'none':
+                if i_level == self.num_resolutions - 1:
+                    if self.progressive == 'output_skip':
+                        pyramid = self.act(modules[m_idx](h))  # GroupNorm
+                        m_idx += 1
+                        pyramid = modules[m_idx](pyramid)  # Conv2D: 256 -> 4
+                        m_idx += 1
+                    elif self.progressive == 'residual':
+                        pyramid = self.act(modules[m_idx](h))
+                        m_idx += 1
+                        pyramid = modules[m_idx](pyramid)
+                        m_idx += 1
+                    else:
+                        raise ValueError(f'{self.progressive} is not a valid name.')
+                else:
+                    if self.progressive == 'output_skip':
+                        pyramid = self.pyramid_upsample(pyramid)  # Upsample
+                        pyramid_h = self.act(modules[m_idx](h))  # GroupNorm
+                        m_idx += 1
+                        pyramid_h = modules[m_idx](pyramid_h)
+                        m_idx += 1
+                        pyramid = pyramid + pyramid_h
+                    elif self.progressive == 'residual':
+                        pyramid = modules[m_idx](pyramid)
+                        m_idx += 1
+                        if self.skip_rescale:
+                            pyramid = (pyramid + h) / np.sqrt(2.)
+                        else:
+                            pyramid = pyramid + h
+                        h = pyramid
+                    else:
+                        raise ValueError(f'{self.progressive} is not a valid name')
+
+            # Upsampling Layer
+            if i_level != 0:
+                if self.resblock_type == 'ddpm':
+                    h = modules[m_idx](h)
+                    m_idx += 1
+                else:
+                    h = modules[m_idx](h, temb)  # Upspampling
+                    m_idx += 1
+
+        assert not hs
+
+        if self.progressive == 'output_skip':
+            h = pyramid
+        else:
+            h = self.act(modules[m_idx](h))
+            m_idx += 1
+            h = modules[m_idx](h)
+            m_idx += 1
+
+        assert m_idx == len(modules), "Implementation error"
+        
+        # Convert back to complex number
+        h = self.output_layer(h)
+
+        if self.scale_by_sigma:
+            used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
+            h = h / used_sigmas
+
+        h = torch.permute(h, (0, 2, 3, 1)).contiguous()
+        h = torch.view_as_complex(h)[:,None, :, :]
+        return h
diff --git a/sgmse/sampling/__init__.py b/sgmse/sampling/__init__.py
index 0046a1c1948087d6d97871c4d251a6d2c8daf075..44ee62a5e02b3e3ee0900469f5150f4e22be74d0 100644
--- a/sgmse/sampling/__init__.py
+++ b/sgmse/sampling/__init__.py
@@ -56,9 +56,13 @@ def get_pc_sampler(
             timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
             for i in range(sde.N):
                 t = timesteps[i]
+                if i != len(timesteps) - 1:
+                    stepsize = t - timesteps[i+1]
+                else:
+                    stepsize = timesteps[-1] # from eps to 0
                 vec_t = torch.ones(y.shape[0], device=y.device) * t
                 xt, xt_mean = corrector.update_fn(xt, vec_t, y)
-                xt, xt_mean = predictor.update_fn(xt, vec_t, y)
+                xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
             x_result = xt_mean if denoise else xt
             ns = sde.N * (corrector.n_steps + 1)
             return x_result, ns
diff --git a/sgmse/sampling/predictors.py b/sgmse/sampling/predictors.py
index 93437d9fe3398494fbda113428292c63dbd98d02..d5520ddfd7bf54cdedc175f44e821370c190b1e8 100644
--- a/sgmse/sampling/predictors.py
+++ b/sgmse/sampling/predictors.py
@@ -57,8 +57,8 @@ class ReverseDiffusionPredictor(Predictor):
     def __init__(self, sde, score_fn, probability_flow=False):
         super().__init__(sde, score_fn, probability_flow=probability_flow)
 
-    def update_fn(self, x, t, *args):
-        f, g = self.rsde.discretize(x, t, *args)
+    def update_fn(self, x, t, y, stepsize):
+        f, g = self.rsde.discretize(x, t, y, stepsize)
         z = torch.randn_like(x)
         x_mean = x - f
         x = x_mean + g[:, None, None, None] * z
diff --git a/sgmse/sdes.py b/sgmse/sdes.py
index affa02e14e47401c14e3a706480c7bfb2bb8403d..6c6a230945b0e5dbc30e5f0896131cb0d1ba9bba 100644
--- a/sgmse/sdes.py
+++ b/sgmse/sdes.py
@@ -69,7 +69,7 @@ class SDE(abc.ABC):
         """
         pass
 
-    def discretize(self, x, t, *args):
+    def discretize(self, x, t, y, stepsize):
         """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
 
         Useful for reverse diffusion sampling and probabiliy flow sampling.
@@ -82,10 +82,10 @@ class SDE(abc.ABC):
         Returns:
             f, G
         """
-        dt = 1 / self.N
-        drift, diffusion = self.sde(x, t, *args)
+        dt = stepsize
+        drift, diffusion = self.sde(x, t, y)
         f = drift * dt
-        G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
+        G = diffusion * torch.sqrt(dt)
         return f, G
 
     def reverse(oself, score_model, probability_flow=False):
@@ -127,10 +127,10 @@ class SDE(abc.ABC):
                     'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
                 }
 
-            def discretize(self, x, t, *args):
+            def discretize(self, x, t, y, stepsize):
                 """Create discretized iteration rules for the reverse diffusion sampler."""
-                f, G = discretize_fn(x, t, *args)
-                rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * (0.5 if self.probability_flow else 1.)
+                f, G = discretize_fn(x, t, y, stepsize)
+                rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
                 rev_G = torch.zeros_like(G) if self.probability_flow else G
                 return rev_f, rev_G
 
@@ -198,6 +198,9 @@ class OUVESDE(SDE):
         exp_interp = torch.exp(-theta * t)[:, None, None, None]
         return exp_interp * x0 + (1 - exp_interp) * y
 
+    def alpha(self, t):
+        return torch.exp(-self.theta * t)
+
     def _std(self, t):
         # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
         sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
diff --git a/sgmse/util/other.py b/sgmse/util/other.py
index 565b1cee67fcdabf667f27784db4f9db14982c4c..fbe5052f348def7eb05ed4682cfed3700f8e2ab0 100644
--- a/sgmse/util/other.py
+++ b/sgmse/util/other.py
@@ -80,16 +80,22 @@ def snr_dB(s,n):
     snr_dB = 10*np.log10(s_power/n_power)
     return snr_dB
 
-def pad_spec(Y):
+def pad_spec(Y, mode="zero_pad"):
     T = Y.size(3)
     if T%64 !=0:
         num_pad = 64-T%64
     else:
         num_pad = 0
-    pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
+    if mode == "zero_pad":
+        pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
+    elif mode == "reflection":
+        pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
+    elif mode == "replication":
+        pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
+    else:
+        raise NotImplementedError("This function hasn't been implemented yet.")
     return pad2d(Y)
 
-
 def ensure_dir(file_path):
     directory = file_path
     if not os.path.exists(directory):