Dataset and Database Loaders¶
-
class
radbm.loaders.base.IRLoader(mode, which, backend, device, rng=None)¶ A subclass of Loader meant for Information Retrieval. This introduces the notion of mode which gouverns to way batches will be given.
Parameters: mode (str) – should be in IRLoader.get_available_modes()
-
class
radbm.loaders.base.Loader(which, backend, device, rng=None)¶ An abstract class managing numpy vs torch, cpu vs gpu and train vs valid vs test. This should be subclassed with a particular dataset (i.e. Mnist).
Parameters: - which (str) – Which datasets version to use. Should be ‘train’, ‘valid’ or ‘test’.
- backend (str) – Which backend to use, should be ‘numpy’ or ‘torch’.
- device (str) – Which device to use, should be ‘cpu’ or ‘cuda’. ‘cuda’ is only available if backend==’torch’.
- rng (numpy.random.RandomState) – A random number generator for reproducibility.
-
cpu()¶ Transfers each registered data (using register_switch) to the CPU
-
cuda()¶ Transfers each registered data (using register_switch) to the GPU
Raises: ValueError– If backend==’numpy’
-
dynamic_cast(data)¶ Cast data according to the current state of the class. E.g. when backend==’torch’, device==’cuda’ and the inputed data is numpy.ndarray, the array will be converted to torch.Tensor and transfered on the GPU.
Parameters: data (numpy.ndarray or torch.Tensor) – The data to cast Returns: casted_data – The casted data Return type: numpy.ndarray or torch.Tensor
-
get_rng()¶ Utility method to get the rng.
Returns: rng – The rng used inside the utility class TorchNumpyRNG. Return type: numpy.random.RandomState
-
numpy()¶ Converts each registered data (using register_switch) into numpy.ndarray
Raises: ValueError– If device==’cuda’
-
register_switch(name, data)¶ This function should only be used when subclassing. This is to register data for when a user will call: numpy(), torch(), cpu() or cuda(). Each value will be transfered to the appropriate format.
Parameters: - name (str) – The name of the data. setarrt is used so one could later do self.<name> to reach the data.
- data (numpy.ndarray or torch.Tensor) – The data to register.
-
set_rng(rng)¶ Utility method to set the rng.
Parameters: rng (numpy.random.RandomState or TorchNumpyRNG) – The rng to use going forward. Returns: self Return type: Loader
-
test()¶ Switch to testing dataset.
-
torch()¶ Converts each registered data (using register_switch) into torch.Tensor
-
train()¶ Switch to training dataset.
-
valid()¶ Switch to validation dataset.