diff --git a/.gitignore b/.gitignore index cb13be1f59627e9ce555651c24906c4914e868c3..dff23bbbc7e2ca0bcf542dd35c68c829c34eef50 100644 --- a/.gitignore +++ b/.gitignore @@ -331,4 +331,5 @@ $RECYCLE.BIN/ ## Logs from W&B and PyTorch Lightning /wandb /lightning_logs -/logs \ No newline at end of file +/logs +/jobs \ No newline at end of file diff --git a/README.md b/README.md index b089808f8dcb3ce3dfe0294c24605fd010858d9f..7f32de7869393e00fdaaa0df774996ae63efcf7d 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,13 @@ Please also check out our follow-up work with code available: - Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work). - Install the package dependencies via `pip install -r requirements.txt`. + - Let pip resolve the dependencies for you. If you encounter any issues, please check `requirements_version.txt` for the exact versions we used. - If using W&B logging (default): - Set up a [wandb.ai](https://wandb.ai/) account - Log in via `wandb login` before running our code. - If not using W&B logging: - - Pass the option `--no_wandb` to `train.py`. - - Your logs will be stored as local TensorBoard logs. Run `tensorboard --logdir logs/` to see them. + - Pass the option `--nolog` to `train.py`. + - Your logs will be stored as local CSVLogger logs in `lightning_logs/`. ## Pretrained checkpoints @@ -35,7 +36,7 @@ Please also check out our follow-up work with code available: - Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`. Usage: -- For resuming training, you can use the `--resume_from_checkpoint` option of `train.py`. +- For resuming training, you can use the `--ckpt` option of `train.py`. - For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below). diff --git a/requirements.txt b/requirements.txt index 2f97aa701c6b3dc26a88c92ed6054ea060a581e6..72562b0cbde42251b113ff95137bf54e120578ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,25 @@ -h5py==3.6.0 -ipympl==0.8.8 -ipywidgets==7.6.5 -jupyter==1.0.0 -jupyter-client==6.1.12 -jupyter-console==6.4.0 -jupyter-core==4.7.1 -jupyterlab-pygments==0.1.2 -jupyterlab-widgets==1.0.2 -librosa==0.9.1 -Ninja==1.10.2.3 -numpy==1.22.2 -pandas==1.4.0 -pesq==0.0.4 -Pillow==9.0.1 -protobuf==3.19.4 -pyroomacoustics==0.6.0 -pystoi==0.3.3 -pytorch-lightning==1.6.5 -scipy==1.8.0 -sdeint==0.2.4 -setuptools==59.5.0 # fixes https://github.com/pytorch/pytorch/issues/69894 -seaborn==0.11.2 -torch==1.12.0 -torch-ema==0.3 -torchaudio==0.12.0 -torchvision==0.13.0 -torchinfo==1.6.3 -torchsde==0.2.5 -tqdm==4.63.0 -wandb==0.12.11 +h5py +ipympl +librosa +ninja +numpy +pandas +pesq +pillow +protobuf +pyarrow +pyroomacoustics +pystoi +pytorch-lightning +scipy +sdeint +setuptools +seaborn +torch +torch-ema +torchaudio +torchvision +torchinfo +torchsde +tqdm +wandb \ No newline at end of file diff --git a/requirements_version.txt b/requirements_version.txt new file mode 100644 index 0000000000000000000000000000000000000000..90cfc2f8ffcaed3d5a801972fe86890519e4536d --- /dev/null +++ b/requirements_version.txt @@ -0,0 +1,25 @@ +h5py==3.10.0 +ipympl==0.9.3 +librosa==0.10.1 +ninja==1.11.1.1 +numpy==1.24.4 +pandas==2.0.3 +pesq==0.0.4 +pillow==10.2.0 +protobuf==4.25.2 +pyarrow==15.0.0 +pyroomacoustics==0.7.3 +pystoi==0.4.1 +pytorch-lightning==2.1.4 +scipy==1.10.1 +sdeint==0.3.0 +setuptools==44.0.0 +seaborn==0.13.2 +torch==2.2.0 +torch-ema==0.3 +torchaudio==2.2.0 +torchvision==0.17.0 +torchinfo==1.8.0 +torchsde==0.2.6 +tqdm==4.66.1 +wandb==0.16.2 \ No newline at end of file diff --git a/train.py b/train.py index a82d9c850c8f5098745d8f5542838817318c9257..56266debe9a3362a7d50e0d3e362e492a6acbb86 100644 --- a/train.py +++ b/train.py @@ -2,8 +2,7 @@ import argparse from argparse import ArgumentParser import pytorch_lightning as pl -from pytorch_lightning.plugins import DDPPlugin -from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger +from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint import wandb @@ -29,14 +28,20 @@ if __name__ == '__main__': for parser_ in (base_parser, parser): parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp") parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve") - parser_.add_argument("--no_wandb", action='store_true', help="Turn off logging to W&B, using local default logger instead") + parser_.add_argument("--nolog", action='store_true', help="Turn off logging.") + parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.") + parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.") temp_args, _ = base_parser.parse_known_args() # Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone) sde_class = SDERegistry.get_by_name(temp_args.sde) - parser = pl.Trainer.add_argparse_args(parser) + trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer") + trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.") + trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.") + trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.") + ScoreModel.add_argparse_args( parser.add_argument_group("ScoreModel", description=ScoreModel.__name__)) sde_class.add_argparse_args( @@ -63,28 +68,31 @@ if __name__ == '__main__': ) # Set up logger configuration - if args.no_wandb: - logger = TensorBoardLogger(save_dir="logs", name="tensorboard") + if args.nolog: + logger = None else: - logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs") + logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name) logger.experiment.log_code(".") # Set up callbacks for logger - callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] - if args.num_eval_files: - checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", - save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') - checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", - save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') - callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] + if logger != None: + callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] + if args.num_eval_files: + checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", + save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') + checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", + save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') + callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] + else: + callbacks = None # Initialize the Trainer and the DataModule - trainer = pl.Trainer.from_argparse_args( - arg_groups['pl.Trainer'], - strategy=DDPPlugin(find_unused_parameters=False), logger=logger, + trainer = pl.Trainer( + **vars(arg_groups['Trainer']), + strategy="ddp", logger=logger, log_every_n_steps=10, num_sanity_val_steps=0, callbacks=callbacks ) # Train model - trainer.fit(model) + trainer.fit(model, ckpt_path=args.ckpt)