-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_results.py
More file actions
88 lines (71 loc) · 2.76 KB
/
plot_results.py
File metadata and controls
88 lines (71 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
SUMMARY_FILE = Path("outputs/grid_search_summary.csv")
PLOT_DIR = Path("outputs/plots")
PLOT_DIR.mkdir(parents=True, exist_ok=True)
def main():
if not SUMMARY_FILE.exists():
raise FileNotFoundError(f"Missing summary file: {SUMMARY_FILE}")
df = pd.read_csv(SUMMARY_FILE)
if df.empty:
raise ValueError("Summary CSV is empty.")
# Sort by STOI, then SI-SDR
df_sorted = df.sort_values(by=["mean_stoi", "mean_si_sdr"], ascending=False)
df_sorted.to_csv(PLOT_DIR / "sorted_results.csv", index=False)
print("\nTop 5 results:")
print(df_sorted.head(5)[["exp_name", "lambda1_spec", "lambda2_smooth", "mean_stoi", "mean_si_sdr"]])
# -----------------------------
# Heatmap: STOI
# -----------------------------
pivot_stoi = df.pivot(index="lambda2_smooth", columns="lambda1_spec", values="mean_stoi")
plt.figure(figsize=(8, 6))
plt.imshow(pivot_stoi.values, aspect="auto")
plt.xticks(range(len(pivot_stoi.columns)), pivot_stoi.columns)
plt.yticks(range(len(pivot_stoi.index)), pivot_stoi.index)
plt.xlabel("lambda1 (spec_w)")
plt.ylabel("lambda2 (smooth_w)")
plt.title("Heatmap of Mean STOI")
plt.colorbar(label="Mean STOI")
plt.tight_layout()
plt.savefig(PLOT_DIR / "heatmap_stoi.png", dpi=300)
plt.close()
# -----------------------------
# Heatmap: SI-SDR
# -----------------------------
pivot_sisdr = df.pivot(index="lambda2_smooth", columns="lambda1_spec", values="mean_si_sdr")
plt.figure(figsize=(8, 6))
plt.imshow(pivot_sisdr.values, aspect="auto")
plt.xticks(range(len(pivot_sisdr.columns)), pivot_sisdr.columns)
plt.yticks(range(len(pivot_sisdr.index)), pivot_sisdr.index)
plt.xlabel("lambda1 (spec_w)")
plt.ylabel("lambda2 (smooth_w)")
plt.title("Heatmap of Mean SI-SDR")
plt.colorbar(label="Mean SI-SDR")
plt.tight_layout()
plt.savefig(PLOT_DIR / "heatmap_sisdr.png", dpi=300)
plt.close()
# -----------------------------
# Pareto-style scatter plot
# x = SI-SDR
# y = STOI
# -----------------------------
plt.figure(figsize=(8, 6))
plt.scatter(df["mean_si_sdr"], df["mean_stoi"])
for _, row in df.iterrows():
plt.annotate(
f'l1={row["lambda1_spec"]}, l2={row["lambda2_smooth"]}',
(row["mean_si_sdr"], row["mean_stoi"]),
fontsize=7,
xytext=(5, 5),
textcoords="offset points"
)
plt.xlabel("Mean SI-SDR")
plt.ylabel("Mean STOI")
plt.title("Pareto-style Trade-off Plot")
plt.tight_layout()
plt.savefig(PLOT_DIR / "pareto_stoi_sisdr.png", dpi=300)
plt.close()
print("\nPlots saved in:", PLOT_DIR)
if __name__ == "__main__":
main()