audtorch.datasets¶
Audio data sets.
AudioSet¶
-
class
audtorch.datasets.
AudioSet
(*args: Any, **kwargs: Any)¶ A large-scale dataset of manually annotated audio events.
Open and publicly available data set of audio events from Google: https://research.google.com/audioset/
License: CC BY 4.0
The categories corresponding to an audio signal are returned as a list, starting with those included in the top hierarchy of the AudioSet ontology, followed by those from the second hierarchy and then all other categories in a random order.
The signals to be returned can be limited by excluding or including only certain categories. This is achieved by first including only the desired categories, estimating all its parent categories and then applying the exclusion.
transform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetssampling_rate
holds the sampling rate of the returned dataoriginal_sampling_rate
holds the sampling rate of the audio files of the data set
- Parameters
root (str) – root directory of dataset
csv_file (str, optional) – name of a CSV file located in root. Can be one of balanced_train_segments.csv, unbalanced_train_segments.csv, eval_segments.csv. Default: balanced_train_segments.csv
include (list of str, optional) – list of categories to include. If None all categories are included. Default: None
exclude (list of str, optional) – list of categories to exclude. If None no category is excluded. Default: None
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
AudioSet ontology categories of the two top hierarchies:
Human sounds Animal Music |-Human voice |-Domestic animals, pets |-Musical instrument |-Whistling |-Livestock, farm |-Music genre |-Respiratory sounds | animals, working |-Musical concepts |-Human locomotion | animals |-Music role |-Digestive \-Wild animals \-Music mood |-Hands |-Heart sounds, Sounds of things Natural sounds | heartbeat |-Vehicle |-Wind |-Otoacoustic emission |-Engine |-Thunderstorm \-Human group actions |-Domestic sounds, |-Water | home sounds \-Fire Source-ambiguous sounds |-Bell |-Generic impact sounds |-Alarm Channel, environment |-Surface contact |-Mechanisms and background |-Deformable shell |-Tools |-Acoustic environment |-Onomatopoeia |-Explosion |-Noise |-Silence |-Wood \-Sound reproduction \-Other sourceless |-Glass |-Liquid |-Miscellaneous sources \-Specific impact sounds
Warning
Some of the recordings in AudioSet were captured with mono and others with stereo input. The user must be careful to handle this, e.g. using a transform to adjust number of channels.
Example
>>> import sounddevice as sd >>> data = AudioSet(root='/data/AudioSet', include=['Thunderstorm']) >>> print(data) Dataset AudioSet Number of data points: 73 Root Location: /data/AudioSet Sampling Rate: 16000Hz CSV file: balanced_train_segments.csv Included categories: ['Thunderstorm'] >>> signal, target = data[4] >>> target ['Natural sounds', 'Thunderstorm', 'Water', 'Rain', 'Thunder'] >>> sd.play(signal.transpose(), data.sampling_rate)
EmoDB¶
-
class
audtorch.datasets.
EmoDB
(*args: Any, **kwargs: Any)¶ EmoDB data set.
Open and publicly available data set of acted emotions: http://www.emodb.bilderbar.info/navi.html
EmoDB is a small audio data set collected in an anechoic chamber in the Berlin Institute of Technology, it contains 5 male and 5 female speakers, consists of 10 unique sentences, and is annotated for 6 emotions plus a neutral state. The spoken language is German.
- Parameters
root – root directory of dataset
transform – function/transform applied on the signal
target_transform – function/transform applied on the target
Note
When using the EmoDB data set in your research, please cite the following publication: [BPR+05].
Example
>>> import sounddevice as sd >>> data = EmoDB('/data/emodb') >>> print(data) Dataset EmoDB Number of data points: 465 Root Location: /data/emodb Sampling Rate: 16000Hz Labels: emotion >>> signal, target = data[0] >>> target 'A' >>> sd.play(signal.transpose(), data.sampling_rate)
LibriSpeech¶
-
class
audtorch.datasets.
LibriSpeech
(*args: Any, **kwargs: Any)¶ LibriSpeech speech data set.
Open and publicly available data set of voices from OpenSLR: http://www.openslr.org/12/
License: CC BY 4.0.
LibriSpeech contains several hundred hours of English speech with corresponding transcriptions in capital letters without punctuation.
It is split into different subsets according to WER-level achieved when performing speech recognition on the speakers. The subsets are: train-clean-100, train-clean-360, train-other-500 dev-clean, dev-other, test-clean, test-other
root
holds the data set’s locationtransform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data setlabels
controls the corresponding labelssampling_rate
holds the sampling rate of data set
In addition, the following class attributes are available
all_sets
holds the names of the different pre-defined setsurls
holds the download links of the different sets
- Parameters
root (str) – root directory of data set
sets (str or list, optional) – desired sets of LibriSpeech. Mutually exclusive with
dataframe
. Default: Nonedataframe (pandas.DataFrame, optional) – pandas data frame containing columns audio_path (relative to root) and transcription. It can be used to pre-select files based on meta information, e.g. sequence length. Mutually exclusive with
sets
. Default: Nonetransform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
download (bool, optional) – download data set to root directory if not present. Default: False
Example
>>> import sounddevice as sd >>> data = LibriSpeech(root='/data/LibriSpeech', sets='dev-clean') >>> print(data) Dataset LibriSpeech Number of data points: 2703 Root Location: /data/LibriSpeech Sampling Rate: 16000Hz Sets: dev-clean >>> signal, label = data[8] >>> label AS FOR ETCHINGS THEY ARE OF TWO KINDS BRITISH AND FOREIGN >>> sd.play(signal.transpose(), data.sampling_rate)
MozillaCommonVoice¶
-
class
audtorch.datasets.
MozillaCommonVoice
(*args: Any, **kwargs: Any)¶ Mozilla Common Voice speech data set.
Open and publicly available data set of voices from Mozilla: https://voice.mozilla.org/en/datasets
License: CC-0 (public domain)
Mozilla Common Voice includes the labels text, up_votes, down_votes, age, gender, accent, duration. You can select one of those labels which is returned as a string by the data set as target or you can specify a list of the labels and the data set will return a dictionary containing those labels. The default label that is returned is text.
root
holds the data set’s locationtransform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetssampling_rate
holds the sampling rate of the returned dataoriginal_sampling_rate
holds the sampling rate of the audio files of the data set
In addition, the following class attribute is available
url
holds the download link of the data set
- Parameters
root (str) – root directory of data set, where the CSV files are located, e.g. /data/MozillaCommonVoice/cv_corpus_v1
csv_file (str, optional) – name of a CSV file from the root folder. No absolute path is possible. You are most probably interested in cv-valid-train.csv, cv-valid-dev.csv, and cv-valid-test.csv. Default: cv-valid-train.csv.
label_type (str or list of str, optional) – one of text, up_votes, down_votes, age, gender, accent, duration. Or a list of any combination of those. Default: text
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
download (bool, optional) – download data set if not present. Default: False
Note
The Mozilla Common Voice data set is constantly growing. If you choose to download it, it will always grep the latest version. If you require reproducibility of your results, make sure to store a safe snapshot of the version you used.
Example
>>> import sounddevice as sd >>> data = MozillaCommonVoice('/data/MozillaCommonVoice/cv_corpus_v1') >>> print(data) Dataset MozillaCommonVoice Number of data points: 195776 Root Location: /data/MozillaCommonVoice/cv_corpus_v1 Sampling Rate: 48000Hz Labels: text CSV file: cv-valid-train.csv >>> signal, target = data[0] >>> target 'learn to recognize omens and follow them the old king had said' >>> sd.play(signal.transpose(), data.sampling_rate)
SpeechCommands¶
-
class
audtorch.datasets.
SpeechCommands
(*args: Any, **kwargs: Any)¶ Data set of spoken words designed for keyword spotting tasks.
Speech Commands V2 publicly available from Google: http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz
License: CC BY 4.0
- Parameters
root (str) – root directory of data set, where the CSV files are located, e.g. /data/speech_commands_v0.02
train (bool, optional) – Partition the dataset into the training set. False returns the test split. Default: False
download (bool, optional) – Download the dataset to root if it’s not already available. Default: False
include (str, or list of str, optional) – commands to include as ‘recognised’ words. Options: “10cmd”, “full”. A custom dataset can be defined using a list of command words. For example, [“stop”,”go”]. Words that are not in the “include” list are treated as unknown words. Default: ‘10cmd’
silence (bool, optional) – include a ‘silence’ class composed of background noise (Note: use randomcrop when training). Default: True
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Example
>>> import sounddevice as sd >>> data = SpeechCommands(root='/data/speech_commands_v0.02') >>> print(data) Dataset SpeechCommands Number of data points: 97524 Root Location: /data/speech_commands_v0.02 Sampling Rate: 16000Hz >>> signal, target = data[4] >>> target 'right' >>> sd.play(signal.transpose(), data.sampling_rate)
VoxCeleb1¶
-
class
audtorch.datasets.
VoxCeleb1
(*args: Any, **kwargs: Any)¶ VoxCeleb1 data set.
Open and publicly available data set of voices from University of Oxford: http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html
VoxCeleb1 is a large audio-visual data set consisting of short clips of human speech extracted from YouTube interviews with celebrities. It is free for commercial and research purposes.
Licence: CC BY-SA 4.0
transform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetssampling_rate
holds the sampling rate of data set
In addition, the following class attributes are available:
url
holds its URL
- Parameters
root (str) – root directory of dataset
partition (str, optional) – name of the data partition to use. Choose one of train, dev, test or None. If None is given, then the whole data set will be returned. Default: train
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Note
This data set will work only if the identification file is downloaded as is from the official homepage. Please open it in your browser and copy paste its contents in a file in your computer.
To download the data set go to http://www.robots.ox.ac.uk/~vgg/data/voxceleb/ and fill in the form to request a password. Get the Audio Files that the owners provide.
When using the VoxCeleb1 data set in your research, please cite the following publication: [NCZ17].
Example
>>> import sounddevice as sd >>> data = VoxCeleb1('/data/voxceleb1') >>> print(data) Dataset VoxCeleb1 Number of data points: 138361 Root Location: /data/voxceleb1 Sampling Rate: 16000Hz Labels: speaker ID >>> signal, target = data[0] >>> target 'id10003' >>> sd.play(signal.transpose(), data.sampling_rate)
WhiteNoise¶
-
class
audtorch.datasets.
WhiteNoise
(*args: Any, **kwargs: Any)¶ White noise data set.
The white noise is generated by numpy.random.standard_normal.
duration
controls the duration of the noise signalsampling_rate
holds the sampling rate of the returned datamean
controls the mean of the underlying distributionstdev
controls the standard deviation of the underlying distributiontransform
controls the input transformtarget_transform
controls the target transform
As white noise has not really a sampling rate you can use the following attribute to change it instead of resampling:
original_sampling_rate
controls the sampling rate of the data set
- Parameters
duration (float) – duration of the noise signal in seconds
sampling_rate (int, optional) – sampling rate in Hz. Default: 44100
mean (float, optional) – mean of underlying distribution. Default: 0
stdev (float, optional) – standard deviation of underlying distribution. Default: 1
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Note
Even WhiteNoise has an infintely number of entries, its length is 1 as repeated calls of the same index return different signals.
Example
>>> import sounddevice as sd >>> data = WhiteNoise(duration=1, sampling_rate=44100) >>> print(data) Dataset WhiteNoise Number of data points: Inf Signal length: 1s Sampling Rate: 44100Hz Label (str): noise type >>> signal, target = data[0] >>> target 'white noise' >>> sd.play(signal.transpose(), data.sampling_rate)
Base¶
This section contains a mix of generic data sets that are useful for a wide variety of cases and can be used as base classes for other data sets.
AudioDataset¶
-
class
audtorch.datasets.
AudioDataset
(*args: Any, **kwargs: Any)¶ Basic audio signal data set.
This data set can be used if you have a list of files and a list of corresponding targets.
In addition, this class is a great starting point to inherit from if you wish to build your own data set.
transform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetsduration
controls audio duration for every file in secondsoffset
controls audio offset for every file in secondssampling_rate
holds the sampling rate of the returned dataoriginal_sampling_rate
holds the sampling rate of the audio files of the data set
- Parameters
files (list) – list of files
targets (list) – list of targets
sampling_rate (int) – sampling rate in Hz of the data set
root (str, optional) – root directory of dataset. Default: None
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Example
>>> data = AudioDataset(files=['speech.wav', 'noise.wav'], ... targets=['speech', 'noise'], ... sampling_rate=8000, ... root='/data') >>> print(data) Dataset AudioDataset Number of data points: 2 Root Location: /data Sampling Rate: 8000Hz >>> signal, target = data[0] >>> target 'speech'
-
extra_repr
()¶ Set the extra representation of the data set.
To print customized extra information, you should reimplement this method in your own data set. Both single-line and multi-line strings are acceptable.
The extra information will be shown after the sampling rate entry.
PandasDataset¶
-
class
audtorch.datasets.
PandasDataset
(*args: Any, **kwargs: Any)¶ Data set from pandas.DataFrame.
Create a data set by accessing the file locations and corresponding labels through a pandas.DataFrame.
You have to specify which labels of the data set you want as target by the names of the corresponding columns in the data frame. If you want to select one of those columns the label is returned directly in its corresponding data type or you can specify a list of columns and the data set will return a dictionary containing the labels.
The filenames of the corresponding audio files have to be specified with absolute path. If they are relative to a folder, you have to use the
root
argument to specify that folder.transform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetssampling_rate
holds the sampling rate of the returned dataoriginal_sampling_rate
holds the sampling rate of the audio files of the data setcolumn_labels
holds the name of the label columns
- Parameters
df (pandas.DataFrame) – data frame with filenames and labels
sampling_rate (int) – sampling rate in Hz of the data set
root (str, optional) – root directory added before the files listed in the CSV file. Default: None
column_labels (str or list of str, optional) – name of data frame column(s) containing the desired labels. Default: label
column_filename (str, optional) – name of column holding the file names. Default: file
column_start (str, optional) – name of column holding start of audio in the corresponding file in seconds. Default: None
column_end (str, optional) – name of column holding end of audio in the corresponding file in seconds. Default: None
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Example
>>> data = PandasDataset(root='/data', ... df=dataset_dataframe, ... sampling_rate=44100, ... column_labels='age') >>> print(data) Dataset AudioDataset Number of data points: 120 Root Location: /data Sampling Rate: 44100Hz Label: age >>> signal, target = data[0] >>> target 31
CsvDataset¶
-
class
audtorch.datasets.
CsvDataset
(*args: Any, **kwargs: Any)¶ Data set from CSV files.
Create a data set by reading the file locations and corresponding labels from a CSV file.
You have to specify which labels you want as the target of the data set by the names of the corresponding columns in the CSV file. If you want to select one of those columns the target is returned directly in its corresponding data type or you can specify a list of columns and the data set will return a dictionary containing the targets.
The filenames of the corresponding audio files have to be specified with absolute path. If they are relative to a folder, you have to use the
root
argument to specify that folder.transform
controls the input transformtarget_transform
controls the target transformfiles
controls the audio files of the data settargets
controls the corresponding targetssampling_rate
holds the sampling rate of the returned dataoriginal_sampling_rate
holds the sampling rate of the audio files of the data setcsv_file
holds the path to the used CSV file
- Parameters
csv_file (str) – CSV file with filenames and labels
sampling_rate (int) – sampling rate in Hz of the data set
root (str, optional) – root directory added before the files listed in the CSV file. Default: None
column_labels (str or list of str, optional) – name of CSV column(s) containing the desired labels. Default: label
column_filename (str, optional) – name of CSV column holding the file names. Default: file
column_start (str, optional) – name of column holding start of audio in the corresponding file in seconds. Default: None
column_end (str, optional) – name of column holding end of audio in the corresponding file in seconds. Default: None
sep (str, optional) – CSV delimiter. Default: ,
transform (callable, optional) – function/transform applied on the signal. Default: None
target_transform (callable, optional) – function/transform applied on the target. Default: None
Example
>>> data = CsvDataset(csv_file='/data/train.csv', ... sampling_rate=44100, ... column_labels='age') >>> print(data) Dataset AudioDataset Number of data points: 120 Sampling Rate: 44100Hz Label: age CSV file: /data/train.csv >>> signal, target = data[0] >>> target 31
AudioConcatDataset¶
-
class
audtorch.datasets.
AudioConcatDataset
(*args: Any, **kwargs: Any)¶ Concatenation data set of multiple audio data sets.
This data set checks that all audio data sets are compatible with respect to the sampling rate which they are processed with.
sampling_rate
holds the consistent sampling rate of the concatenated data setdatasets
holds a list of all audio data setscumulative_sizes
holds a list of sizes accumulated over all audio data sets, i.e. [len(data1), len(data1) + len(data2), …]
- Parameters
datasets (list of audtorch.AudioDataset) – Audio data sets with property sampling_rate.
Example
>>> import sounddevice as sd >>> from audtorch.datasets import LibriSpeech >>> dev_clean = LibriSpeech(root='/data/LibriSpeech', sets='dev-clean') >>> dev_other = LibriSpeech(root='/data/LibriSpeech', sets='dev-other') >>> data = AudioConcatDataset([dev_clean, dev_other]) >>> print(data) Data set AudioConcatDataset Number of data points: 5567 Sampling Rate: 16000Hz data sets data points extra ----------- ------------- --------------- LibriSpeech 2703 Sets: dev-clean LibriSpeech 2864 Sets: dev-other >>> signal, label = data[8] >>> label AS FOR ETCHINGS THEY ARE OF TWO KINDS BRITISH AND FOREIGN >>> sd.play(signal.transpose(), data.sampling_rate)
-
extra_repr
()¶ Set the extra representation of the data set.
To print customized extra information, you should reimplement this method in your own data set. Both single-line and multi-line strings are acceptable.
The extra information will be shown after the sampling rate entry.
Mixture¶
This section contains data sets that are primarily used for mixing different data sets.
SpeechNoiseMix¶
-
class
audtorch.datasets.
SpeechNoiseMix
(*args: Any, **kwargs: Any)¶ Mix speech and noise with speech as target.
Add noise to each speech sample from the provided data by a mix transform. Return the mix as input and the speech signal as corresponding target. In addition, allow to replace randomly some of the mixes by noise as input and silence as output. This helps to train a speech enhancement algorithm to deal with background noise only as input signal [RPS18].
speech_dataset
controls the speech data setmix_transform
controls the transform that adds noisetransform
controls the transform applied on the mixtarget_transform
controls the transform applied on the target clean speechjoint_transform
controls the transform applied jointly on the mixture an the target clean speechpercentage_silence
controls the amount of noise-silent data augmentation
- Parameters
speech_dataset (Dataset) – speech data set
mix_transform (callable) – function/transform that can augment a signal with noise
transform (callable, optional) – function/transform applied on the speech-noise-mixture (input) only. Default; None
target_transform (callable, optional) – function/transform applied on the speech (target) only. Default: None
joint_transform (callable, optional) – function/transform applied on the mixtue (input) and speech (target) simultaneously. If the transform includes randomization it is applied with the same random parameter during both calls
percentage_silence (float, optional) – value between 0 and 1, which controls the percentage of randomly inserted noise input, silent target pairs. Default: 0
Examples
>>> import sounddevice as sd >>> from audtorch import datasets, transforms >>> noise = datasets.WhiteNoise(duration=10, sampling_rate=48000) >>> mix = transforms.RandomAdditiveMix(noise) >>> normalize = transforms.Normalize() >>> speech = datasets.MozillaCommonVoice(root='/data/MozillaCommonVoice/cv_corpus_v1') >>> data = SpeechNoiseMix(speech, mix, transform=normalize) >>> noisy, clean = data[0] >>> sd.play(noisy.transpose(), data.sampling_rate)
Utils¶
Utility functions for handling audio data sets.
load¶
-
audtorch.datasets.
load
(filename, *, duration=None, offset=0)¶ Load audio file.
If an error occurrs during loading as the file could not be found, is empty, or has the wrong format an empty signal is returned and a warning shown.
- Parameters
- Returns
numpy.ndarray: two-dimensional array with shape (channels, samples)
int: sample rate of the audio file
- Return type
Example
>>> signal, sampling_rate = load('speech.wav')
download_url¶
-
audtorch.datasets.
download_url
(url, root, *, filename=None, md5=None)¶ Download a file from an url to a specified directory.
- Parameters
- Returns
path to downloaded file
- Return type
download_url_list¶
-
audtorch.datasets.
download_url_list
(urls, root, *, num_workers=0)¶ Download files from a list of URLs to a specified directory.
- Parameters
urls (list of str or dict) – either list of URLs or dictionary with URLs as keys and with either filenames or tuples of filename and MD5 checksum as values. Uses basename of URL if filename is None. Performs no check if MD5 checksum is None
root (str) – directory to place downloaded files in
num_workers (int, optional) – number of worker threads (0 = len(urls)). Default: 0
extract_archive¶
-
audtorch.datasets.
extract_archive
(filename, *, out_path=None, remove_finished=False)¶ Extract archive.
Currently tar.gz and tar archives are supported.
sampling_rate_after_transform¶
-
audtorch.datasets.
sampling_rate_after_transform
(dataset)¶ Sampling rate of data set after all transforms are applied.
A change of sampling rate by a transform is only recognized, if that transform has the attribute
output_sampling_rate
.- Parameters
dataset (torch.utils.data.Dataset) – data set with sampling_rate attribute or property
- Returns
sampling rate in Hz after all transforms are applied
- Return type
Example
>>> from audtorch import datasets, transforms >>> t = transforms.Resample(input_sampling_rate=16000, ... output_sampling_rate=8000) >>> data = datasets.WhiteNoise(sampling_rate=16000, transform=t) >>> sampling_rate_after_transform(data) 8000
ensure_same_sampling_rate¶
-
audtorch.datasets.
ensure_same_sampling_rate
(datasets)¶ Raise error if provided data set differ in sampling rate.
All data sets that are checked need to have a sampling_rate attribute or property.
- Parameters
datasets (list of torch.utils.data.Dataset) – list of at least two audio data sets.
ensure_df_columns_contain¶
-
audtorch.datasets.
ensure_df_columns_contain
(df, labels)¶ Raise error if list of labels are not in dataframe columns.
- Parameters
df (pandas.dataframe) – data frame
labels (list of str) – labels to be expected in df.columns
Example
>>> import pandas as pd >>> df = pd.DataFrame(data=[(1, 2)], columns=['a', 'b']) >>> ensure_df_columns_contain(df, ['a', 'c']) Traceback (most recent call last): RuntimeError: Dataframe contains only these columns: 'a, b'
ensure_df_not_empty¶
-
audtorch.datasets.
ensure_df_not_empty
(df, labels=None)¶ Raise error if dataframe is empty.
- Parameters
df (pandas.dataframe) – data frame
labels (list of str, optional) – list of labels used to shrink data set. Default: None
Example
>>> import pandas as pd >>> df = pd.DataFrame() >>> ensure_df_not_empty(df) Traceback (most recent call last): RuntimeError: No valid data points found in data set
files_and_labels_from_df¶
-
audtorch.datasets.
files_and_labels_from_df
(df, *, column_labels=None, column_filename='filename')¶ Extract list of files and labels from dataframe columns.
- Parameters
- Returns
list of str: list of files
list of str or list of dicts: list of labels
- Return type
Example
>>> import pandas as pd >>> df = pd.DataFrame(data=[('speech.wav', 'speech')], ... columns=['filename', 'label']) >>> files, labels = files_and_labels_from_df(df, column_labels='label') >>> os.path.relpath(files[0]), labels[0] ('speech.wav', 'speech')
defined_split¶
-
audtorch.datasets.
defined_split
(dataset, split_func)¶ Split data set into desired non-overlapping subsets.
- Parameters
dataset (torch.utils.data.Dataset) – data set to be split
split_func (func) – function mapping from data set index to subset id, \(f(\text{index}) = \text{subset\_id}\). The target domain of subset ids does not need to cover the complete range [0, 1, …, (num_subsets - 1)]
- Returns
desired subsets according to
split_func
- Return type
(list of Subsets)
Example
>>> import torch >>> from torch.utils.data import TensorDataset >>> from audtorch.samplers import buckets_of_even_size >>> data = TensorDataset(torch.randn(100)) >>> lengths = np.random.randint(0, 1000, (100,)) >>> split_func = buckets_of_even_size(lengths, 5) >>> subsets = defined_split(data, split_func) >>> [len(subset) for subset in subsets] [20, 20, 20, 20, 20]