diff --git a/sgmse/backbones/ncsnpp_utils/op/__init__.py b/sgmse/backbones/ncsnpp_utils/op/__init__.py index d0918d92285955855be89f00096b888ee5597ce3..857a9fc459361c342ff4f0e53085ab8712b6609e 100755 --- a/sgmse/backbones/ncsnpp_utils/op/__init__.py +++ b/sgmse/backbones/ncsnpp_utils/op/__init__.py @@ -1,2 +1 @@ -from .fused_act import FusedLeakyReLU, fused_leaky_relu -from .upfirdn2d import upfirdn2d +from .upfirdn2d import upfirdn2d \ No newline at end of file diff --git a/sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py b/sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py index a4cf05dbcb2628fff56b0cb917c2941ce4a5f18b..0f4dc563d23a336fa129a925040d8ef7303b878b 100755 --- a/sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py +++ b/sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py @@ -7,14 +7,17 @@ from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) -upfirdn2d_op = load( - "upfirdn2d", - sources=[ - os.path.join(module_path, "upfirdn2d.cpp"), - os.path.join(module_path, "upfirdn2d_kernel.cu"), - ], -) +if torch.cuda.is_available(): + upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], + ) +else: + upfirdn2d_op = None class UpFirDn2dBackward(Function): @staticmethod