From b7fa075eda56525c08280e449491e14a38ae22d9 Mon Sep 17 00:00:00 2001
From: jrichter <jrichter@exchange.informatik.uni-hamburg.de>
Date: Sun, 4 Feb 2024 14:10:48 +0000
Subject: [PATCH] Support for torch>=2.0 and pytorch-lightning>=2.0

---
 .gitignore               |  3 ++-
 README.md                |  7 ++---
 requirements.txt         | 56 ++++++++++++++++++----------------------
 requirements_version.txt | 25 ++++++++++++++++++
 train.py                 | 44 ++++++++++++++++++-------------
 5 files changed, 82 insertions(+), 53 deletions(-)
 create mode 100644 requirements_version.txt

diff --git a/.gitignore b/.gitignore
index cb13be1..dff23bb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -331,4 +331,5 @@ $RECYCLE.BIN/
 ## Logs from W&B and PyTorch Lightning
 /wandb
 /lightning_logs
-/logs
\ No newline at end of file
+/logs
+/jobs
\ No newline at end of file
diff --git a/README.md b/README.md
index b089808..7f32de7 100644
--- a/README.md
+++ b/README.md
@@ -20,12 +20,13 @@ Please also check out our follow-up work with code available:
 
 - Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
 - Install the package dependencies via `pip install -r requirements.txt`.
+  - Let pip resolve the dependencies for you. If you encounter any issues, please check `requirements_version.txt` for the exact versions we used.
 - If using W&B logging (default):
     - Set up a [wandb.ai](https://wandb.ai/) account
     - Log in via `wandb login` before running our code.
 - If not using W&B logging:
-    - Pass the option `--no_wandb` to `train.py`.
-    - Your logs will be stored as local TensorBoard logs. Run `tensorboard --logdir logs/` to see them.
+    - Pass the option `--nolog` to `train.py`.
+    - Your logs will be stored as local CSVLogger logs in `lightning_logs/`.
 
 
 ## Pretrained checkpoints
@@ -35,7 +36,7 @@ Please also check out our follow-up work with code available:
     - Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`.
 
 Usage:
-- For resuming training, you can use the `--resume_from_checkpoint` option of `train.py`.
+- For resuming training, you can use the `--ckpt` option of `train.py`.
 - For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below).
 
 
diff --git a/requirements.txt b/requirements.txt
index 2f97aa7..72562b0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,31 +1,25 @@
-h5py==3.6.0
-ipympl==0.8.8
-ipywidgets==7.6.5
-jupyter==1.0.0
-jupyter-client==6.1.12
-jupyter-console==6.4.0
-jupyter-core==4.7.1
-jupyterlab-pygments==0.1.2
-jupyterlab-widgets==1.0.2
-librosa==0.9.1
-Ninja==1.10.2.3
-numpy==1.22.2
-pandas==1.4.0
-pesq==0.0.4
-Pillow==9.0.1
-protobuf==3.19.4
-pyroomacoustics==0.6.0
-pystoi==0.3.3
-pytorch-lightning==1.6.5
-scipy==1.8.0
-sdeint==0.2.4
-setuptools==59.5.0  # fixes https://github.com/pytorch/pytorch/issues/69894
-seaborn==0.11.2
-torch==1.12.0
-torch-ema==0.3
-torchaudio==0.12.0
-torchvision==0.13.0
-torchinfo==1.6.3
-torchsde==0.2.5
-tqdm==4.63.0
-wandb==0.12.11
+h5py 
+ipympl 
+librosa 
+ninja 
+numpy
+pandas 
+pesq 
+pillow 
+protobuf 
+pyarrow
+pyroomacoustics 
+pystoi 
+pytorch-lightning 
+scipy 
+sdeint 
+setuptools 
+seaborn 
+torch 
+torch-ema 
+torchaudio
+torchvision 
+torchinfo 
+torchsde 
+tqdm 
+wandb
\ No newline at end of file
diff --git a/requirements_version.txt b/requirements_version.txt
new file mode 100644
index 0000000..90cfc2f
--- /dev/null
+++ b/requirements_version.txt
@@ -0,0 +1,25 @@
+h5py==3.10.0
+ipympl==0.9.3
+librosa==0.10.1
+ninja==1.11.1.1
+numpy==1.24.4
+pandas==2.0.3
+pesq==0.0.4
+pillow==10.2.0
+protobuf==4.25.2
+pyarrow==15.0.0
+pyroomacoustics==0.7.3
+pystoi==0.4.1
+pytorch-lightning==2.1.4
+scipy==1.10.1
+sdeint==0.3.0
+setuptools==44.0.0
+seaborn==0.13.2
+torch==2.2.0
+torch-ema==0.3
+torchaudio==2.2.0
+torchvision==0.17.0
+torchinfo==1.8.0
+torchsde==0.2.6
+tqdm==4.66.1
+wandb==0.16.2
\ No newline at end of file
diff --git a/train.py b/train.py
index a82d9c8..56266de 100644
--- a/train.py
+++ b/train.py
@@ -2,8 +2,7 @@ import argparse
 from argparse import ArgumentParser
 
 import pytorch_lightning as pl
-from pytorch_lightning.plugins import DDPPlugin
-from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
+from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.callbacks import ModelCheckpoint
 
 import wandb
@@ -29,14 +28,20 @@ if __name__ == '__main__':
      for parser_ in (base_parser, parser):
           parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp")
           parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve")
-          parser_.add_argument("--no_wandb", action='store_true', help="Turn off logging to W&B, using local default logger instead")
+          parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
+          parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
+          parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
           
      temp_args, _ = base_parser.parse_known_args()
 
      # Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class
      backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone)
      sde_class = SDERegistry.get_by_name(temp_args.sde)
-     parser = pl.Trainer.add_argparse_args(parser)
+     trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer")
+     trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.")
+     trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.")
+     trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.")
+     
      ScoreModel.add_argparse_args(
           parser.add_argument_group("ScoreModel", description=ScoreModel.__name__))
      sde_class.add_argparse_args(
@@ -63,28 +68,31 @@ if __name__ == '__main__':
      )
 
      # Set up logger configuration
-     if args.no_wandb:
-          logger = TensorBoardLogger(save_dir="logs", name="tensorboard")
+     if args.nolog:
+          logger = None
      else:
-          logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs")
+          logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name)
           logger.experiment.log_code(".")
 
      # Set up callbacks for logger
-     callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
-     if args.num_eval_files:
-          checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
-               save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
-          checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
-               save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
-          callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
+     if logger != None:
+          callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
+          if args.num_eval_files:
+               checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+                    save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
+               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+                    save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
+               callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
+     else:
+          callbacks = None
 
      # Initialize the Trainer and the DataModule
-     trainer = pl.Trainer.from_argparse_args(
-          arg_groups['pl.Trainer'],
-          strategy=DDPPlugin(find_unused_parameters=False), logger=logger,
+     trainer = pl.Trainer(
+          **vars(arg_groups['Trainer']),
+          strategy="ddp", logger=logger,
           log_every_n_steps=10, num_sanity_val_steps=0,
           callbacks=callbacks
      )
 
      # Train model
-     trainer.fit(model)
+     trainer.fit(model, ckpt_path=args.ckpt)
-- 
GitLab