Skip to content
Snippets Groups Projects
Commit 2f22b677 authored by Le, Mia's avatar Le, Mia
Browse files

added heatmap plots for intersection analysis

parent 126fa355
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,8 @@ import argparse ...@@ -14,6 +14,8 @@ import argparse
import webbrowser import webbrowser
import random import random
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
#MC: #MC:
from configparser import ConfigParser from configparser import ConfigParser
...@@ -172,9 +174,35 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -172,9 +174,35 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
# SEED VARIATION # SEED VARIATION
if seed_variation: if seed_variation:
def predict_and_make_consensus(vis=False): def predict_and_make_consensus(vis=False):
result_sets = cami.make_predictions() result_sets = cami.make_predictions()
cami.create_consensus(result_sets) cami.create_consensus(result_sets)
if vis:
n_results = len(cami.result_gene_sets)
comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)],
columns = list(cami.result_gene_sets.keys()),
index = list(cami.result_gene_sets.keys()),
dtype=int)
for algo1 in cami.result_gene_sets:
for algo2 in cami.result_gene_sets:
comparison_matrix.loc[algo1,algo2] = int(len(cami.result_gene_sets[algo1].intersection(cami.result_gene_sets[algo2])))
fig2, ax3 = plt.subplots(figsize=(20,20))
ax3 = sb.heatmap(comparison_matrix, annot=True, fmt='g')
ax3.set_title('Intersections of result_gene_sets of all analyzed algorithms.')
fig2.savefig(f'{output_dir}/heatmap_{cami.uid}.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}.png')
plt.close(fig2)
fig2a, ax3a = plt.subplots(figsize=(20,20))
comparison_matrix_normalized = comparison_matrix.apply(lambda row: row/row.max(), axis=1)
ax3a = sb.heatmap(comparison_matrix_normalized, annot=True, fmt='.2f')
ax3a.set_title('Normalized intersections of result_gene_sets of all analyzed algorithms.')
fig2a.savefig(f'{output_dir}/heatmap_{cami.uid}_normalized.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}_normalized.png')
plt.close(fig2a)
if nvenn and vis: if nvenn and vis:
url = cami.use_nvenn() url = cami.use_nvenn()
cami.download_diagram(url) cami.download_diagram(url)
...@@ -208,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -208,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv' result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv'
result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv' result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv'
result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv' result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv'
n_results = len(cami.result_gene_sets)
redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)],
columns = list(cami.result_gene_sets.keys()),
index = list(cami.result_gene_sets.keys()),
dtype=int)
with open(result_file_1, 'w') as res_table1: with open(result_file_1, 'w') as res_table1:
with open(result_file_2, 'w') as res_table2: with open(result_file_2, 'w') as res_table2:
with open(result_file_3, 'w') as res_table3: with open(result_file_3, 'w') as res_table3:
...@@ -256,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -256,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
for tool in result_dict: for tool in result_dict:
nof_predictions = len(result_dict[tool]) + len(used_seeds) nof_predictions = len(result_dict[tool]) + len(used_seeds)
redisc_seeds = list(set(result_dict[tool]).intersection(set(rem_seeds))) redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds))
redisc_prev = len(redisc_seeds) redisc_prev = len(redisc_seeds)
redisc_rate = redisc_prev / nof_removals redisc_rate = redisc_prev / nof_removals
redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen
...@@ -270,18 +304,34 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -270,18 +304,34 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
res_table1.write('\t') res_table1.write('\t')
for idx,seed in enumerate(redisc_seeds): for idx,seed in enumerate(redisc_seeds):
if idx == 0: if idx == 0:
res_table1.write(f'{redisc_seeds[0]}') res_table1.write(f'{list(redisc_seeds)[0]}')
else: else:
res_table1.write(f',{seed}') res_table1.write(f',{seed}')
print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.') print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.')
res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n') res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n')
res_table1.write('\n') res_table1.write('\n')
variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds)) variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds))
for algo1 in redisc_seeds_dict:
for algo2 in redisc_rate_dict:
redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2]))
print(f'Result tables are saved in the following locations:') print(f'Result tables are saved in the following locations:')
print(f'Rediscovered seeds: {result_file_1}') print(f'Rediscovered seeds: {result_file_1}')
print(f'Rediscovery Rates: {result_file_2}') print(f'Rediscovery Rates: {result_file_2}')
print(f'Sensitivity: {result_file_3}') print(f'Sensitivity: {result_file_3}')
fig3, ax6 = plt.subplots(figsize=(20,20))
ax6 = sb.heatmap(redisc_intersection_matrix, annot=True, fmt='g')
ax6.set_title('Number of times the algorithms rediscovered the same seeds')
fig3.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_{identifier}.png')
plt.close(fig3)
fig3a, ax6a = plt.subplots(figsize=(20,20))
redisc_intersection_matrix_normalized = redisc_intersection_matrix.apply(lambda row: row/row.max(), axis=1)
ax6a = sb.heatmap(redisc_intersection_matrix_normalized, annot=True, fmt='.2f')
ax6a.set_title('Normalized numbers of times the algorithms rediscovered the same seeds')
fig3a.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}_normalized.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_normalized_{cami.uid}.png')
plt.close(fig3a)
# print(variation_results) # print(variation_results)
rediscovery_rates_results = [results[0] for results in variation_results] rediscovery_rates_results = [results[0] for results in variation_results]
# print(rediscovery_rates_results) # print(rediscovery_rates_results)
...@@ -300,8 +350,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -300,8 +350,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
#PLOT #PLOT
# Create a figure instance # Create a figure instance
#print(sys.getrecursionlimit()) #print(sys.getrecursionlimit())
fig, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20)) fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20))
fig.subplots_adjust(left=0.2) fig1.subplots_adjust(left=0.2)
# Extract Figure and Axes instance # Extract Figure and Axes instance
# Create a plot # Create a plot
...@@ -350,9 +400,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ...@@ -350,9 +400,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14) ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14)
ax5.tick_params(axis='x', labelsize=11) ax5.tick_params(axis='x', labelsize=11)
fig.tight_layout() fig1.tight_layout()
fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight") fig1.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
plt.close(fig) plt.close(fig1)
print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png') print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png')
if save_temps: if save_temps:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment