From db0bbc80ac5cbb22aa99b98ed7be857e257fdef7 Mon Sep 17 00:00:00 2001
From: bay9355 <mia.le@studium.uni-hamburg.de>
Date: Tue, 14 Mar 2023 15:48:21 +0100
Subject: [PATCH] Merged seed_variation plotting functions with main

---
 cami_src/cami.py | 134 +++++++++++++++++++++++++++++++++--------------
 1 file changed, 95 insertions(+), 39 deletions(-)

diff --git a/cami_src/cami.py b/cami_src/cami.py
index d5c1f58..4b6da42 100755
--- a/cami_src/cami.py
+++ b/cami_src/cami.py
@@ -14,6 +14,8 @@ import argparse
 import webbrowser
 import random
 import matplotlib.pyplot as plt
+import seaborn as sb
+import pandas as pd
 
 #MC:
 from configparser import ConfigParser
@@ -21,7 +23,7 @@ from configparser import ConfigParser
 
 def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
          output_dir, identifier, save_temps, nvenn, save_image, force, drugstone, ncbi, configuration, 
-         seed_variation, parallelization, external_results):
+         seed_variation, parallelization, external_results, debug):
     print('CAMI started')
     config = ConfigParser()
     config.read(configuration)
@@ -55,7 +57,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
     # rename seeds with their corresponding indices in the ppi_graph
     name2index = preprocess.name2index_dict(ppi_graph)
     checked_seed_lst = [name2index[v] for v in seed_lst if v in name2index]
-    print('Created the PPI-network Graph and seed list.')
+    if debug:
+        print('Created the PPI-network Graph and seed list.')
     n = len(checked_seed_lst)
     m = len(seed_lst)
     if n != m  and not force:
@@ -69,20 +72,22 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
             exit(1)
 
     seed_lst = checked_seed_lst.copy()
-    print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
+    if debug: print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
 
     # change directory to ~/cami/cami (home of cami.py)
     cami_home = sys.argv[0].rsplit('/', 1)
     os.chdir(cami_home[0])
     home_path = os.path.dirname(os.getcwd())
-    print(f"Home directory of cami: {home_path}")
+    if debug:
+        print(f"Home directory of cami: {home_path}")
     if identifier==None:
         identifier = str(uuid.uuid4())
 
     if output_dir==None:
         output_dir = os.path.join(home_path, f'data/output/{identifier}')
         output_dir = os.path.abspath(output_dir)
-        print(f"Output directory of cami: {output_dir}")
+        if debug:
+            print(f"Output directory of cami: {output_dir}")
         if not os.path.exists(output_dir):
             os.makedirs(output_dir)
 
@@ -91,7 +96,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     #create temporary directory
     tmp_dir = os.path.join(home_path, f'data/tmp/{identifier}')
-    print(f'Creating unique temporary directory for CAMI: {tmp_dir}')
+    if debug:
+        print(f'Creating unique temporary directory for CAMI: {tmp_dir}')
 
     if not os.path.exists(tmp_dir):
         os.makedirs(tmp_dir)
@@ -114,6 +120,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     if ncbi:
         cami.ncbi = True
+    
+    if debug:
+        cami.debug = True
         
     for tool in tool_wrappers:
         cami.initialize_tool(tool)
@@ -126,7 +135,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
 
     cami.set_initial_seed_lst(initial_seedlst)
 
-    if consensus or (not consensus and not evaluate and not seed_variation):
+    if (consensus or (not consensus and not evaluate and not seed_variation)) and not seed_variation:
         result_sets = cami.make_predictions()
         if len(external_results) > 0:
             assert(len(ext_wrappers) == len(external_results))
@@ -150,7 +159,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
                     cami.download_diagram(url)
 
         if drugstone is not None:
-            print('Sending results to Drugst.One')
+            print('Sending results to DrugstOne')
             cami.use_drugstone()
         if consensus:
             cami.reset_cami()
@@ -165,11 +174,37 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         
     # SEED VARIATION
     if seed_variation:
+        
         def predict_and_make_consensus(vis=False):
             result_sets = cami.make_predictions()
             cami.create_consensus(result_sets)
+            if vis:
+                n_results = len(cami.result_gene_sets)
+                comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)],
+                                                 columns = list(cami.result_gene_sets.keys()),
+                                                 index = list(cami.result_gene_sets.keys()),
+                                                 dtype=int)
+                for algo1 in cami.result_gene_sets:
+                    for algo2 in cami.result_gene_sets:
+                        comparison_matrix.loc[algo1,algo2] = int(len(cami.result_gene_sets[algo1].intersection(cami.result_gene_sets[algo2])))
+                        
+                fig2, ax3 = plt.subplots(figsize=(20,20))      
+                              
+                ax3 = sb.heatmap(comparison_matrix, annot=True, fmt='g')
+                ax3.set_title('Intersections of result_gene_sets of all analyzed algorithms.')
+                fig2.savefig(f'{output_dir}/heatmap_{cami.uid}.png', bbox_inches="tight")
+                print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}.png')
+                plt.close(fig2)
+                
+                fig2a, ax3a = plt.subplots(figsize=(20,20))
+                comparison_matrix_normalized = comparison_matrix.apply(lambda row: row/row.max(), axis=1)
+                ax3a = sb.heatmap(comparison_matrix_normalized, annot=True, fmt='.2f')
+                ax3a.set_title('Normalized intersections of result_gene_sets of all analyzed algorithms.')
+                fig2a.savefig(f'{output_dir}/heatmap_{cami.uid}_normalized.png', bbox_inches="tight")
+                print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}_normalized.png')
+                plt.close(fig2a)
             if nvenn and vis:
-                url = cami.nvenn()
+                url = cami.use_nvenn()
                 cami.download_diagram(url)
                 
         with open(f'{output_dir}/00_node_degrees.tsv', 'w') as node_degrees:
@@ -187,7 +222,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
             seedname = seeds
             for tool in cami.result_gene_sets:
                 f.write(f'\n{seedname}\t{len(cami.seed_lst)}\t{tool}\t{len(cami.result_gene_sets[tool])}')
-        
+        cami.make_evaluation()
         #predict_and_make_consensus(vis=True)
         
         random.seed(50)
@@ -201,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv'
         result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv'
         result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv'
+        n_results = len(cami.result_gene_sets)
+        
+        redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)],
+                                                    columns = list(cami.result_gene_sets.keys()),
+                                                    index = list(cami.result_gene_sets.keys()),
+                                                    dtype=int)
         with open(result_file_1, 'w') as res_table1:
             with open(result_file_2, 'w') as res_table2:
                 with open(result_file_3, 'w') as res_table3:
@@ -249,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
                             
                             for tool in result_dict:
                                 nof_predictions = len(result_dict[tool]) + len(used_seeds)
-                                redisc_seeds = list(set(result_dict[tool]).intersection(set(rem_seeds)))
+                                redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds))
                                 redisc_prev = len(redisc_seeds)
                                 redisc_rate = redisc_prev / nof_removals
                                 redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen
@@ -263,43 +304,54 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
                                 res_table1.write('\t')
                                 for idx,seed in enumerate(redisc_seeds):
                                     if idx == 0:
-                                        res_table1.write(f'{redisc_seeds[0]}')
+                                        res_table1.write(f'{list(redisc_seeds)[0]}')
                                     else:
                                         res_table1.write(f',{seed}') 
                                 print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.')
                                 res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n')
                             res_table1.write('\n')
                             variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds))
-                            
+                            for algo1 in redisc_seeds_dict:
+                                for algo2 in redisc_rate_dict:
+                                    redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2]))
         print(f'Result tables are saved in the following locations:')
         print(f'Rediscovered seeds: {result_file_1}')
         print(f'Rediscovery Rates: {result_file_2}')
         print(f'Sensitivity: {result_file_3}')
+        fig3, ax6 = plt.subplots(figsize=(20,20))                    
+        ax6 = sb.heatmap(redisc_intersection_matrix, annot=True, fmt='g')
+        ax6.set_title('Number of times the algorithms rediscovered the same seeds')
+        fig3.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}.png', bbox_inches="tight")
+        print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_{identifier}.png')
+        plt.close(fig3)
+        fig3a, ax6a = plt.subplots(figsize=(20,20))
+        redisc_intersection_matrix_normalized = redisc_intersection_matrix.apply(lambda row: row/row.max(), axis=1)
+        ax6a = sb.heatmap(redisc_intersection_matrix_normalized, annot=True, fmt='.2f')
+        ax6a.set_title('Normalized numbers of times the algorithms rediscovered the same seeds')
+        fig3a.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}_normalized.png', bbox_inches="tight")
+        print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_normalized_{cami.uid}.png')
+        plt.close(fig3a)
+        
 #        print(variation_results)
         rediscovery_rates_results = [results[0] for results in variation_results]
 #        print(rediscovery_rates_results)
         tools = [tool for tool in rediscovery_rates_results[0].keys()]
         tool_labels = tools.copy()
         redisc_rates = [[res[tool] for res in rediscovery_rates_results] for tool in tools]
-        for idx,tool in enumerate(tool_labels):
+        for idx,tool in enumerate(tools):
             if '_' in tool:
                 # find the index of the second occurrence of the character
                 second_occurrence_index = tool.find('_', tool.find('_') + 1)
                 if second_occurrence_index > -1:
                     # replace the character at that index with the replacement character
                     tool_name = tool[:second_occurrence_index] + '\n' + tool[second_occurrence_index + 1:]
-
-                    tool_labels[idx] = tool_name
-        
+                    tool_labels[idx] = tool_name        
 
         #PLOT
-
         # Create a figure instance
-        fig, axs = plt.subplots(2, 1, figsize=(20,15))
-        plt.subplots_adjust(left=0.2)
-        
-        ax1 = axs[0]
-        ax5 = axs[1]
+        #print(sys.getrecursionlimit())
+        fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20))
+        fig1.subplots_adjust(left=0.2)
         # Extract Figure and Axes instance
 
         # Create a plot
@@ -307,6 +359,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
          
         for violinpart in list(violins1.keys())[2:]:
             violins1[violinpart].set_color('k')
+            
         for violin in violins1['bodies']:
             violin.set_facecolor('red')
         # Add title
@@ -319,19 +372,19 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14)
 
         
-        # ax4 = plt.subplot(1,3,2, label='ax4')
-        # violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
-        # for violinpart in list(violins2.keys())[2:]:
-        #     violins2[violinpart].set_color('k')
-        # for violin in violins2['bodies']:
-        #     violin.set_facecolor('orange')
-        # # Add title
-        # ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds\n{nof_iterations} times from {identifier} seeds.', wrap=True)
+        violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
+        for violinpart in list(violins2.keys())[2:]:
+            violins2[violinpart].set_color('k')
+        for violin in violins2['bodies']:
+            violin.set_facecolor('orange')
+        # Add title
+        ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
 
-        # ax4.set_xticks(list(range(1,len(tools)+1)))
-        # ax4.set_xticklabels(tools)
+        ax4.set_xticks(list(range(1,len(tools)+1)))
+        ax4.set_xticklabels(tool_labels)
+        ax4.tick_params(axis='x', labelsize=11)
 
-        # ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True)
+        ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14)
 
         violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True)
         # Add title
@@ -345,11 +398,11 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         ax5.set_xticks(list(range(1,len(tools)+1)))
         ax5.set_xticklabels(tool_labels)
         
-        ax5.set_ylabel('(<rediscovered seeds>/<module size>)', fontsize=14)
+        ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14)
         ax5.tick_params(axis='x', labelsize=11)
-        plt.tight_layout()
-        fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
-        
+        fig1.tight_layout()
+        fig1.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
+        plt.close(fig1)
         print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png')
         
     if save_temps:
@@ -401,6 +454,8 @@ if __name__ == "__main__":
                         help="Choose a configuration for the static variables.")
     parser.add_argument('-p', '--parallelization', action='store_true', 
             help="run the tools for prediction parallelized")
+    parser.add_argument('-db', '--debug', action='store_true', 
+            help="run CAMI with verbose outputs")
     #TODO List with additional arguments if needed by certain tools
 
     args = vars(parser.parse_args())
@@ -412,7 +467,7 @@ if __name__ == "__main__":
 
 # add name for implemented tools here:
     implemented_tools = [
-        "domino",
+                         "domino",
                          "diamond",
                          "robust",
                          "hotnet"
@@ -421,4 +476,5 @@ if __name__ == "__main__":
         for tool in args['tools']:
             if tool not in implemented_tools:
                 raise RuntimeError(tool + f' not implemented yet. Implemented tools are: {implemented_tools}')
+
     main(**args)
-- 
GitLab