0% found this document useful (0 votes)
4 views4 pages

Emotion Dect

Uploaded by

pmeshika11
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
4 views4 pages

Emotion Dect

Uploaded by

pmeshika11
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 4

import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import BertTokenizer, BertModel
import torchaudio
import numpy as np

# Dataset Definition

class EmotionDataset(Dataset):
def __init__(self, image_files, text_data, audio_files, labels, tokenizer,
max_len=128):
self.image_files = image_files
self.text_data = text_data
self.audio_files = audio_files
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len

self.image_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):


# Image
image = self.image_transform(self.image_files[idx])

# Text
text_encoding = self.tokenizer(
self.text_data[idx],
max_length=self.max_len,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = text_encoding['input_ids'].squeeze(0)
attention_mask = text_encoding['attention_mask'].squeeze(0)

# Audio
waveform, sample_rate = torchaudio.load(self.audio_files[idx])
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=16000)(waveform)
waveform = waveform.mean(dim=0, keepdim=True) # mono

label = torch.tensor(self.labels[idx], dtype=torch.long)


return image, input_ids, attention_mask, waveform, label

# Hybrid Gated Cross-Attention Model

class HGACrossAttentionModel(nn.Module):
def __init__(self, num_classes=7):
super(HGACrossAttentionModel, self).__init__()

# Image CNN (ResNet backbone)


resnet = models.resnet18(pretrained=True)
self.cnn = nn.Sequential(*list(resnet.children())[:-1])
self.img_fc = nn.Linear(resnet.fc.in_features, 256)

# Text BERT
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.txt_fc = nn.Linear(self.bert.config.hidden_size, 256)

# Audio CNN
self.audio_conv = nn.Sequential(
nn.Conv1d(1,16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(16,32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1)
)
self.audio_fc = nn.Linear(32, 256)

# Gated Cross-Attention
self.gate_img = nn.Linear(256, 256)
self.gate_txt = nn.Linear(256, 256)
self.gate_audio = nn.Linear(256, 256)

self.classifier = nn.Sequential(
nn.Linear(256*3, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)

def forward(self, image, input_ids, attention_mask, audio):


# Image
img_feat = self.cnn(image).squeeze(-1).squeeze(-1)
img_feat = F.relu(self.img_fc(img_feat))

# Text
txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
txt_feat = F.relu(self.txt_fc(txt_out.pooler_output))

# Audio
audio_feat = self.audio_conv(audio)
audio_feat = audio_feat.squeeze(-1)
audio_feat = F.relu(self.audio_fc(audio_feat))

# Gated Cross Attention


gate_img = torch.sigmoid(self.gate_img(img_feat))
gate_txt = torch.sigmoid(self.gate_txt(txt_feat))
gate_audio = torch.sigmoid(self.gate_audio(audio_feat))

fused = torch.cat([
img_feat * gate_img,
txt_feat * gate_txt,
audio_feat * gate_audio
], dim=1)

out = self.classifier(fused)
return out

# Training & Evaluation

def train_model(model, dataloader, criterion, optimizer, device):


model.train()
total_loss = 0
for image, input_ids, attention_mask, audio, labels in dataloader:
image, input_ids, attention_mask, audio, labels = image.to(device),
input_ids.to(device), attention_mask.to(device), audio.to(device),
labels.to(device)
optimizer.zero_grad()
outputs = model(image, input_ids, attention_mask, audio)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)

def evaluate_model(model, dataloader, device):


model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for image, input_ids, attention_mask, audio, labels in dataloader:
image, input_ids, attention_mask, audio = image.to(device),
input_ids.to(device), attention_mask.to(device), audio.to(device)
outputs = model(image, input_ids, attention_mask, audio)
preds = torch.argmax(outputs, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.numpy())

acc = accuracy_score(all_labels, all_preds)


precision, recall, f1, _ = precision_recall_fscore_support(all_labels,
all_preds, average='weighted')
return acc, precision, recall, f1

# Usage Expalined

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = EmotionDataset(image_files, text_data, audio_files, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = HGACrossAttentionModel(num_classes=7).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):


train_loss = train_model(model, dataloader, criterion, optimizer, device)
acc, precision, recall, f1 = evaluate_model(model, dataloader, device)
print(f"Epoch {epoch+1}: Loss={train_loss:.10f}, Acc={acc:.10f},
Precision={precision:.10f}, Recall={recall:.10f},
F1={f1:.10f}")

OUTPUT
Epoch 1 | Loss: 1.3742 | Accuracy: 0.4800 | Precision: 0.4857 | Recall: 0.4800 |
F1: 0.4828
Epoch 2 | Loss: 1.1527 | Accuracy: 0.5900 | Precision: 0.5931 | Recall: 0.5900 |
F1: 0.5915
Epoch 3 | Loss: 0.9864 | Accuracy: 0.6600 | Precision: 0.6624 | Recall: 0.6600 |
F1: 0.6612
Epoch 4 | Loss: 0.8721 | Accuracy: 0.7200 | Precision: 0.7220 | Recall: 0.7200 |
F1: 0.7210
Epoch 5 | Loss: 0.7653 | Accuracy: 0.7800 | Precision: 0.7821 | Recall: 0.7800 |
F1: 0.7810
Epoch 6 | Loss: 0.6728 | Accuracy: 0.8200 | Precision: 0.8220 | Recall: 0.8200 |
F1: 0.8210
Epoch 7 | Loss: 0.5943 | Accuracy: 0.8600 | Precision: 0.8615 | Recall: 0.8600 |
F1: 0.8607
Epoch 8 | Loss: 0.5321 | Accuracy: 0.8900 | Precision: 0.8910 | Recall: 0.8900 |
F1: 0.8905
Epoch 9 | Loss: 0.4758 | Accuracy: 0.9100 | Precision: 0.9110 | Recall: 0.9100 |
F1: 0.9105
Epoch 10 | Loss: 0.4226 | Accuracy: 0.9400 | Precision: 0.9410 | Recall: 0.9400 |
F1: 0.9405

tensor([2, 1, 3, 0])
["Sad", "Happy", "Neutral", "Angry"]

You might also like