audtorch.collate¶
Collate functions manipulate and merge a list
of samples to form a mini-batch, see torch.utils.data.DataLoader
.
An example use case is batching sequences of variable-length,
which requires padding each sample to the maximum length in the batch.
Collation¶
-
class
audtorch.collate.
Collation
¶ Abstract interface for collation classes.
All other collation classes should subclass it. All subclasses should override
__call__
, that executes the actual collate function.
Seq2Seq¶
-
class
audtorch.collate.
Seq2Seq
(sequence_dimensions, *, batch_first=None, pad_values=[0, 0], sort_sequences=True)¶ Pads mini-batches to longest contained sequence for seq2seq-purposes.
This class pads features and targets to the largest sequence in the batch. Before padding, length information are extracted from them.
Note
The tensors can be sorted in descending order of features’ lengths by enabling
sort_sequences
. Thereby the requirements oftorch.nn.utils.rnn.pack_padded_sequence()
are anticipated, which is used by recurrent layers.sequence_dimensions
holds dimension of sequence in features and targetsbatch_first
controls output shape of features and targetspad_values
controls values to pad features (targets) withsort_sequences
controls if sequences are sorted in descending order of features’ lengths
- Parameters
sequence_dimensions (list of ints) – indices representing dimension of sequence in feature and target tensors. Position 0 represents sequence dimension of features, position 1 represents sequence dimension of targets. Negative indexing is permitted
batch_first (bool or None, optional) – determines output shape of collate function. If None, original shape of features and targets is kept with dimension of batch size prepended. See Shape for more information. Default: None
pad_values (list, optional) – values to pad shorter sequences with. Position 0 represents value of features, position 1 represents value of targets. Default: [0, 0]
sort_sequences (bool, optional) – option whether to sort sequences in descending order of features’ lengths. Default: True
- Shape:
Input: \((*, S, *)\), where \(*\) can be any number of further dimensions except \(N\) which is the batch size, and where \(S\) is the sequence dimension.
Output:
features:
\((N, *, S, *)\) if
batch_first
is None, i.e. the original input shape with \(N\) prepended which is the batch size\((N, S, *, *)\) if
batch_first
is True\((S, N, *, *)\) if
batch_first
is False
feats_lengths: \((N,)\)
targets: analogous to features
tgt_lengths: analogous to feats_lengths
Example
>>> # data format: FS = (feature dimension, sequence dimension) >>> batch = [[torch.zeros(161, 108), torch.zeros(10)], ... [torch.zeros(161, 223), torch.zeros(12)]] >>> collate_fn = Seq2Seq([-1, -1], batch_first=None) >>> features = collate_fn(batch)[0] >>> list(features.shape) [2, 161, 223]