Skip to content
Snippets Groups Projects
Commit b7fa075e authored by jrichter's avatar jrichter
Browse files

Support for torch>=2.0 and pytorch-lightning>=2.0

parent c6e3291e
Branches
No related tags found
No related merge requests found
...@@ -332,3 +332,4 @@ $RECYCLE.BIN/ ...@@ -332,3 +332,4 @@ $RECYCLE.BIN/
/wandb /wandb
/lightning_logs /lightning_logs
/logs /logs
/jobs
\ No newline at end of file
...@@ -20,12 +20,13 @@ Please also check out our follow-up work with code available: ...@@ -20,12 +20,13 @@ Please also check out our follow-up work with code available:
- Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work). - Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
- Install the package dependencies via `pip install -r requirements.txt`. - Install the package dependencies via `pip install -r requirements.txt`.
- Let pip resolve the dependencies for you. If you encounter any issues, please check `requirements_version.txt` for the exact versions we used.
- If using W&B logging (default): - If using W&B logging (default):
- Set up a [wandb.ai](https://wandb.ai/) account - Set up a [wandb.ai](https://wandb.ai/) account
- Log in via `wandb login` before running our code. - Log in via `wandb login` before running our code.
- If not using W&B logging: - If not using W&B logging:
- Pass the option `--no_wandb` to `train.py`. - Pass the option `--nolog` to `train.py`.
- Your logs will be stored as local TensorBoard logs. Run `tensorboard --logdir logs/` to see them. - Your logs will be stored as local CSVLogger logs in `lightning_logs/`.
## Pretrained checkpoints ## Pretrained checkpoints
...@@ -35,7 +36,7 @@ Please also check out our follow-up work with code available: ...@@ -35,7 +36,7 @@ Please also check out our follow-up work with code available:
- Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`. - Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`.
Usage: Usage:
- For resuming training, you can use the `--resume_from_checkpoint` option of `train.py`. - For resuming training, you can use the `--ckpt` option of `train.py`.
- For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below). - For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below).
......
h5py==3.6.0 h5py
ipympl==0.8.8 ipympl
ipywidgets==7.6.5 librosa
jupyter==1.0.0 ninja
jupyter-client==6.1.12 numpy
jupyter-console==6.4.0 pandas
jupyter-core==4.7.1 pesq
jupyterlab-pygments==0.1.2 pillow
jupyterlab-widgets==1.0.2 protobuf
librosa==0.9.1 pyarrow
Ninja==1.10.2.3 pyroomacoustics
numpy==1.22.2 pystoi
pandas==1.4.0 pytorch-lightning
pesq==0.0.4 scipy
Pillow==9.0.1 sdeint
protobuf==3.19.4 setuptools
pyroomacoustics==0.6.0 seaborn
pystoi==0.3.3 torch
pytorch-lightning==1.6.5 torch-ema
scipy==1.8.0 torchaudio
sdeint==0.2.4 torchvision
setuptools==59.5.0 # fixes https://github.com/pytorch/pytorch/issues/69894 torchinfo
seaborn==0.11.2 torchsde
torch==1.12.0 tqdm
torch-ema==0.3 wandb
torchaudio==0.12.0 \ No newline at end of file
torchvision==0.13.0
torchinfo==1.6.3
torchsde==0.2.5
tqdm==4.63.0
wandb==0.12.11
h5py==3.10.0
ipympl==0.9.3
librosa==0.10.1
ninja==1.11.1.1
numpy==1.24.4
pandas==2.0.3
pesq==0.0.4
pillow==10.2.0
protobuf==4.25.2
pyarrow==15.0.0
pyroomacoustics==0.7.3
pystoi==0.4.1
pytorch-lightning==2.1.4
scipy==1.10.1
sdeint==0.3.0
setuptools==44.0.0
seaborn==0.13.2
torch==2.2.0
torch-ema==0.3
torchaudio==2.2.0
torchvision==0.17.0
torchinfo==1.8.0
torchsde==0.2.6
tqdm==4.66.1
wandb==0.16.2
\ No newline at end of file
...@@ -2,8 +2,7 @@ import argparse ...@@ -2,8 +2,7 @@ import argparse
from argparse import ArgumentParser from argparse import ArgumentParser
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
import wandb import wandb
...@@ -29,14 +28,20 @@ if __name__ == '__main__': ...@@ -29,14 +28,20 @@ if __name__ == '__main__':
for parser_ in (base_parser, parser): for parser_ in (base_parser, parser):
parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp") parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp")
parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve") parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve")
parser_.add_argument("--no_wandb", action='store_true', help="Turn off logging to W&B, using local default logger instead") parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
temp_args, _ = base_parser.parse_known_args() temp_args, _ = base_parser.parse_known_args()
# Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class # Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class
backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone) backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone)
sde_class = SDERegistry.get_by_name(temp_args.sde) sde_class = SDERegistry.get_by_name(temp_args.sde)
parser = pl.Trainer.add_argparse_args(parser) trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer")
trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.")
trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.")
trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.")
ScoreModel.add_argparse_args( ScoreModel.add_argparse_args(
parser.add_argument_group("ScoreModel", description=ScoreModel.__name__)) parser.add_argument_group("ScoreModel", description=ScoreModel.__name__))
sde_class.add_argparse_args( sde_class.add_argparse_args(
...@@ -63,13 +68,14 @@ if __name__ == '__main__': ...@@ -63,13 +68,14 @@ if __name__ == '__main__':
) )
# Set up logger configuration # Set up logger configuration
if args.no_wandb: if args.nolog:
logger = TensorBoardLogger(save_dir="logs", name="tensorboard") logger = None
else: else:
logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs") logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name)
logger.experiment.log_code(".") logger.experiment.log_code(".")
# Set up callbacks for logger # Set up callbacks for logger
if logger != None:
callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
if args.num_eval_files: if args.num_eval_files:
checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}",
...@@ -77,14 +83,16 @@ if __name__ == '__main__': ...@@ -77,14 +83,16 @@ if __name__ == '__main__':
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}",
save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
else:
callbacks = None
# Initialize the Trainer and the DataModule # Initialize the Trainer and the DataModule
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
arg_groups['pl.Trainer'], **vars(arg_groups['Trainer']),
strategy=DDPPlugin(find_unused_parameters=False), logger=logger, strategy="ddp", logger=logger,
log_every_n_steps=10, num_sanity_val_steps=0, log_every_n_steps=10, num_sanity_val_steps=0,
callbacks=callbacks callbacks=callbacks
) )
# Train model # Train model
trainer.fit(model) trainer.fit(model, ckpt_path=args.ckpt)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment