Description
Hi,
I was looking at the implementation of MPerClassSampler, and I noticed the following issue: in consecutive batches, there are often overlaps of classes used. For example, the first batch with batch_size=16
, and m=4
, might consist of classes: [1,5,3,7]
, while the second one might be [1,9,8,2]
. This would mean that examples from class 1
could be seen more often than other examples with small datasets.
I think this can be easily overcome by generating ((length_before_new_iter // batch_size) * m) // num_unique_labels + 1
arrays of unique labels, shuffling each of them and then concatenating them. This way the sampler can take labels from i*m
to (i+1)*m
and be certain that after the epoch, examples from a certain class have been seen either (length_before_new_iter // batch_size) * batch_size // num_unique_labels
or ((length_before_new_iter // batch_size) * batch_size // num_unique_labels) + 1
times, minimizing the initial issue.
I'm pretty certain the difference in performance would be minimal, if any. Does this make sense?