diff --git a/calc_metrics.py b/calc_metrics.py index b2e2e55cdafd318dfea72d6a22236b002f6a60a6..3f6a352e0501e361ea2a004da37dece4998f14ee 100644 --- a/calc_metrics.py +++ b/calc_metrics.py @@ -13,57 +13,48 @@ from sgmse.util.other import energy_ratios, mean_std if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the original test data (must have subdirectories clean/ and noisy/)') + parser.add_argument("--clean_dir", type=str, required=True, help='Directory containing the clean data') + parser.add_argument("--noisy_dir", type=str, required=True, help='Directory containing the noisy data') parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data') args = parser.parse_args() - test_dir = args.test_dir - clean_dir = join(test_dir, "clean/") - noisy_dir = join(test_dir, "noisy/") - enhanced_dir = args.enhanced_dir - data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []} sr = 16000 # Evaluate standard metrics - noisy_files = sorted(glob('{}/*.wav'.format(noisy_dir))) + noisy_files = [] + noisy_files += sorted(glob(join(args.noisy_dir, '*.wav'))) + noisy_files += sorted(glob(join(args.test_dir, '**', '*.wav'))) for noisy_file in tqdm(noisy_files): filename = noisy_file.split('/')[-1] - x, _ = read(join(clean_dir, filename)) + x, _ = read(join(args.clean_dir, filename)) y, _ = read(noisy_file) n = y - x - x_method, _ = read(join(enhanced_dir, filename)) - + x_hat, _ = read(join(args.enhanced_dir, filename)) data["filename"].append(filename) - data["pesq"].append(pesq(sr, x, x_method, 'wb')) - data["estoi"].append(stoi(x, x_method, sr, extended=True)) - data["si_sdr"].append(energy_ratios(x_method, x, n)[0]) - data["si_sir"].append(energy_ratios(x_method, x, n)[1]) - data["si_sar"].append(energy_ratios(x_method, x, n)[2]) + data["pesq"].append(pesq(sr, x, x_hat, 'wb')) + data["estoi"].append(stoi(x, x_hat, sr, extended=True)) + data["si_sdr"].append(energy_ratios(x_hat, x, n)[0]) + data["si_sir"].append(energy_ratios(x_hat, x, n)[1]) + data["si_sar"].append(energy_ratios(x_hat, x, n)[2]) # Save results as DataFrame df = pd.DataFrame(data) - # POLQA evaluation - requires POLQA license and server, uncomment at your own peril. - # This is batch processed for speed reasons and thus runs outside the for loop. - # if not basic: - # clean_files = sorted(glob('{}/*.wav'.format(clean_dir))) - # enhanced_files = sorted(glob('{}/*.wav'.format(enhanced_dir))) - # clean_audios = [read(clean_file)[0] for clean_file in clean_files] - # enhanced_audios = [read(enhanced_file)[0] for enhanced_file in enhanced_files] - # polqa_vals = polqa(clean_audios, enhanced_audios, 16000, save_to=None) - # polqa_vals = [val[1] for val in polqa_vals] - # # Add POLQA column to DataFrame - # df['polqa'] = polqa_vals - # Print results - print(enhanced_dir) - #print("POLQA: {:.2f} ± {:.2f}".format(*mean_std(df["polqa"].to_numpy()))) print("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy()))) print("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy()))) print("SI-SDR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sdr"].to_numpy()))) print("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy()))) print("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy()))) + # Save average results to file + log = open(join(args.enhanced_dir, "_avg_results.txt"), "w") + log.write("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())) + "\n") + log.write("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())) + "\n") + log.write("SI-SDR: {:.1f} ± {:.2f}".format(*mean_std(df["si_sdr"].to_numpy())) + "\n") + log.write("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())) + "\n") + log.write("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())) + "\n") + # Save DataFrame as csv file - df.to_csv(join(enhanced_dir, "_results.csv"), index=False) + df.to_csv(join(args.enhanced_dir, "_results.csv"), index=False)