diff --git a/train.py b/train.py index 56266debe9a3362a7d50e0d3e362e492a6acbb86..0597208b99bd6e81ab328d9d2a33997ca3aa9f77 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,15 @@ +import wandb import argparse -from argparse import ArgumentParser - import pytorch_lightning as pl + +from argparse import ArgumentParser from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint +from os.path import join -import wandb +# Set CUDA architecture list +from sgmse.util.other import set_torch_cuda_arch_list +set_torch_cuda_arch_list() from sgmse.backbones.shared import BackboneRegistry from sgmse.data_module import SpecsDataModule @@ -31,6 +35,7 @@ if __name__ == '__main__': parser_.add_argument("--nolog", action='store_true', help="Turn off logging.") parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.") parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.") + parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.") temp_args, _ = base_parser.parse_known_args() @@ -76,11 +81,11 @@ if __name__ == '__main__': # Set up callbacks for logger if logger != None: - callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] + callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')] if args.num_eval_files: - checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", + checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') - checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", + checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] else: @@ -95,4 +100,4 @@ if __name__ == '__main__': ) # Train model - trainer.fit(model, ckpt_path=args.ckpt) + trainer.fit(model, ckpt_path=args.ckpt) \ No newline at end of file