Skip to content
Snippets Groups Projects
Commit a18cdb7b authored by jrichter's avatar jrichter
Browse files

set_torch_cuda_arch_list, add --log_dir argument

parent e9d2773b
No related branches found
No related tags found
No related merge requests found
import wandb
import argparse
from argparse import ArgumentParser
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from os.path import join
import wandb
# Set CUDA architecture list
from sgmse.util.other import set_torch_cuda_arch_list
set_torch_cuda_arch_list()
from sgmse.backbones.shared import BackboneRegistry
from sgmse.data_module import SpecsDataModule
......@@ -31,6 +35,7 @@ if __name__ == '__main__':
parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
temp_args, _ = base_parser.parse_known_args()
......@@ -76,11 +81,11 @@ if __name__ == '__main__':
# Set up callbacks for logger
if logger != None:
callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
if args.num_eval_files:
checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}",
checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}",
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
else:
......@@ -95,4 +100,4 @@ if __name__ == '__main__':
)
# Train model
trainer.fit(model, ckpt_path=args.ckpt)
trainer.fit(model, ckpt_path=args.ckpt)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment