Skip to content
Snippets Groups Projects
Commit db0bbc80 authored by Le, Mia's avatar Le, Mia
Browse files

Merged seed_variation plotting functions with main

parent 9ae6ae2c
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,8 @@ import argparse
import webbrowser
import random
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
#MC:
from configparser import ConfigParser
......@@ -21,7 +23,7 @@ from configparser import ConfigParser
def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
output_dir, identifier, save_temps, nvenn, save_image, force, drugstone, ncbi, configuration,
seed_variation, parallelization, external_results):
seed_variation, parallelization, external_results, debug):
print('CAMI started')
config = ConfigParser()
config.read(configuration)
......@@ -55,6 +57,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
# rename seeds with their corresponding indices in the ppi_graph
name2index = preprocess.name2index_dict(ppi_graph)
checked_seed_lst = [name2index[v] for v in seed_lst if v in name2index]
if debug:
print('Created the PPI-network Graph and seed list.')
n = len(checked_seed_lst)
m = len(seed_lst)
......@@ -69,12 +72,13 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
exit(1)
seed_lst = checked_seed_lst.copy()
print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
if debug: print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds')
# change directory to ~/cami/cami (home of cami.py)
cami_home = sys.argv[0].rsplit('/', 1)
os.chdir(cami_home[0])
home_path = os.path.dirname(os.getcwd())
if debug:
print(f"Home directory of cami: {home_path}")
if identifier==None:
identifier = str(uuid.uuid4())
......@@ -82,6 +86,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
if output_dir==None:
output_dir = os.path.join(home_path, f'data/output/{identifier}')
output_dir = os.path.abspath(output_dir)
if debug:
print(f"Output directory of cami: {output_dir}")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......@@ -91,6 +96,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
#create temporary directory
tmp_dir = os.path.join(home_path, f'data/tmp/{identifier}')
if debug:
print(f'Creating unique temporary directory for CAMI: {tmp_dir}')
if not os.path.exists(tmp_dir):
......@@ -115,6 +121,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
if ncbi:
cami.ncbi = True
if debug:
cami.debug = True
for tool in tool_wrappers:
cami.initialize_tool(tool)
......@@ -126,7 +135,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
cami.set_initial_seed_lst(initial_seedlst)
if consensus or (not consensus and not evaluate and not seed_variation):
if (consensus or (not consensus and not evaluate and not seed_variation)) and not seed_variation:
result_sets = cami.make_predictions()
if len(external_results) > 0:
assert(len(ext_wrappers) == len(external_results))
......@@ -150,7 +159,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
cami.download_diagram(url)
if drugstone is not None:
print('Sending results to Drugst.One')
print('Sending results to DrugstOne')
cami.use_drugstone()
if consensus:
cami.reset_cami()
......@@ -165,11 +174,37 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
# SEED VARIATION
if seed_variation:
def predict_and_make_consensus(vis=False):
result_sets = cami.make_predictions()
cami.create_consensus(result_sets)
if vis:
n_results = len(cami.result_gene_sets)
comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)],
columns = list(cami.result_gene_sets.keys()),
index = list(cami.result_gene_sets.keys()),
dtype=int)
for algo1 in cami.result_gene_sets:
for algo2 in cami.result_gene_sets:
comparison_matrix.loc[algo1,algo2] = int(len(cami.result_gene_sets[algo1].intersection(cami.result_gene_sets[algo2])))
fig2, ax3 = plt.subplots(figsize=(20,20))
ax3 = sb.heatmap(comparison_matrix, annot=True, fmt='g')
ax3.set_title('Intersections of result_gene_sets of all analyzed algorithms.')
fig2.savefig(f'{output_dir}/heatmap_{cami.uid}.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}.png')
plt.close(fig2)
fig2a, ax3a = plt.subplots(figsize=(20,20))
comparison_matrix_normalized = comparison_matrix.apply(lambda row: row/row.max(), axis=1)
ax3a = sb.heatmap(comparison_matrix_normalized, annot=True, fmt='.2f')
ax3a.set_title('Normalized intersections of result_gene_sets of all analyzed algorithms.')
fig2a.savefig(f'{output_dir}/heatmap_{cami.uid}_normalized.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}_normalized.png')
plt.close(fig2a)
if nvenn and vis:
url = cami.nvenn()
url = cami.use_nvenn()
cami.download_diagram(url)
with open(f'{output_dir}/00_node_degrees.tsv', 'w') as node_degrees:
......@@ -187,7 +222,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
seedname = seeds
for tool in cami.result_gene_sets:
f.write(f'\n{seedname}\t{len(cami.seed_lst)}\t{tool}\t{len(cami.result_gene_sets[tool])}')
cami.make_evaluation()
#predict_and_make_consensus(vis=True)
random.seed(50)
......@@ -201,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv'
result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv'
result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv'
n_results = len(cami.result_gene_sets)
redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)],
columns = list(cami.result_gene_sets.keys()),
index = list(cami.result_gene_sets.keys()),
dtype=int)
with open(result_file_1, 'w') as res_table1:
with open(result_file_2, 'w') as res_table2:
with open(result_file_3, 'w') as res_table3:
......@@ -249,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
for tool in result_dict:
nof_predictions = len(result_dict[tool]) + len(used_seeds)
redisc_seeds = list(set(result_dict[tool]).intersection(set(rem_seeds)))
redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds))
redisc_prev = len(redisc_seeds)
redisc_rate = redisc_prev / nof_removals
redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen
......@@ -263,43 +304,54 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
res_table1.write('\t')
for idx,seed in enumerate(redisc_seeds):
if idx == 0:
res_table1.write(f'{redisc_seeds[0]}')
res_table1.write(f'{list(redisc_seeds)[0]}')
else:
res_table1.write(f',{seed}')
print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.')
res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n')
res_table1.write('\n')
variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds))
for algo1 in redisc_seeds_dict:
for algo2 in redisc_rate_dict:
redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2]))
print(f'Result tables are saved in the following locations:')
print(f'Rediscovered seeds: {result_file_1}')
print(f'Rediscovery Rates: {result_file_2}')
print(f'Sensitivity: {result_file_3}')
fig3, ax6 = plt.subplots(figsize=(20,20))
ax6 = sb.heatmap(redisc_intersection_matrix, annot=True, fmt='g')
ax6.set_title('Number of times the algorithms rediscovered the same seeds')
fig3.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_{identifier}.png')
plt.close(fig3)
fig3a, ax6a = plt.subplots(figsize=(20,20))
redisc_intersection_matrix_normalized = redisc_intersection_matrix.apply(lambda row: row/row.max(), axis=1)
ax6a = sb.heatmap(redisc_intersection_matrix_normalized, annot=True, fmt='.2f')
ax6a.set_title('Normalized numbers of times the algorithms rediscovered the same seeds')
fig3a.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}_normalized.png', bbox_inches="tight")
print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_normalized_{cami.uid}.png')
plt.close(fig3a)
# print(variation_results)
rediscovery_rates_results = [results[0] for results in variation_results]
# print(rediscovery_rates_results)
tools = [tool for tool in rediscovery_rates_results[0].keys()]
tool_labels = tools.copy()
redisc_rates = [[res[tool] for res in rediscovery_rates_results] for tool in tools]
for idx,tool in enumerate(tool_labels):
for idx,tool in enumerate(tools):
if '_' in tool:
# find the index of the second occurrence of the character
second_occurrence_index = tool.find('_', tool.find('_') + 1)
if second_occurrence_index > -1:
# replace the character at that index with the replacement character
tool_name = tool[:second_occurrence_index] + '\n' + tool[second_occurrence_index + 1:]
tool_labels[idx] = tool_name
#PLOT
# Create a figure instance
fig, axs = plt.subplots(2, 1, figsize=(20,15))
plt.subplots_adjust(left=0.2)
ax1 = axs[0]
ax5 = axs[1]
#print(sys.getrecursionlimit())
fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20))
fig1.subplots_adjust(left=0.2)
# Extract Figure and Axes instance
# Create a plot
......@@ -307,6 +359,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
for violinpart in list(violins1.keys())[2:]:
violins1[violinpart].set_color('k')
for violin in violins1['bodies']:
violin.set_facecolor('red')
# Add title
......@@ -319,19 +372,19 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14)
# ax4 = plt.subplot(1,3,2, label='ax4')
# violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
# for violinpart in list(violins2.keys())[2:]:
# violins2[violinpart].set_color('k')
# for violin in violins2['bodies']:
# violin.set_facecolor('orange')
# # Add title
# ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds\n{nof_iterations} times from {identifier} seeds.', wrap=True)
violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
for violinpart in list(violins2.keys())[2:]:
violins2[violinpart].set_color('k')
for violin in violins2['bodies']:
violin.set_facecolor('orange')
# Add title
ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
# ax4.set_xticks(list(range(1,len(tools)+1)))
# ax4.set_xticklabels(tools)
ax4.set_xticks(list(range(1,len(tools)+1)))
ax4.set_xticklabels(tool_labels)
ax4.tick_params(axis='x', labelsize=11)
# ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True)
ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14)
violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True)
# Add title
......@@ -345,11 +398,11 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate,
ax5.set_xticks(list(range(1,len(tools)+1)))
ax5.set_xticklabels(tool_labels)
ax5.set_ylabel('(<rediscovered seeds>/<module size>)', fontsize=14)
ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14)
ax5.tick_params(axis='x', labelsize=11)
plt.tight_layout()
fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
fig1.tight_layout()
fig1.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
plt.close(fig1)
print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png')
if save_temps:
......@@ -401,6 +454,8 @@ if __name__ == "__main__":
help="Choose a configuration for the static variables.")
parser.add_argument('-p', '--parallelization', action='store_true',
help="run the tools for prediction parallelized")
parser.add_argument('-db', '--debug', action='store_true',
help="run CAMI with verbose outputs")
#TODO List with additional arguments if needed by certain tools
args = vars(parser.parse_args())
......@@ -421,4 +476,5 @@ if __name__ == "__main__":
for tool in args['tools']:
if tool not in implemented_tools:
raise RuntimeError(tool + f' not implemented yet. Implemented tools are: {implemented_tools}')
main(**args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment