From bda13dbdf935e7d5b9c1ab4ebadbe4f52c1e89e7 Mon Sep 17 00:00:00 2001
From: bay9355 <mia.le@studium.uni-hamburg.de>
Date: Thu, 30 Mar 2023 19:05:13 +0200
Subject: [PATCH] added comparison seedvariation output

---
 cami_src/cami.py       |  6 ++----
 cami_src/cami_suite.py | 20 ++++++++++++--------
 2 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/cami_src/cami.py b/cami_src/cami.py
index 9d75ece..06969d3 100755
--- a/cami_src/cami.py
+++ b/cami_src/cami.py
@@ -67,7 +67,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
                 if save_image:
                     cami.download_diagram(url)
         
-        if visualize or comparison_matrix:
+        if visualize or comparison_matrix or seed_variation:
             cami.visualize_and_save_comparison_matrix()
         
         if drugstone is not None:
@@ -77,10 +77,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
         if consensus:
             cami.reset_cami()
 
-    if evaluate or (not consensus and not evaluate) or seed_variation:
+    if evaluate or (not consensus and not evaluate and not seed_variation):
         cami.make_evaluation()
-    
-    cami.reset_cami()
         
     # SEED VARIATION
     if seed_variation:
diff --git a/cami_src/cami_suite.py b/cami_src/cami_suite.py
index 2c88390..889a197 100644
--- a/cami_src/cami_suite.py
+++ b/cami_src/cami_suite.py
@@ -8,6 +8,7 @@ from algorithms.RobustWrapper import RobustWrapper
 from configparser import ConfigParser
 import preprocess
 from consensus import cami_v1, cami_v2, cami_v3
+import matplotlib.pyplot as plt
 
 def initialize_cami(path_to_ppi_file=''):
     cami_params = {}
@@ -197,7 +198,7 @@ class cami():
                 tar_id='entrez',
                 mode='set-set',
                 distance='jaccard',
-                ref=set(self.seed_lst),
+                ref=set(seed_gene_lst),
                 ref_id='entrez')
             
             if set_validation_results['status'] == 'ok':
@@ -218,7 +219,9 @@ class cami():
                 distance='jaccard',
                 network_data={"network_file":ppi_graph_file,
                               "prop_name":"name",
-                              "id_type":"entrez"}
+                              "id_type":"entrez"},
+                ref=set(seed_gene_lst),
+                ref_id='entrez'
                 )
             if sub_validation_results['status'] == 'ok':
                 biodigest.single_validation.save_results(sub_validation_results, f'{result_set}_{self.uid}',
@@ -322,10 +325,10 @@ class cami():
         params = {'hub_pentalty': [0, 0.25, 0.5, 0.75, 1.0], 'damping_factor': [0.1, 0.25, 0.5, 0.75], 'confidence_level': [0.2, 0.35, 0.5, 0.75], 'ranking':["trustrank", "betweenness", "harmonic"], 'function':[cami_v2.run_cami, cami_v3.run_cami]}
 
         camis = {
-            'cami_v1': {'function': cami_v1.run_cami, 'params': {'consens_threshold': consens_threshold}},
             'union': {'function': cami_v1.make_union, 'params': {}},
             'intersection': {'function': cami_v1.make_intersection, 'params': {}},
             'first_neighbors': {'function': cami_v1.make_first_neighbor_result_set, 'params': {}},
+            'cami_v1': {'function': cami_v1.run_cami, 'params': {'consens_threshold': consens_threshold}},
             'cami_v2_param1_tr': {'function': cami_v2.run_cami, 'params': {
                 'hub_penalty': 0.3, 'damping_factor': 0.7, 'confidence_level': 0.5
             }},
@@ -466,17 +469,17 @@ class cami():
         comp_matrix = comparison_matrix.make_comparison_matrix(self.result_gene_sets)
         comp_fig, comp_ax, norm_fig, norm_ax = comparison_matrix.plot_comparison_matrix(comp_matrix,
                                                                                         title=title,
-                                                                                        n_rows=self.nof_tools,)
+                                                                                        n_rows=self.nof_tools)
         comp_fig_file = f'{self.output_dir}/comparison_matrix_{identifier}.png'
         comp_fig.savefig(comp_fig_file, bbox_inches="tight")
         if self.debug:
             print(f'saved comparison matrix in: {comp_fig_file}')
-        norm_fig_file = f'{self.output_dir}/normalized_comparison_matrix_{identifier}.png'
+        norm_fig_file = f'{self.output_dir}/comparison_matrix_{identifier}_normalized.png'
         if self.debug:
             print(f'saved normalized comparison matrix in: {norm_fig_file}')
         norm_fig.savefig(norm_fig_file, bbox_inches="tight")
-        comp_fig.close()
-        norm_fig.close()
+        plt.close(comp_fig)
+        plt.close(norm_fig)
         return comp_fig_file, norm_fig_file
         
     def use_nvenn(self, download=False):
@@ -485,7 +488,8 @@ class cami():
            Returns the URL of the result.
         """
         # visualize with degradome
-        if self.nof_tools < 7:
+        n_venns = len(self.result_gene_sets)
+        if n_venns < 7:
             degradome_sets = {tool: self.result_gene_sets[tool]
                               for tool in self.result_gene_sets
                               if len(self.result_gene_sets[tool]) > 0}
-- 
GitLab