diff --git a/cami_src/utils/comparison_matrix.py b/cami_src/utils/comparison_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7c7032d97d9deb866dfe32336e1a959c2fc599 --- /dev/null +++ b/cami_src/utils/comparison_matrix.py @@ -0,0 +1,37 @@ +import pandas as pd +import seaborn as sb +import matplotlib.pyplot as plt + + +def make_comparison_matrix(result_gene_sets): + n_results = len(result_gene_sets) + comparison_matrix = pd.DataFrame([[int(0) for _ in range(n_results)] for __ in range(n_results)], + columns = list(result_gene_sets.keys()), + index = list(result_gene_sets.keys()), + dtype=int) + for algo1 in result_gene_sets: + for algo2 in result_gene_sets: + comparison_matrix.loc[algo1,algo2] = int(len(result_gene_sets[algo1].intersection(result_gene_sets[algo2]))) + return comparison_matrix + +def plot_comparison_matrix(comparison_matrix, title='', n_rows=-1): + """plot the comparison matrix + + Args: + comparison_matrix (DataFrame): comparison matrix that compares the result_gene_sets of all algorithms to each other + n_rows (int, optional): number of rows to plot. Defaults to -1. + Returns: + fig1, ax2: matplotlib figure and axis of the comparison matrix + fig2, ax2: matplotlib figure and axis of the normalized comparison matrix + """ + comparison_matrix_slice = comparison_matrix.iloc[:n_rows] + fig1, ax1 = plt.subplots(figsize=(20,10)) + ax1 = sb.heatmap(comparison_matrix_slice, annot=True, fmt='g') + ax1.set_title(f'{title.capitalize()}') + + fig2, ax2 = plt.subplots(figsize=(20,20)) + comparison_matrix_slice_normalized = comparison_matrix_slice.apply(lambda row: row/row.max(), axis=1) + ax2 = sb.heatmap(comparison_matrix_slice_normalized, annot=True, fmt='.2f') + ax2.set_title(f'Normalized {title}') + return fig1,ax1, fig2,ax2 +