diff --git a/train.py b/train.py
index 56266debe9a3362a7d50e0d3e362e492a6acbb86..0597208b99bd6e81ab328d9d2a33997ca3aa9f77 100644
--- a/train.py
+++ b/train.py
@@ -1,11 +1,15 @@
+import wandb
 import argparse
-from argparse import ArgumentParser
-
 import pytorch_lightning as pl
+
+from argparse import ArgumentParser
 from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.callbacks import ModelCheckpoint
+from os.path import join
 
-import wandb
+# Set CUDA architecture list
+from sgmse.util.other import set_torch_cuda_arch_list
+set_torch_cuda_arch_list()
 
 from sgmse.backbones.shared import BackboneRegistry
 from sgmse.data_module import SpecsDataModule
@@ -31,6 +35,7 @@ if __name__ == '__main__':
           parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
           parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
           parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
+          parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
           
      temp_args, _ = base_parser.parse_known_args()
 
@@ -76,11 +81,11 @@ if __name__ == '__main__':
 
      # Set up callbacks for logger
      if logger != None:
-          callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
+          callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
           if args.num_eval_files:
-               checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
-               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
                callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
      else:
@@ -95,4 +100,4 @@ if __name__ == '__main__':
      )
 
      # Train model
-     trainer.fit(model, ckpt_path=args.ckpt)
+     trainer.fit(model, ckpt_path=args.ckpt)
\ No newline at end of file