diff --git a/cami_src/cami.py b/cami_src/cami.py index 12cf3ef55e38fd9e991f6ff30bf41847a43696fa..9d08515dce066b941b62ae4747f497ae2604971f 100755 --- a/cami_src/cami.py +++ b/cami_src/cami.py @@ -14,6 +14,8 @@ import argparse import webbrowser import random import matplotlib.pyplot as plt +import seaborn as sb +import pandas as pd #MC: from configparser import ConfigParser @@ -172,9 +174,35 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, # SEED VARIATION if seed_variation: + def predict_and_make_consensus(vis=False): result_sets = cami.make_predictions() cami.create_consensus(result_sets) + if vis: + n_results = len(cami.result_gene_sets) + comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)], + columns = list(cami.result_gene_sets.keys()), + index = list(cami.result_gene_sets.keys()), + dtype=int) + for algo1 in cami.result_gene_sets: + for algo2 in cami.result_gene_sets: + comparison_matrix.loc[algo1,algo2] = int(len(cami.result_gene_sets[algo1].intersection(cami.result_gene_sets[algo2]))) + + fig2, ax3 = plt.subplots(figsize=(20,20)) + + ax3 = sb.heatmap(comparison_matrix, annot=True, fmt='g') + ax3.set_title('Intersections of result_gene_sets of all analyzed algorithms.') + fig2.savefig(f'{output_dir}/heatmap_{cami.uid}.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}.png') + plt.close(fig2) + + fig2a, ax3a = plt.subplots(figsize=(20,20)) + comparison_matrix_normalized = comparison_matrix.apply(lambda row: row/row.max(), axis=1) + ax3a = sb.heatmap(comparison_matrix_normalized, annot=True, fmt='.2f') + ax3a.set_title('Normalized intersections of result_gene_sets of all analyzed algorithms.') + fig2a.savefig(f'{output_dir}/heatmap_{cami.uid}_normalized.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}_normalized.png') + plt.close(fig2a) if nvenn and vis: url = cami.use_nvenn() cami.download_diagram(url) @@ -208,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv' result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv' result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv' + n_results = len(cami.result_gene_sets) + + redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)], + columns = list(cami.result_gene_sets.keys()), + index = list(cami.result_gene_sets.keys()), + dtype=int) with open(result_file_1, 'w') as res_table1: with open(result_file_2, 'w') as res_table2: with open(result_file_3, 'w') as res_table3: @@ -256,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, for tool in result_dict: nof_predictions = len(result_dict[tool]) + len(used_seeds) - redisc_seeds = list(set(result_dict[tool]).intersection(set(rem_seeds))) + redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds)) redisc_prev = len(redisc_seeds) redisc_rate = redisc_prev / nof_removals redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen @@ -270,18 +304,34 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, res_table1.write('\t') for idx,seed in enumerate(redisc_seeds): if idx == 0: - res_table1.write(f'{redisc_seeds[0]}') + res_table1.write(f'{list(redisc_seeds)[0]}') else: res_table1.write(f',{seed}') print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.') res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n') res_table1.write('\n') variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds)) - + for algo1 in redisc_seeds_dict: + for algo2 in redisc_rate_dict: + redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2])) print(f'Result tables are saved in the following locations:') print(f'Rediscovered seeds: {result_file_1}') print(f'Rediscovery Rates: {result_file_2}') print(f'Sensitivity: {result_file_3}') + fig3, ax6 = plt.subplots(figsize=(20,20)) + ax6 = sb.heatmap(redisc_intersection_matrix, annot=True, fmt='g') + ax6.set_title('Number of times the algorithms rediscovered the same seeds') + fig3.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_{identifier}.png') + plt.close(fig3) + fig3a, ax6a = plt.subplots(figsize=(20,20)) + redisc_intersection_matrix_normalized = redisc_intersection_matrix.apply(lambda row: row/row.max(), axis=1) + ax6a = sb.heatmap(redisc_intersection_matrix_normalized, annot=True, fmt='.2f') + ax6a.set_title('Normalized numbers of times the algorithms rediscovered the same seeds') + fig3a.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}_normalized.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_normalized_{cami.uid}.png') + plt.close(fig3a) + # print(variation_results) rediscovery_rates_results = [results[0] for results in variation_results] # print(rediscovery_rates_results) @@ -300,8 +350,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, #PLOT # Create a figure instance #print(sys.getrecursionlimit()) - fig, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20)) - fig.subplots_adjust(left=0.2) + fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20)) + fig1.subplots_adjust(left=0.2) # Extract Figure and Axes instance # Create a plot @@ -350,9 +400,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14) ax5.tick_params(axis='x', labelsize=11) - fig.tight_layout() - fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight") - plt.close(fig) + fig1.tight_layout() + fig1.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight") + plt.close(fig1) print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png') if save_temps: