diff --git a/cami_src/evaluation_scripts/seed_variation_script.py b/cami_src/evaluation_scripts/seed_variation_script.py index 8135b79c32b8657ae0e49678848878f09ff805e4..ddfadb035a4084123132f5f83253b6550edc0e78 100644 --- a/cami_src/evaluation_scripts/seed_variation_script.py +++ b/cami_src/evaluation_scripts/seed_variation_script.py @@ -17,7 +17,7 @@ def predict_and_make_consensus(cami, vis=False): if vis: cami.use_nvenn(download=True) -def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=True): +def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=False): identifier = cami.uid base_seeds = cami.origin_seed_lst original_seeds = [cami.ppi_vertex2gene[seed] for seed in base_seeds] @@ -52,11 +52,11 @@ def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=Tru for tool in prediction_tools: res_table.write(f'\t{tool}_msr_ks_pvalue') - with open(os.path.join(cami.tmp_dir, f'{used_tools[0]}_{cami.uid}_relevance_scores.tsv'), 'r') as f: - for line in f: - val_name = line.split('\t')[0] - redisc_table.write(f'\t{val_name}') - res_table.write('\n') + # with open(os.path.join(cami.tmp_dir, f'{used_tools[0]}_{cami.uid}_relevance_scores.tsv'), 'r') as f: + # for line in f: + # val_name = line.split('\t')[0] + # redisc_table.write(f'\t{val_name}') + # res_table.write('\n') # result dictionaries of the form {tool:list(value for each iteration)} tp_rate_dict = {k:list() for k in used_tools} @@ -142,10 +142,10 @@ def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=Tru list(module_size_dict[pred_tool])) res_table.write(f'\t{p_val}') - with open(os.path.join(cami.tmp_dir, f'{tool}_{identifier}_relevance_scores.tsv)'), 'r') as f: - for line in f: - rel_score = line.split('\t')[1].strip() - res_table.write(f'\t{rel_score}') + # with open(os.path.join(cami.tmp_dir, f'{tool}_{identifier}_relevance_scores.tsv)'), 'r') as f: + # for line in f: + # rel_score = line.split('\t')[1].strip() + # res_table.write(f'\t{rel_score}') res_table.write('\n') print(f'Result tables are saved in the following locations:') @@ -171,7 +171,7 @@ def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=Tru #PLOT # Create a figure instance #print(sys.getrecursionlimit()) - fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20)) + fig1, (ax1, ax5) = plt.subplots(2, 1, figsize=(20,20)) fig1.subplots_adjust(left=0.2) # Extract Figure and Axes instance @@ -200,26 +200,26 @@ def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=Tru ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14) - violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True) - for violinpart in list(violins2.keys())[2:]: - violins2[violinpart].set_color('k') - for violin, tool in zip(violins2['bodies'], tools): - if tool in [tw.name for tw in cami.tool_wrappers]: - violin.set_facecolor('tan') - elif tool == 'first_neighbors': - violin.set_facecolor('peachpuff') - elif tool in ['union', 'intersection']: - violin.set_facecolor('orange') - else: - violin.set_facecolor('darkorange') - # Add title - ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14) + # violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True) + # for violinpart in list(violins2.keys())[2:]: + # violins2[violinpart].set_color('k') + # for violin, tool in zip(violins2['bodies'], tools): + # if tool in [tw.name for tw in cami.tool_wrappers]: + # violin.set_facecolor('tan') + # elif tool == 'first_neighbors': + # violin.set_facecolor('peachpuff') + # elif tool in ['union', 'intersection']: + # violin.set_facecolor('orange') + # else: + # violin.set_facecolor('darkorange') + # # Add title + # ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14) - ax4.set_xticks(list(range(1,len(tools)+1))) - ax4.set_xticklabels(tool_labels) - ax4.tick_params(axis='x', labelsize=11) + # ax4.set_xticks(list(range(1,len(tools)+1))) + # ax4.set_xticklabels(tool_labels) + # ax4.tick_params(axis='x', labelsize=11) - ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14) + # ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14) violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True) # Add title