From d1b84282718392fa95d189249c63a6ec6edda2c7 Mon Sep 17 00:00:00 2001
From: bay9355 <mia.le@studium.uni-hamburg.de>
Date: Mon, 13 Mar 2023 23:37:20 +0100
Subject: [PATCH] seed variation implementation with iterative find_bridges
 function

---
 cami_src/algorithms/AlgorithmWrapper.py |  4 +-
 cami_src/algorithms/DiamondWrapper.py   |  2 +-
 cami_src/algorithms/DominoWrapper.py    | 16 ++---
 cami_src/algorithms/RobustWrapper.py    |  2 +-
 cami_src/cami.py                        | 71 ++++++++++---------
 cami_src/cami_suite.py                  | 91 ++++++++++++-------------
 cami_src/seed_variation_2.py            |  2 +-
 cami_src/utils/degradome.py             |  5 +-
 cami_src/utils/networks.py              | 72 +++++++++++++++++--
 9 files changed, 161 insertions(+), 104 deletions(-)

diff --git a/cami_src/algorithms/AlgorithmWrapper.py b/cami_src/algorithms/AlgorithmWrapper.py
index 6378fad..17de832 100644
--- a/cami_src/algorithms/AlgorithmWrapper.py
+++ b/cami_src/algorithms/AlgorithmWrapper.py
@@ -38,10 +38,10 @@ class AlgorithmWrapper(object):
 
     def create_tmp_output_dir(self, tmp_dir):
         out_dir = os.path.join(tmp_dir, self.name)
-        print(tmp_dir)
+        #print(tmp_dir)
         if not os.path.exists(out_dir):
             os.mkdir(out_dir)
-            print(f"created temporary directory for {self.name} named {out_dir}...")
+            if self.debug: print(f"created temporary directory for {self.name} named {out_dir}...")
         self.output_dir = out_dir
 
     def name_file(self, kind, ending='txt'):
diff --git a/cami_src/algorithms/DiamondWrapper.py b/cami_src/algorithms/DiamondWrapper.py
index 93ca932..ee61636 100644
--- a/cami_src/algorithms/DiamondWrapper.py
+++ b/cami_src/algorithms/DiamondWrapper.py
@@ -42,7 +42,7 @@ class DiamondWrapper(AlgorithmWrapper):
         command = f'{diamond} "{ppi}" "{seeds}" {nof_predictions} {self.alpha} "{algo_output}"'
         subprocess.call(command, shell=True, stdout=subprocess.PIPE)
         assert os.path.exists(algo_output), f'DIAMOnD failed to save output to {algo_output}'
-        print(f"DIAMOnD results saved in {algo_output}")
+        if self.debug: print(f"DIAMOnD results saved in {algo_output}")
         return self.extract_output(algo_output)
 
     def prepare_input(self):
diff --git a/cami_src/algorithms/DominoWrapper.py b/cami_src/algorithms/DominoWrapper.py
index d642ecf..34953a3 100644
--- a/cami_src/algorithms/DominoWrapper.py
+++ b/cami_src/algorithms/DominoWrapper.py
@@ -54,7 +54,7 @@ class DominoWrapper(AlgorithmWrapper):
         command = f'mv "{algo_output}" "{os.path.join(self.output_dir, outputfilename)}"'
         subprocess.call(command, shell=True, stdout=subprocess.PIPE)
         algo_output = os.path.join(self.output_dir, outputfilename)
-        print(f"{self.name} results saved in {algo_output}")
+        if self.debug: print(f"{self.name} results saved in {algo_output}")
 
         return self.extract_output(algo_output)
 
@@ -64,8 +64,7 @@ class DominoWrapper(AlgorithmWrapper):
         """
         inputparams = []
         # prepare inputfiles
-        if self.debug:
-            print(f'creating {self.name} input files in {self.output_dir}')
+        if self.debug: print(f'creating {self.name} input files in {self.output_dir}')
 
         ppi_filename = self.name_file('ppi', 'sif')
         ppi_file = os.path.join(self.output_dir, ppi_filename)
@@ -81,26 +80,23 @@ class DominoWrapper(AlgorithmWrapper):
                 file.write(f"{str(edge.source()) + '_'}\tppi\t{str(edge.target()) + '_'}\n")
                 # the nodes need to be appended by '_' so that pandas recognizes the vertices as strings
         inputparams.append(ppi_file)
-        print(f'{self.name} ppi is saved in {ppi_file}')
+        if self.debug: print(f'{self.name} ppi is saved in {ppi_file}')
 
         with open(seed_file, "w") as file:
             file.write('#node\n')
             for seed in self.seeds:
                 file.write(f"{seed}_\n")
         inputparams.append(seed_file)
-        if self.debug:
-            print(f'{self.name} seeds are saved in {seed_file}')
+        if self.debug: print(f'{self.name} seeds are saved in {seed_file}')
 
         slices_filename = self.name_file('slices')
         slices_output = os.path.join(self.output_dir, slices_filename)
 
         if not os.path.exists(slices_output):
-            if self.debug:
-                print('creating domino slices_file...')
+            if self.debug: print('creating domino slices_file...')
             command = f'slicer --network_file "{ppi_file}" --output_file "{slices_output}"'
             subprocess.call(command, shell=True, stdout=subprocess.PIPE)
-        if self.debug:
-            print(f'{self.name} slices are saved in {slices_output}')
+        if self.debug: print(f'{self.name} slices are saved in {slices_output}')
         inputparams.append(slices_output)
         return inputparams
 
diff --git a/cami_src/algorithms/RobustWrapper.py b/cami_src/algorithms/RobustWrapper.py
index 84d00ef..fb89b20 100755
--- a/cami_src/algorithms/RobustWrapper.py
+++ b/cami_src/algorithms/RobustWrapper.py
@@ -58,7 +58,7 @@ class RobustWrapper(AlgorithmWrapper):
             {self.initialFraction} {self.reductionFactor} \
             {self.numberSteinerTrees} {self.threshold}'
         subprocess.call(command, shell=True, stdout=subprocess.PIPE)
-        print(f"Robust results saved in {algo_output}")
+        if self.debug: print(f"Robust results saved in {algo_output}")
         return self.extract_output(algo_output)
 
         # inputparams = []
diff --git a/cami_src/cami.py b/cami_src/cami.py
index 7747d8d..12cf3ef 100755
--- a/cami_src/cami.py
+++ b/cami_src/cami.py
@@ -21,7 +21,7 @@ from configparser import ConfigParser
 
 def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
          output_dir, identifier, save_temps, nvenn, save_image, force, drugstone, ncbi, configuration, 
-         seed_variation, parallelization, external_results):
+         seed_variation, parallelization, external_results, debug):
     print('CAMI started')
     config = ConfigParser()
     config.read(configuration)
@@ -55,7 +55,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
     # rename seeds with their corresponding indices in the ppi_graph
     name2index = preprocess.name2index_dict(ppi_graph)
     checked_seed_lst = [name2index[v] for v in seed_lst if v in name2index]
-    print('Created the PPI-network Graph and seed list.')
+    if debug:
+        print('Created the PPI-network Graph and seed list.')
     n = len(checked_seed_lst)
     m = len(seed_lst)
     if n != m  and not force:
@@ -69,20 +70,22 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
             exit(1)
 
     seed_lst = checked_seed_lst.copy()
-    print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
+    if debug: print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
 
     # change directory to ~/cami/cami (home of cami.py)
     cami_home = sys.argv[0].rsplit('/', 1)
     os.chdir(cami_home[0])
     home_path = os.path.dirname(os.getcwd())
-    print(f"Home directory of cami: {home_path}")
+    if debug:
+        print(f"Home directory of cami: {home_path}")
     if identifier==None:
         identifier = str(uuid.uuid4())
 
     if output_dir==None:
         output_dir = os.path.join(home_path, f'data/output/{identifier}')
         output_dir = os.path.abspath(output_dir)
-        print(f"Output directory of cami: {output_dir}")
+        if debug:
+            print(f"Output directory of cami: {output_dir}")
         if not os.path.exists(output_dir):
             os.makedirs(output_dir)
 
@@ -91,7 +94,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     #create temporary directory
     tmp_dir = os.path.join(home_path, f'data/tmp/{identifier}')
-    print(f'Creating unique temporary directory for CAMI: {tmp_dir}')
+    if debug:
+        print(f'Creating unique temporary directory for CAMI: {tmp_dir}')
 
     if not os.path.exists(tmp_dir):
         os.makedirs(tmp_dir)
@@ -114,6 +118,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     if ncbi:
         cami.ncbi = True
+    
+    if debug:
+        cami.debug = True
         
     for tool in tool_wrappers:
         cami.initialize_tool(tool)
@@ -126,7 +133,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     cami.set_initial_seed_lst(initial_seedlst)
 
-    if consensus or (not consensus and not evaluate and not seed_variation):
+    if (consensus or (not consensus and not evaluate and not seed_variation)) and not seed_variation:
         result_sets = cami.make_predictions()
         if len(external_results) > 0:
             assert(len(ext_wrappers) == len(external_results))
@@ -169,7 +176,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
             result_sets = cami.make_predictions()
             cami.create_consensus(result_sets)
             if nvenn and vis:
-                url = cami.nvenn()
+                url = cami.use_nvenn()
                 cami.download_diagram(url)
                 
         with open(f'{output_dir}/00_node_degrees.tsv', 'w') as node_degrees:
@@ -281,25 +288,20 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         tools = [tool for tool in rediscovery_rates_results[0].keys()]
         tool_labels = tools.copy()
         redisc_rates = [[res[tool] for res in rediscovery_rates_results] for tool in tools]
-        for idx,tool in enumerate(tool_labels):
+        for idx,tool in enumerate(tools):
             if '_' in tool:
                 # find the index of the second occurrence of the character
                 second_occurrence_index = tool.find('_', tool.find('_') + 1)
                 if second_occurrence_index > -1:
                     # replace the character at that index with the replacement character
                     tool_name = tool[:second_occurrence_index] + '\n' + tool[second_occurrence_index + 1:]
-
-                    tool_labels[idx] = tool_name
-        
+                    tool_labels[idx] = tool_name        
 
         #PLOT
-
         # Create a figure instance
-        fig, axs = plt.subplots(2, 1, figsize=(20,15))
-        plt.subplots_adjust(left=0.2)
-        
-        ax1 = axs[0]
-        ax5 = axs[1]
+        #print(sys.getrecursionlimit())
+        fig, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20))
+        fig.subplots_adjust(left=0.2)
         # Extract Figure and Axes instance
 
         # Create a plot
@@ -307,6 +309,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
          
         for violinpart in list(violins1.keys())[2:]:
             violins1[violinpart].set_color('k')
+            
         for violin in violins1['bodies']:
             violin.set_facecolor('red')
         # Add title
@@ -319,19 +322,19 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14)
 
         
-        # ax4 = plt.subplot(1,3,2, label='ax4')
-        # violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
-        # for violinpart in list(violins2.keys())[2:]:
-        #     violins2[violinpart].set_color('k')
-        # for violin in violins2['bodies']:
-        #     violin.set_facecolor('orange')
-        # # Add title
-        # ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds\n{nof_iterations} times from {identifier} seeds.', wrap=True)
+        violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
+        for violinpart in list(violins2.keys())[2:]:
+            violins2[violinpart].set_color('k')
+        for violin in violins2['bodies']:
+            violin.set_facecolor('orange')
+        # Add title
+        ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
 
-        # ax4.set_xticks(list(range(1,len(tools)+1)))
-        # ax4.set_xticklabels(tools)
+        ax4.set_xticks(list(range(1,len(tools)+1)))
+        ax4.set_xticklabels(tool_labels)
+        ax4.tick_params(axis='x', labelsize=11)
 
-        # ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True)
+        ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14)
 
         violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True)
         # Add title
@@ -345,11 +348,11 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         ax5.set_xticks(list(range(1,len(tools)+1)))
         ax5.set_xticklabels(tool_labels)
         
-        ax5.set_ylabel('(<rediscovered seeds>/<module size>)', fontsize=14)
+        ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14)
         ax5.tick_params(axis='x', labelsize=11)
-        plt.tight_layout()
+        fig.tight_layout()
         fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
-        
+        plt.close(fig)
         print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png')
         
     if save_temps:
@@ -401,6 +404,8 @@ if __name__ == "__main__":
                         help="Choose a configuration for the static variables.")
     parser.add_argument('-p', '--parallelization', action='store_true', 
             help="run the tools for prediction parallelized")
+    parser.add_argument('-db', '--debug', action='store_true', 
+            help="run CAMI with verbose outputs")
     #TODO List with additional arguments if needed by certain tools
 
     args = vars(parser.parse_args())
@@ -412,7 +417,7 @@ if __name__ == "__main__":
 
 # add name for implemented tools here:
     implemented_tools = [
-        "domino",
+                         "domino",
                          "diamond",
                          "robust",
                          "hotnet"
diff --git a/cami_src/cami_suite.py b/cami_src/cami_suite.py
index 235b624..b8ac144 100644
--- a/cami_src/cami_suite.py
+++ b/cami_src/cami_suite.py
@@ -180,9 +180,9 @@ class cami():
         :rtype: set()
         """
         tool.create_tmp_output_dir(self.tmp_dir) # creates the temporary input directory
-        print(f"preparing {tool.name} input...")
+        if self.debug: print(f"preparing {tool.name} input...")
         inputparams = tool.prepare_input()
-        print(f'running {tool.name}...')
+        if self.debug: print(f'running {tool.name}...')
         preds = set(tool.run_algorithm(inputparams))
         if self.debug:
             print(f'{tool.name} predicted {len(preds)} active vertices (seeds not excluded):')
@@ -224,7 +224,8 @@ class cami():
                  to the corresponding tool
         :rtype: dict(AlgorithmWrapper():set(Graph.vertex()))
         """
-        print(f'Creating result sets of all {self.nof_tools} tools...')
+        if self.debug:
+            print(f'Creating result sets of all {self.nof_tools} tools...')
         pred_sets = {tool:None for tool in self.tool_wrappers}
         
         if self.threaded:
@@ -337,18 +338,21 @@ class cami():
             }},
         }
         
-        # create integer codes for cami_versions (needed for predicted_by vertex property)
-
+        # transform all vertex indices to their corresponding gene names in a result set
+        for tool in result_sets:
+            self.result_gene_sets[tool.name] = set([gene_name_map[vertex] for vertex in result_sets[tool]])
+                
         for cami_method_name, cami_params in camis.items():
-            print("Running " + cami_method_name)
+            if self.debug:
+                print("Running " + cami_method_name)
+            # create integer codes for cami_versions (needed for predicted_by vertex property)
             tool_code = max(list(tool_name_map.keys())) + 1
             tool_name_map[tool_code] = cami_method_name
             
             cami_vertices, putative_vertices, codes2tools = cami_params['function'](result_sets, ppi_graph, seed_list,
-                                                                                    predicted_by, cami_scores,
-                                                                                    tool_name_map, tool_code,
-                                                                                    cami_params['params'])
-
+                                                                                        predicted_by, cami_scores,
+                                                                                        tool_name_map, tool_code,
+                                                                                        cami_params['params'])
             # sort the resulting vertices according to their cami_score
             cami_vlist = sorted(cami_vertices, key=lambda v: cami_scores[v], reverse=True)
 
@@ -362,29 +366,23 @@ class cami():
                 for vertex in cami_vlist:
                     print(f'{gene_name_map[vertex]}\t{cami_scores[vertex]}\t{codes2tools[vertex]}')
             else:
-                print(f'With the {len(seed_genes)} seed genes CAMI ({cami_method_name}) proposes {len(cami_vlist)} to add to the Active Module')
-                
+                print(f'With the {len(seed_genes)} seed genes CAMI ({cami_method_name}) proposes {len(cami_vlist)} genes to add to the Active Module')
+            
             # for visualization with nvenn
-            self.result_gene_sets[cami_method_name] = cami_genes
-                
-            # transform all vertex indices to their corresponding gene names in a result set
-            for tool in result_sets:
-                self.result_gene_sets[tool.name] = set([gene_name_map[vertex] for vertex in result_sets[tool]])
-                
-            # add seeds to result sets for drugstone and digest
-            for tool in result_sets:
-                self.result_module_sets[tool.name] = set([gene_name_map[vertex] for vertex in result_sets[tool]]).union(self.seed_lst)
+            self.result_gene_sets[cami_method_name] = set(cami_genes)
             
-            assert(self.code2toolname == tool_name_map)
+            # add seeds to result sets for drugstone and digest
+            for toolname in self.result_gene_sets:
+                self.result_module_sets[toolname] = self.result_gene_sets[toolname].union(set([gene_name_map[svertex] for svertex in self.seed_lst]))
             
             # save the results in outputfiles
             self.generate_output(cami_method_name, seed_genes, cami_vlist, cami_vertices, putative_vertices, cami_genes,
                                  gene_name_map, codes2tools, cami_scores)
 
-    def generate_output(self, cami_method, seed_genes, cami_vlist, cami_vertices, putative_vertices, cami_genes,
-                        gene_name_map, codes2tools, cami_scores):
+        
         # save all predictions by all tools
-        print('Saving the results...')
+        if self.debug:
+            print('Saving the results...')
         with open(f'{self.output_dir}/all_predictions_{self.uid}.tsv', 'w') as outputfile:
             outputfile.write(f'CAMI predictions with {len(self.seed_lst)} of initially {len(self.initial_seed_lst)} seeds: {seed_genes},\n'+
                              f'initially: {self.initial_seed_lst}\n')
@@ -394,11 +392,15 @@ class cami():
                 outputfile.write(f'{gene_name_map[vertex]}\t{codes2tools[vertex]}\t{cami_scores[vertex]}\t{str(vertex)}\t{vertex.out_degree()}\n')
         print(f'saved all predictions by the used tools in: {self.output_dir}/all_predictions_{self.uid}.tsv')
 
+        
+    def generate_output(self, cami_method, seed_genes, cami_vlist, cami_vertices, putative_vertices, cami_genes,
+                        gene_name_map, codes2tools, cami_scores):
+
         # save the predictions made by cami
         ncbi_url = ('\tncbi_url' if self.ncbi else '')
         ncbi_summary = ('\tncbi_summary' if self.ncbi else '')
 
-        with open(f'{self.output_dir}/CAMI_output_{self.uid}.tsv', 'w') as outputfile:
+        with open(f'{self.output_dir}/{cami_method}_output_{self.uid}.tsv', 'w') as outputfile:
             outputfile.write(f'gene\tindex_in_graph\tcami_score\tdegree_in_graph{ncbi_url}{ncbi_summary}\n')     
             for vertex in cami_vlist:
                 if self.ncbi:
@@ -412,32 +414,28 @@ class cami():
                     url, summary = '',''
                 outputfile.write(f'{gene_name_map[vertex]}\t{str(vertex)}\t{cami_scores[vertex]}\t{vertex.out_degree()}{url}{summary}\n')
         
-        # save the whole module
-        whole_module = []
-        with open(f'{self.output_dir}/CAMI_module_{cami_method}_{self.uid}.txt', 'w') as modfile:
-                for vertex in seed_genes:
-                    modfile.write(f'{vertex}\n')
-                    whole_module.append(vertex)
-                for vertex in cami_genes:
-                    modfile.write(f'{vertex}\n')
-                    whole_module.append(vertex)
-
-        print(f'saved cami output in: {self.output_dir}/CAMI_output_{self.uid}.tsv')
-        print(f'saved the Consensus Active Module by CAMI in: {self.output_dir}/CAMI_nodes_{cami_method}_{self.uid}.txt')
+        # # save the whole module
+        # whole_module = []
+        # with open(f'{self.output_dir}/{cami_method}_module_{self.uid}.txt', 'w') as modfile:
+        #         for vertex in seed_genes:
+        #             modfile.write(f'{vertex}\n')
+        #             whole_module.append(vertex)
+        #         for vertex in cami_genes:
+        #             modfile.write(f'{vertex}\n')
+        #             whole_module.append(vertex)
+
+        # print(f'saved {cami_method} output in: {cami_method}_output_{self.uid}.tsv')
+        # print(f'saved the Consensus Active Module by CAMI in: {self.output_dir}/{cami_method}_module_{self.uid}.txt')
         
        
-        # save predictions by the other tools
-        for tool in self.result_gene_sets:
+        # save predicted modules by all other tools
+        for tool in self.result_module_sets:
             with open(f'{self.output_dir}/{tool}_output_{self.uid}.tsv', 'w') as outputfile:
                 outputfile.write('gene\n')
                 for gene in self.result_gene_sets[tool]:
                     outputfile.write(f'{gene}\n')
-            print(f'saved {tool} output in: {self.output_dir}/{tool}_output_{self.uid}.tsv')
-                    
-        # return values
-        consensus = {}
-        consensus['module'] = whole_module
-        consensus['seeds'] = self.seed_lst
+            if self.debug:
+                print(f'saved {tool} output in: {self.output_dir}/{tool}_output_{self.uid}.tsv')
 
 
     def use_nvenn(self):
@@ -486,7 +484,6 @@ class cami():
         #print(list(set(cami_symbol_edges)))
         url = drugstone.send_request(cami_symbols, cami_symbol_edges)
         print(f'You can find a network visualization of the CAMI module via: {url}')
-        print('The link was also saved in the outputfolder for later.')
         with open(f'{self.output_dir}/drugstone_link_{self.uid}.txt', 'w') as f:
             f.write(url)
         return url
diff --git a/cami_src/seed_variation_2.py b/cami_src/seed_variation_2.py
index 12bd963..278ed7d 100755
--- a/cami_src/seed_variation_2.py
+++ b/cami_src/seed_variation_2.py
@@ -9,5 +9,5 @@ seedfiles = sys.argv[2:]
 config = 'seed_variationconf'
 for seeds in seedfiles:
     identifier = basename(seeds).rsplit('.')[0] + '_seedvar_different_consensus_approaches'
-    command = f'./cami.py -n {network} -s {seeds} -id {identifier} -conf {config} -var 100 -f;'
+    command = f'./cami.py -n {network} -s {seeds} -id {identifier} -conf {config} -var 10 -f;'
     subprocess.call(command, shell=True)
\ No newline at end of file
diff --git a/cami_src/utils/degradome.py b/cami_src/utils/degradome.py
index b295427..1b13755 100644
--- a/cami_src/utils/degradome.py
+++ b/cami_src/utils/degradome.py
@@ -1,14 +1,13 @@
 import requests, re
 
-def send_request(sets=dict(), seeds=list()):
+def send_request(sets=dict()):
     url = 'http://degradome.uniovi.es/cgi-bin/nVenn/nvenn.cgi'
 
     groups = '['
     for tool in sets:
         groups += '{' + f'"name":"{tool}","els":['
         for node in sets[tool]:
-            if node not in seeds:
-                groups += f'"{node}",'
+            groups += f'"{node}",'
         groups = groups[:-1]
         groups += ']},'
     groups = groups[:-1]
diff --git a/cami_src/utils/networks.py b/cami_src/utils/networks.py
index f045996..244388e 100644
--- a/cami_src/utils/networks.py
+++ b/cami_src/utils/networks.py
@@ -121,14 +121,45 @@ def steiner_tree(g, seeds, seed_map, weights, non_zero_hub_penalty):
             except:
                 pass
             g2.remove_vertex(node)
-
     return g2
 
+def find_bridges_it(g):
+    r"""Finds all bridges in a graph."""
+    global __time
+    __time = 0
+    visited = g.new_vertex_property("boolean", False)
+    disc = g.new_vertex_property("float", float("inf"))
+    low = g.new_vertex_property("float", float("inf"))
+    is_bridge = g.new_edge_property("boolean", False)
+    stack = [(None, g.get_vertices()[0])]
+    while stack:
+        parent_node, node = stack.pop()
+        if not visited[node]:
+            visited[node] = True
+            disc[node] = __time
+            low[node] = __time
+            __time += 1
+            for neighbor in g.get_all_neighbors(node):
+                if not visited[neighbor]:
+                    stack.append((node, neighbor))
+                elif neighbor != parent_node:
+                    low[node] = min(low[node], disc[neighbor])
+            if low[node] == disc[node] and parent_node is not None:
+                try:
+                    is_bridge[g.edge(node, parent_node)] = True
+                except Exception as err:
+                    print(f'{err}')
+                    pass
+        elif parent_node is not None:
+            low[parent_node] = min(low[parent_node], disc[node])
+    return is_bridge
+
+
 def find_bridges(g):
     r"""Finds all bridges in a graph."""
     global __time
     __time = 0
-    sys.setrecursionlimit(g.num_vertices() + 1)
+    #sys.setrecursionlimit(g.num_vertices() + 1)
     visited = g.new_vertex_property("boolean", False)
     disc = g.new_vertex_property("float", float("inf"))
     low = g.new_vertex_property("float", float("inf"))
@@ -139,13 +170,43 @@ def find_bridges(g):
             __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge)
     return is_bridge
 
-def __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge):
+def __dfs_find_bridges_it(g, node, visited, disc, low, parent, is_bridge):
     visited[node] = True
     global __time
     disc[node] = __time
     low[node] = __time
     __time += 1
-
+    stack = []
+    for nb in g.get_all_neighbors(node):
+        if not visited[nb]:
+            parent[nb] = node
+            stack.append(nb)
+            low[node] = min(low[node], low[nb])
+            if low[nb] > disc[node]:
+                    try:
+                        is_bridge[g.edge(node, nb)] = True
+                    except:
+                        pass
+            elif int(nb) != parent[node]: #TODO can in theory be removed
+                low[node] = min(low[node], disc[nb])
+            
+    while (stack):
+        next_node = stack.pop()
+        if not visited[next_node]:
+            visited[next_node] = True
+            for nb in g.get_all_neighbors(next_node):
+                parent[nb] = next_node
+                stack.append(nb)
+                low[next_node] = min(low[next_node], low[nb])
+                if low[nb] > disc[next_node]:
+                    try:
+                        is_bridge[g.edge(next_node, nb)] = True
+                    except:
+                        pass
+        elif int(nb) != parent[node]: #TODO can in theory be removed
+            low[next_node] = min(low[next_node], disc[nb])
+        
+def __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge):    
     for nb in g.get_all_neighbors(node):
         if not visited[nb]:
             parent[nb] = node
@@ -159,7 +220,6 @@ def __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge):
         elif int(nb) != parent[node]: #TODO can in theory be removed
             low[node] = min(low[node], disc[nb])
 
-
 def must(g, seed_ids, num_trees, hub_penalty, weights=None, tolerance=10):
     if gt.openmp_enabled():
         gt.openmp_set_num_threads(6)
@@ -186,7 +246,7 @@ def must(g, seed_ids, num_trees, hub_penalty, weights=None, tolerance=10):
     for vertex in tree_nodes:
         scores[vertex] +=1
     if num_trees > 1:
-        is_bridge = find_bridges(g)
+        is_bridge = find_bridges_it(g)
         edge_filter = g.new_edge_property("boolean", True)
         while len(tree_edges) > 0:
             tree_edge = tree_edges.pop()
-- 
GitLab