-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSDM.m
More file actions
63 lines (56 loc) · 2 KB
/
Copy pathSDM.m
File metadata and controls
63 lines (56 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
function [accuracy] = SDM(sourceFile,targetFile, testFile)
% function [accuracy] = SDM(sourceFile,targetFile, testFile)
% Input:
% sourceFile - source domain training data (e.g. source.arff)
% targetFile - target domain training data (e.g. target.arff)
% testFile - target domain testing data (e.g. test.arff)
%
% Output:
% accuracy - final accuracy
%% Preparing data by reading *.arff files and converting into matlab format
%% Set paths for utils files
addpath('.\utils');
fprintf('SDM starts....\nPreparing data...\n');
[Xs, Ys,target_data, test_data] = prepareData(sourceFile,targetFile, testFile);
nt=size(test_data.test_labels,1);
%% Set parameters for SVM
param.C = 10;
param.Cu = 1; % Cu should be less than C
param.Cu_max = 10*param.Cu; % add at most rho patterns at each iteration
param.rho = 10;
param.max_iter = 100;
param.max_unl_num = 5;
param.kernel_type ='gaussian'; % 'gaussian' or 'linear';
%% Set optiones for MMD and Manifold
dd=20;
[numRows,numCols] = size(Xs);
if dd>numRows
dd=numRows-1;
end
if dd>numCols
dd=numCols-1;
end
[numRows,numCols] = size(target_data.target_features);
if dd>numRows
dd=numRows-1;
end
if dd>numCols
dd=numCols-1;
end
options.d = dd; %defualt 20
options.rho = 1.0;
options.p = 10;
options.lambda = 10.0;
options.eta = 0.1;
options.T = 10;
%% Find number of classes and hypothesis
CV=unique(Ys);
Nc=length(CV);
H=Nc*(Nc-1)/2;
%% 1-to-1 multiclass classification
fprintf('Running multiClassSDM...\n');
[predictions] = multiClassSDM(Xs, Ys,target_data, test_data, param, options);
%% Calculate classification accuracy on the test data
fprintf('Calculating classification accuracy...\n');
[accuracy] = getAccuracy(test_data.test_labels, predictions, H, nt, CV);
end