From a18cdb7b5e393c68e396f3f5919ec757bdf45ed0 Mon Sep 17 00:00:00 2001
From: jrichter <jrichter@exchange.informatik.uni-hamburg.de>
Date: Tue, 18 Jun 2024 08:24:26 +0200
Subject: [PATCH] set_torch_cuda_arch_list, add --log_dir argument

---
 train.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/train.py b/train.py
index 56266de..0597208 100644
--- a/train.py
+++ b/train.py
@@ -1,11 +1,15 @@
+import wandb
 import argparse
-from argparse import ArgumentParser
-
 import pytorch_lightning as pl
+
+from argparse import ArgumentParser
 from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.callbacks import ModelCheckpoint
+from os.path import join
 
-import wandb
+# Set CUDA architecture list
+from sgmse.util.other import set_torch_cuda_arch_list
+set_torch_cuda_arch_list()
 
 from sgmse.backbones.shared import BackboneRegistry
 from sgmse.data_module import SpecsDataModule
@@ -31,6 +35,7 @@ if __name__ == '__main__':
           parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
           parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
           parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
+          parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
           
      temp_args, _ = base_parser.parse_known_args()
 
@@ -76,11 +81,11 @@ if __name__ == '__main__':
 
      # Set up callbacks for logger
      if logger != None:
-          callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
+          callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
           if args.num_eval_files:
-               checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
-               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
                callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
      else:
@@ -95,4 +100,4 @@ if __name__ == '__main__':
      )
 
      # Train model
-     trainer.fit(model, ckpt_path=args.ckpt)
+     trainer.fit(model, ckpt_path=args.ckpt)
\ No newline at end of file
-- 
GitLab