audtorch.samplers¶
BucketSampler¶
-
class
audtorch.samplers.
BucketSampler
(*args: Any, **kwargs: Any)¶ Creates batches from ordered data sets.
This sampler iterates over the data sets of concat_dataset and samples sequentially from them. Samples of each batch deliberately originate solely from the same data set. Only when the current data set is exhausted, the next data set is sampled from. In other words, samples from different buckets are never mixed.
In each epoch num_batches batches of size batch_sizes are extracted from each data set. If the requested number of batches cannot be extracted from a data set, only its available batches are queued. By default, the data sets (and thus their batches) are iterated over in increasing order of their data set id.
Note
The information in
batch_sizes
andnum_batches
refer todatasets
at the same index independently ofpermuted_order
.Simple Use Case: “Train on data with increasing sequence length”
bucket_id:
[0, 1, 2, … end ]
batch_sizes:
[32, 16, 8, … 2 ]
num_batches:
[None, None, None, … None ]
Result: “Extract all batches (None) from all data sets, all of different batch size, and queue them in increasing order of their data set id”
batch_sizes
controls batch size for each data setnum_batches
controls number of batches to extract from each data setpermuted_order
controls if order in which data sets are iterated over is permuted or in which specific order iteration is permutedshuffle_each_bucket
controls if each data set is shuffleddrop_last
controls whether to drop last samples of a bucket which cannot form a mini-batch
- Parameters
concat_dataset (torch.utils.data.ConcatDataset) – ordered concatenated data set
batch_sizes (list) – batch sizes per data set. Permissible values are unsigned integers
num_batches (list or None, optional) – number of batches per data set. Permissible values are non-negative integers and None. If None, then as many batches are extracted as data set provides. Default: None
permuted_order (bool or list, optional) – option whether to permute the order of data set ids in which the respective data set’s batches are queued. If True (False), data set ids are (not) shuffled. Besides, a customized list of permuted data set ids can be specified. Default: False
shuffle_each_bucket (bool, optional) – option whether to shuffle samples in each data set. Recommended to set to True. Default: True
drop_last (bool, optional) – controls whether the last samples of a bucket which cannot form a mini-batch should be dropped. Default: False
Example
>>> import torch >>> from torch.utils.data import (TensorDataset, ConcatDataset) >>> from audtorch.datasets.utils import defined_split >>> data = TensorDataset(torch.randn(100)) >>> lengths = np.random.randint(0, 890, (100,)) >>> split_func = buckets_of_even_size(lengths, num_buckets=3) >>> subsets = defined_split(data, split_func) >>> concat_dataset = ConcatDataset(subsets) >>> batch_sampler = BucketSampler(concat_dataset, 3 * [16])
buckets_by_boundaries¶
-
audtorch.samplers.
buckets_by_boundaries
(key_values, bucket_boundaries)¶ Split samples into buckets based on key values using bucket boundaries.
Note
A sample is sorted into bucket \(i\) if for their key value holds:
\(b_{i-1} <= \text{key value} < b_i\), where \(b_i\) is bucket boundary at index \(i\)
- Parameters
- Returns
Key function to use for splitting: \(f(\text{item}) = \text{bucket\_id}\)
- Return type
func
Example
>>> lengths = [288, 258, 156, 99, 47, 13] >>> boundaries = [80, 150] >>> split_func = buckets_by_boundaries(lengths, boundaries) >>> [split_func(i) for i in range(len(lengths))] [2, 2, 2, 1, 0, 0]
buckets_of_even_size¶
-
audtorch.samplers.
buckets_of_even_size
(key_values, num_buckets, *, reverse=False)¶ Split samples into buckets of even size based on key values.
The samples are sorted with either increasing (or decreasing) key value. If number of samples cannot be distributed evenly to buckets, the first buckets are filled up with one remainder each.
- Parameters
- Returns
Key function to use for splitting: \(f(\text{item}) = \text{bucket\_id}\)
- Return type
func
Example
>>> lengths = [288, 258, 156, 47, 112, 99, 13] >>> num_buckets = 4 >>> split_func = buckets_of_even_size(lengths, num_buckets) >>> [split_func(i) for i in range(len(lengths))] [3, 2, 2, 0, 1, 1, 0]