forked from LAION-AI/CLAP
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrenshu.py
More file actions
60 lines (49 loc) · 1.96 KB
/
renshu.py
File metadata and controls
60 lines (49 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import laion_clap
import glob
import json
import torch
import numpy as np
device = torch.device("cuda:1")
# download https://drive.google.com/drive/folders/1scyH43eQAcrBz-5fAw44C6RNBhC3ejvX?usp=sharing and extract ./ESC50_1/test/0.tar to ./ESC50_1/test/
esc50_test_dir = (
"/mnt/kiso-qnap3/yuabe/m1/CLAP/data/test/fsx/home-marianna/ESC50_folds/ESC50_1"
)
class_index_dict_path = "./class_labels/ESC50_class_labels_indices_space.json"
# Load the model
model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device)
model.load_ckpt("/home/yuabe/checkpoints/music_audioset_epoch_15_esc_90.14.pt")
# Get the class index dict
class_index_dict = {v: k for v, k in json.load(open(class_index_dict_path)).items()}
# Get all the data
import os
audio_files = sorted(
glob.glob(os.path.join(esc50_test_dir, "**", "*.flac"), recursive=True)
)
json_files = sorted(
glob.glob(os.path.join(esc50_test_dir, "**", "*.json"), recursive=True)
)
ground_truth_idx = [
class_index_dict[json.load(open(jf))["tag"][0]] for jf in json_files
]
with torch.no_grad():
ground_truth = torch.tensor(ground_truth_idx).view(-1, 1)
# Get text features
all_texts = ["This is a sound of " + t for t in class_index_dict.keys()]
text_embed = model.get_text_embedding(all_texts)
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files)
ranking = torch.argsort(
torch.tensor(audio_embed) @ torch.tensor(text_embed).t(), descending=True
)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.cpu().numpy()
metrics = {}
metrics[f"mean_rank"] = preds.mean() + 1
metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
print(
f"Zeroshot Classification Results: "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)