Skip to content
Snippets Groups Projects
Commit e69d61af authored by Le, Mia's avatar Le, Mia
Browse files

added seedvar script

parent 0bf2a6d5
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
import os
import random
from cami_suite import cami
import utils.comparison_matrix as comparison_matrix
import numpy as np
from utils import kolmogorov_smirnoff
def predict_and_make_consensus(cami, vis=False):
result_sets = cami.make_predictions()
cami.create_consensus(result_sets, save_output=False)
if vis:
n_results = len(cami.result_gene_sets)
cami.visualize_and_save_comparison_matrix()
if vis:
cami.use_nvenn(download=True)
def make_seedvariation(cami, n_iterations, removal_frac=0.2, vis=False, plot=True):
identifier = cami.uid
base_seeds = cami.origin_seed_lst
original_seeds = [cami.ppi_vertex2gene[seed] for seed in base_seeds]
print(f'All given seeds:{original_seeds}')
random.seed(50)
removal_frac = removal_frac
nof_iterations = int(n_iterations)
used_tools = list(cami.result_gene_sets.keys())
prediction_tools = cami.prediction_tools
nof_seeds = len(base_seeds)
nof_removals = max([int(nof_seeds * removal_frac), 1])
redisc_seeds_file = f'{cami.output_dir}/00_seedvariation_rediscovered_seeds.tsv'
result_table_file = f'{cami.output_dir}/00_seedvariation_result_table.tsv'
n_results = len(cami.result_gene_sets)
redisc_intersection_matrix = pd.DataFrame([[0 for _ in range(n_results)] for __ in range(n_results)],
columns = list(cami.result_gene_sets.keys()),
index = list(cami.result_gene_sets.keys()),
dtype=int)
with open(redisc_seeds_file, 'w') as redisc_table:
with open(result_table_file, 'w') as res_table:
redisc_table.write('id')
for tool in used_tools:
redisc_table.write(f'\t{tool}')
redisc_table.write('\n')
res_table.write('tool\trdr\trdr_std\tsensitivity\tsensitivity_std\tprecision\tprecision_std')
for tool in prediction_tools:
res_table.write(f'\t{tool}_rdr_ks_pvalue')
for tool in prediction_tools:
res_table.write(f'\t{tool}_msr_ks_pvalue')
with open(os.path.join(cami.tmp_dir, f'{used_tools[0]}_{cami.uid}_relevance_scores.tsv)'), 'r') as f:
for line in f:
val_name = line.split('\t')[0]
redisc_table.write(f'\t{val_name}')
res_table.write('\n')
# result dictionaries of the form {tool:list(value for each iteration)}
tp_rate_dict = {k:list() for k in used_tools}
redisc_rate_dict = {k:list() for k in used_tools}
module_size_dict = {k:list() for k in used_tools}
# removed and used seeds per iteration
all_removed_seeds = list()
all_used_seeds = list()
all_redisc_seeds = []
for ident in range(nof_iterations):
redisc_table.write(f'{ident}')
# update uid
new_identifier = identifier + f'_{ident}'
# reset cami
cami.reset_cami(new_uid=new_identifier)
# cami.ppi_graph = original_ppi
#remove seeds (again)
print(f'Removing {nof_removals} seeds from the original seed list...')
removed_seeds_idx = random.sample(list(range(nof_seeds)), nof_removals)
removed_seeds = cami.remove_seeds(removed_seeds_idx)
rem_seeds = [cami.ppi_vertex2gene[seed] for seed in removed_seeds]
print(f'Removed: {rem_seeds} from the seed list')
print('Updating tools and repeat CAMI')
# reinitialize tools
cami.initialize_all_tools()
# repeat consensus
if ident%20==0:
predict_and_make_consensus(cami)
else:
predict_and_make_consensus(cami)
used_seeds = [cami.ppi_vertex2gene[seed] for seed in cami.seed_lst]
redisc_seeds_dict = {}
result_dict = cami.result_gene_sets
for tool in result_dict:
nof_predictions = len(result_dict[tool]) + len(used_seeds)
redisc_seeds = set(result_dict[tool]).intersection(set(rem_seeds))
redisc_prev = len(redisc_seeds)
redisc_rate = redisc_prev / nof_removals
redisc_rate_dict[tool].append(redisc_rate)
redisc_seeds_dict[tool] = redisc_seeds
tp_rate = redisc_prev / len(removed_seeds)
tp_rate_dict[tool].append(tp_rate)
module_size_frac = redisc_prev / nof_predictions
assert module_size_frac <= 1
module_size_dict[tool].append(module_size_frac)
redisc_table.write('\t')
for idx,seed in enumerate(redisc_seeds):
if idx == 0:
redisc_table.write(f'{list(redisc_seeds)[0]}')
else:
redisc_table.write(f',{seed}')
print(f'{tool} rediscovered {redisc_seeds} after removing {rem_seeds}.')
all_redisc_seeds.append(redisc_seeds_dict)
redisc_table.write('\n')
all_used_seeds.append(used_seeds)
all_removed_seeds.append(rem_seeds)
for algo1 in redisc_seeds_dict:
for algo2 in redisc_seeds_dict:
redisc_intersection_matrix.loc[algo1,algo2] += len(redisc_seeds_dict[algo1].intersection(redisc_seeds_dict[algo2]))
for tool in redisc_rate_dict:
res_table.write(f'{tool}\t')
res_table.write(f'{np.mean(redisc_rate_dict[tool])}\t')
res_table.write(f'{np.std(redisc_rate_dict[tool])}\t')
res_table.write(f'{np.mean(tp_rate_dict[tool])}\t')
res_table.write(f'{np.std(tp_rate_dict[tool])}\t')
res_table.write(f'{np.mean(module_size_dict[tool])}\t')
res_table.write(f'{np.std(module_size_dict[tool])}')
for pred_tool in prediction_tools:
p_val = kolmogorov_smirnoff.calculate_ks_p_value(list(redisc_rate_dict[tool]),
list(redisc_rate_dict[pred_tool]))
res_table.write(f'\t{p_val}')
for pred_tool in prediction_tools:
p_val = kolmogorov_smirnoff.calculate_ks_p_value(list(module_size_dict[tool]),
list(module_size_dict[pred_tool]))
res_table.write(f'\t{p_val}')
with open(os.path.join(cami.tmp_dir, f'{tool}_{cami.uid}_relevance_scores.tsv)'), 'r') as f:
for line in f:
rel_score = line.split('\t')[1].strip()
res_table.write(f'\t{rel_score}')
res_table.write('\n')
print(f'Result tables are saved in the following locations:')
fig1,ax1, fig2,ax2 = comparison_matrix.plot_comparison_matrix(redisc_intersection_matrix, n_rows=cami.nof_tools,
title=f'number of times algorithms rediscovered the same seeds after removing {nof_removals} seeds')
fig1.savefig(f'{cami.output_dir}/same_rediscs_{identifier}_comparison_matrix.png')
fig2.savefig(f'{cami.output_dir}/same_rediscs_{identifier}_comparison_matrix_normalized.png')
# print(variation_results)
# print(rediscovery_rates_results)
tools = [tool for tool in redisc_rate_dict.keys()]
tool_labels = tools.copy()
for idx,tool in enumerate(tools):
if '_' in tool:
# find the index of the second occurrence of the character
second_occurrence_index = tool.find('_', tool.find('_') + 1)
if second_occurrence_index > -1:
# replace the character at that index with the replacement character
tool_name = tool[:second_occurrence_index] + '\n' + tool[second_occurrence_index + 1:]
tool_labels[idx] = tool_name
if plot:
#PLOT
# Create a figure instance
#print(sys.getrecursionlimit())
fig1, (ax1, ax5, ax4) = plt.subplots(3, 1, figsize=(20,20))
fig1.subplots_adjust(left=0.2)
# Extract Figure and Axes instance
# Create a plot
violins1 = ax1.violinplot([redisc_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
for violinpart in list(violins1.keys())[2:]:
violins1[violinpart].set_color('k')
for violin, tool in zip(violins1['bodies'], tools):
if tool in [tw.name for tw in cami.tool_wrappers]:
violin.set_facecolor('saddlebrown')
elif tool == 'first_neighbors':
violin.set_facecolor('orange')
elif tool in ['union', 'intersection']:
violin.set_facecolor('peachpuff')
else:
violin.set_facecolor('red')
# Add title
ax1.set_title(f'Rediscovery rate after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
ax1.set_xticks(list(range(1,len(tools)+1)))
ax1.set_xticklabels(tool_labels)
ax1.tick_params(axis='x', labelsize=11)
ax1.set_ylabel('Rediscovery rate (<rediscovered seeds>/<removed seeds>)', wrap=True, fontsize=14)
violins2 = ax4.violinplot([tp_rate_dict[tool] for tool in tools], showmeans=True, showextrema=True)
for violinpart in list(violins2.keys())[2:]:
violins2[violinpart].set_color('k')
for violin, tool in zip(violins2['bodies'], tools):
if tool in [tw.name for tw in cami.tool_wrappers]:
violin.set_facecolor('tan')
elif tool == 'first_neighbors':
violin.set_facecolor('peachpuff')
elif tool in ['union', 'intersection']:
violin.set_facecolor('orange')
else:
violin.set_facecolor('darkorange')
# Add title
ax4.set_title(f'True positive rates after randomly removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
ax4.set_xticks(list(range(1,len(tools)+1)))
ax4.set_xticklabels(tool_labels)
ax4.tick_params(axis='x', labelsize=11)
ax4.set_ylabel('Sensitivity (TP/TP + FN)', wrap=True, fontsize=14)
violins3 = ax5.violinplot([module_size_dict[tool] for tool in tools], showmeans=True, showextrema=True)
# Add title
for violinpart in list(violins3.keys())[2:]:
violins3[violinpart].set_color('k')
for violin, tool in zip(violins3['bodies'], tools):
if tool in [tw.name for tw in cami.tool_wrappers]:
violin.set_facecolor('midnightblue')
elif tool == 'first_neighbors':
violin.set_facecolor('mediumblue')
elif tool in ['union', 'intersection']:
violin.set_facecolor('lightsteelblue')
else:
violin.set_facecolor('royalblue')
ax5.set_title(f'Ratio of number of rediscovered seeds and predicted module size after removing {nof_removals} seeds {nof_iterations} times from {identifier} seeds.', wrap=True, fontsize=14)
ax5.set_xticks(list(range(1,len(tools)+1)))
ax5.set_xticklabels(tool_labels)
ax5.set_ylabel('precision (<rediscovered seeds>/<module size>)', fontsize=14)
ax5.tick_params(axis='x', labelsize=11)
fig1.tight_layout()
fig1.savefig(f'{cami.output_dir}/00_{identifier}_seed_variation_result.png', bbox_inches="tight")
plt.close(fig1)
print(f'Violin plot saved under: 00_{identifier}_seed_variation_result.png')
return cami
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment