From db0bbc80ac5cbb22aa99b98ed7be857e257fdef7 Mon Sep 17 00:00:00 2001 From: bay9355 <mia.le@studium.uni-hamburg.de> Date: Tue, 14 Mar 2023 15:48:21 +0100 Subject: [PATCH] Merged seed_variation plotting functions with main --- cami_src/cami.py | 134 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 95 insertions(+), 39 deletions(-) diff --git a/cami_src/cami.py b/cami_src/cami.py index d5c1f58..4b6da42 100755 --- a/cami_src/cami.py +++ b/cami_src/cami.py @@ -14,6 +14,8 @@ import argparse import webbrowser import random import matplotlib.pyplot as plt +import seaborn as sb +import pandas as pd #MC: from configparser import ConfigParser @@ -21,7 +23,7 @@ from configparser import ConfigParser def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, output_dir, identifier, save_temps, nvenn, save_image, force, drugstone, ncbi, configuration, - seed_variation, parallelization, external_results): + seed_variation, parallelization, external_results, debug): print('CAMI started') config = ConfigParser() config.read(configuration) @@ -55,7 +57,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, # rename seeds with their corresponding indices in the ppi_graph name2index = preprocess.name2index_dict(ppi_graph) checked_seed_lst = [name2index[v] for v in seed_lst if v in name2index] - print('Created the PPI-network Graph and seed list.') + if debug: + print('Created the PPI-network Graph and seed list.') n = len(checked_seed_lst) m = len(seed_lst) if n != m and not force: @@ -69,20 +72,22 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, exit(1) seed_lst = checked_seed_lst.copy() - print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds') + if debug: print(f'Continuing with vertices at indices {[int(seed) for seed in seed_lst]} as input seeds') # change directory to ~/cami/cami (home of cami.py) cami_home = sys.argv[0].rsplit('/', 1) os.chdir(cami_home[0]) home_path = os.path.dirname(os.getcwd()) - print(f"Home directory of cami: {home_path}") + if debug: + print(f"Home directory of cami: {home_path}") if identifier==None: identifier = str(uuid.uuid4()) if output_dir==None: output_dir = os.path.join(home_path, f'data/output/{identifier}') output_dir = os.path.abspath(output_dir) - print(f"Output directory of cami: {output_dir}") + if debug: + print(f"Output directory of cami: {output_dir}") if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -91,7 +96,8 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, #create temporary directory tmp_dir = os.path.join(home_path, f'data/tmp/{identifier}') - print(f'Creating unique temporary directory for CAMI: {tmp_dir}') + if debug: + print(f'Creating unique temporary directory for CAMI: {tmp_dir}') if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) @@ -114,6 +120,9 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, if ncbi: cami.ncbi = True + + if debug: + cami.debug = True for tool in tool_wrappers: cami.initialize_tool(tool) @@ -126,7 +135,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, cami.set_initial_seed_lst(initial_seedlst) - if consensus or (not consensus and not evaluate and not seed_variation): + if (consensus or (not consensus and not evaluate and not seed_variation)) and not seed_variation: result_sets = cami.make_predictions() if len(external_results) > 0: assert(len(ext_wrappers) == len(external_results)) @@ -150,7 +159,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, cami.download_diagram(url) if drugstone is not None: - print('Sending results to Drugst.One') + print('Sending results to DrugstOne') cami.use_drugstone() if consensus: cami.reset_cami() @@ -165,11 +174,37 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, # SEED VARIATION if seed_variation: + def predict_and_make_consensus(vis=False): result_sets = cami.make_predictions() cami.create_consensus(result_sets) + if vis: + n_results = len(cami.result_gene_sets) + comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)], + columns = list(cami.result_gene_sets.keys()), + index = list(cami.result_gene_sets.keys()), + dtype=int) + for algo1 in cami.result_gene_sets: + for algo2 in cami.result_gene_sets: + comparison_matrix.loc[algo1,algo2] = int(len(cami.result_gene_sets[algo1].intersection(cami.result_gene_sets[algo2]))) + + fig2, ax3 = plt.subplots(figsize=(20,20)) + + ax3 = sb.heatmap(comparison_matrix, annot=True, fmt='g') + ax3.set_title('Intersections of result_gene_sets of all analyzed algorithms.') + fig2.savefig(f'{output_dir}/heatmap_{cami.uid}.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}.png') + plt.close(fig2) + + fig2a, ax3a = plt.subplots(figsize=(20,20)) + comparison_matrix_normalized = comparison_matrix.apply(lambda row: row/row.max(), axis=1) + ax3a = sb.heatmap(comparison_matrix_normalized, annot=True, fmt='.2f') + ax3a.set_title('Normalized intersections of result_gene_sets of all analyzed algorithms.') + fig2a.savefig(f'{output_dir}/heatmap_{cami.uid}_normalized.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under: {output_dir}/heatmap_{cami.uid}_normalized.png') + plt.close(fig2a) if nvenn and vis: - url = cami.nvenn() + url = cami.use_nvenn() cami.download_diagram(url) with open(f'{output_dir}/00_node_degrees.tsv', 'w') as node_degrees: @@ -187,7 +222,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, seedname = seeds for tool in cami.result_gene_sets: f.write(f'\n{seedname}\t{len(cami.seed_lst)}\t{tool}\t{len(cami.result_gene_sets[tool])}') - + cami.make_evaluation() #predict_and_make_consensus(vis=True) random.seed(50) @@ -201,6 +236,12 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, result_file_2 = f'{output_dir}/00_seedvariation_results2.tsv' result_file_3 = f'{output_dir}/00_seedvariation_results3.tsv' result_file_4 = f'{output_dir}/00_seedvariation_results4.tsv' + n_results = len(cami.result_gene_sets) + + redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)], + columns = list(cami.result_gene_sets.keys()), + index = list(cami.result_gene_sets.keys()), + dtype=int) with open(result_file_1, 'w') as res_table1: with open(result_file_2, 'w') as res_table2: with open(result_file_3, 'w') as res_table3: @@ -249,7 +290,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, for tool in result_dict: nof_predictions = len(result_dict[tool]) + len(used_seeds) - redisc_seeds = list(set(result_dict[tool]).intersection(set(rem_seeds))) + redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds)) redisc_prev = len(redisc_seeds) redisc_rate = redisc_prev / nof_removals redisc_rate_dict[tool] = redisc_rate #true positive rate: rate verhältnis von gefundenen und allgemein gefundenen @@ -263,43 +304,54 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, res_table1.write('\t') for idx,seed in enumerate(redisc_seeds): if idx == 0: - res_table1.write(f'{redisc_seeds[0]}') + res_table1.write(f'{list(redisc_seeds)[0]}') else: res_table1.write(f',{seed}') print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.') res_table2.write(f'{ident}\t{tool}\t{nof_removals},\t{rem_seeds}\t{used_seeds}\t{redisc_prev}\t{redisc_rate}\t{redisc_seeds}\n') res_table1.write('\n') variation_results.append((redisc_rate_dict, redisc_seeds_dict, used_seeds, rem_seeds)) - + for algo1 in redisc_seeds_dict: + for algo2 in redisc_rate_dict: + redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2])) print(f'Result tables are saved in the following locations:') print(f'Rediscovered seeds: {result_file_1}') print(f'Rediscovery Rates: {result_file_2}') print(f'Sensitivity: {result_file_3}') + fig3, ax6 = plt.subplots(figsize=(20,20)) + ax6 = sb.heatmap(redisc_intersection_matrix, annot=True, fmt='g') + ax6.set_title('Number of times the algorithms rediscovered the same seeds') + fig3.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_{identifier}.png') + plt.close(fig3) + fig3a, ax6a = plt.subplots(figsize=(20,20)) + redisc_intersection_matrix_normalized = redisc_intersection_matrix.apply(lambda row: row/row.max(), axis=1) + ax6a = sb.heatmap(redisc_intersection_matrix_normalized, annot=True, fmt='.2f') + ax6a.set_title('Normalized numbers of times the algorithms rediscovered the same seeds') + fig3a.savefig(f'{output_dir}/00_seedvariation_heatmap_{identifier}_normalized.png', bbox_inches="tight") + print(f'saved intersection heatmap of all algorithms under:{output_dir}/heatmap_normalized_{cami.uid}.png') + plt.close(fig3a) + # print(variation_results) rediscovery_rates_results = [results[0] for results in variation_results] # print(rediscovery_rates_results) tools = [tool for tool in rediscovery_rates_results[0].keys()] tool_labels = tools.copy() redisc_rates = [[res[tool] for res in rediscovery_rates_results] for tool in tools] - for idx,tool in enumerate(tool_labels): + for idx,tool in enumerate(tools): if '_' in tool: # find the index of the second occurrence of the character second_occurrence_index = tool.find('_', tool.find('_') + 1) if second_occurrence_index > -1: # replace the character at that index with the replacement character tool_name = tool[:second_occurrence_index] + '\n' + tool[second_occurrence_index + 1:] - - tool_labels[idx] = tool_name - + tool_labels[idx] = tool_name #PLOT - # Create a figure instance - fig, axs = plt.subplots(2, 1, figsize=(20,15)) - plt.subplots_adjust(left=0.2) - - ax1 = axs[0] - ax5 = axs[1] + #print(sys.getrecursionlimit()) + fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20)) + fig1.subplots_adjust(left=0.2) # Extract Figure and Axes instance # Create a plot @@ -307,6 +359,7 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, for violinpart in list(violins1.keys())[2:]: violins1[violinpart].set_color('k') + for violin in violins1['bodies']: violin.set_facecolor('red') # Add title @@ -319,19 +372,19 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14) - # ax4 = plt.subplot(1,3,2, label='ax4') - # violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True) - # for violinpart in list(violins2.keys())[2:]: - # violins2[violinpart].set_color('k') - # for violin in violins2['bodies']: - # violin.set_facecolor('orange') - # # Add title - # ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds\n{nof_iterations} times from {identifier} seeds.', wrap=True) + violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True) + for violinpart in list(violins2.keys())[2:]: + violins2[violinpart].set_color('k') + for violin in violins2['bodies']: + violin.set_facecolor('orange') + # Add title + ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14) - # ax4.set_xticks(list(range(1,len(tools)+1))) - # ax4.set_xticklabels(tools) + ax4.set_xticks(list(range(1,len(tools)+1))) + ax4.set_xticklabels(tool_labels) + ax4.tick_params(axis='x', labelsize=11) - # ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True) + ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14) violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True) # Add title @@ -345,11 +398,11 @@ def main(ppi_network, seeds, tools, tool_weights, consensus, evaluate, ax5.set_xticks(list(range(1,len(tools)+1))) ax5.set_xticklabels(tool_labels) - ax5.set_ylabel('(<rediscovered seeds>/<module size>)', fontsize=14) + ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14) ax5.tick_params(axis='x', labelsize=11) - plt.tight_layout() - fig.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight") - + fig1.tight_layout() + fig1.savefig(f'{output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight") + plt.close(fig1) print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png') if save_temps: @@ -401,6 +454,8 @@ if __name__ == "__main__": help="Choose a configuration for the static variables.") parser.add_argument('-p', '--parallelization', action='store_true', help="run the tools for prediction parallelized") + parser.add_argument('-db', '--debug', action='store_true', + help="run CAMI with verbose outputs") #TODO List with additional arguments if needed by certain tools args = vars(parser.parse_args()) @@ -412,7 +467,7 @@ if __name__ == "__main__": # add name for implemented tools here: implemented_tools = [ - "domino", + "domino", "diamond", "robust", "hotnet" @@ -421,4 +476,5 @@ if __name__ == "__main__": for tool in args['tools']: if tool not in implemented_tools: raise RuntimeError(tool + f' not implemented yet. Implemented tools are: {implemented_tools}') + main(**args) -- GitLab