Stratified K-fold cross validation Matlab

My implementation of stratified K-fold cross-validation, pretty much like the c = cvpartition(group,'KFold',k)  from Matlab statistic toolbox library.
<pre>function [X, partition] = KfoldCVBalance(X, y, k)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Author: Pree Thiengburanathum
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Description:
% To ensure that the training, testing, and validating dataset have similar
% proportions of classes (e.g., 20 classes). This stratified sampling
% technique provided the analyst with more control over the sampling process.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Input:
% X - dataset
% k - number of fold
% classData - the class data
%
% Output:
% X - new dataset
% partition - fold index
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
n = size(X, 1);
partition = repmat(0, n, 1);
% shuffle the dataset
[~, idx] = sort(rand(1, n));
X = X(idx, :);
y = y(idx);
% find the unique class
group = unique(y);
nGroup = numel(group);
% find min max number of sample per class
nmax = 0;
for i=1:nGroup
    idx = find(y == group(i));
    ni = length(idx);
    nmax = max(nmax, ni);
end
% create fold indices
foldIndices = zeros(nGroup, nmax);
for i=1:nGroup
    idx = find(y == group(i));
    foldIndices(i, 1:numel(idx)) = idx;
end
% compute fold size for each fold
foldSize = zeros(nGroup, 1);
for i=1:nGroup
    % find the number of element of the class
    numElement = numel(find(foldIndices(i,:) ~= 0));
    % calculate number of element for each fold
    foldSize(i) = floor(numElement/k);
end
ptr = ones(nGroup, 1);
for i=1:k
    for j=1:nGroup
        idx =  foldIndices(j, (ptr(j): (ptr(j)+foldSize(j)) ));
        if(idx(end) == 0)
           idx = idx(1:end-1);
        end
        partition (idx) = i;
        ptr(j) = ptr(j)+foldSize(j);
    end
end
% dump the rest of index to the last fold
idx = find(partition == 0);
partition(idx) = k;
data = [X partition];
% check class balance for each fold
for i=1:k
    idx = find(data(:, 2) == i);
    fold = X(idx);
    disp(['fold# ', int2str(i), ' has ', int2str( numel(fold) ) ]);
    for j=1:nGroup
        idx = find(fold == group(j));
        percentage = (numel(idx)/numel(fold)) * 100;
        disp(['class# ', int2str(j), ' = ', num2str(percentage), '%']);


    end
    disp(' ');
end
end % end function

 

Leave a Reply

Your email address will not be published. Required fields are marked *