Skip to content
Snippets Groups Projects
Commit 13b988b5 authored by Le, Mia's avatar Le, Mia
Browse files

added comparison_matrix utils

parent eae22010
Branches
No related tags found
No related merge requests found
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
def make_comparison_matrix(result_gene_sets):
n_results = len(result_gene_sets)
comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)],
columns = list(result_gene_sets.keys()),
index = list(result_gene_sets.keys()),
dtype=int)
for algo1 in result_gene_sets:
for algo2 in result_gene_sets:
comparison_matrix.loc[algo1,algo2] = int(len(result_gene_sets[algo1].intersection(result_gene_sets[algo2])))
return comparison_matrix
def plot_comparison_matrix(comparison_matrix, title='', n_rows=-1):
"""plot the comparison matrix
Args:
comparison_matrix (DataFrame): comparison matrix that compares the result_gene_sets of all algorithms to each other
n_rows (int, optional): number of rows to plot. Defaults to -1.
Returns:
fig1, ax2: matplotlib figure and axis of the comparison matrix
fig2, ax2: matplotlib figure and axis of the normalized comparison matrix
"""
comparison_matrix_slice = comparison_matrix.iloc[:n_rows]
fig1, ax1 = plt.subplots(figsize=(20,10))
ax1 = sb.heatmap(comparison_matrix_slice, annot=True, fmt='g')
ax1.set_title(f'{title.capitalize()}')
fig2, ax2 = plt.subplots(figsize=(20,20))
comparison_matrix_slice_normalized = comparison_matrix_slice.apply(lambda row: row/row.max(), axis=1)
ax2 = sb.heatmap(comparison_matrix_slice_normalized, annot=True, fmt='.2f')
ax2.set_title(f'Normalized {title}')
return fig1,ax1, fig2,ax2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment