Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Base classes for datasets and loaders."""

import abc
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -239,6 +240,12 @@ class Loader(abc.ABC, cebra.io.HasDevice):
batch_size: int = dataclasses.field(default=None,
doc="""The total batch size.""")

num_negatives: int = dataclasses.field(
default=None,
doc=("The number of negative samples to draw for each reference. "
"If not specified, the batch size is used."),
)

def __post_init__(self):
if self.num_steps is None or self.num_steps <= 0:
raise ValueError(
Expand All @@ -248,28 +255,41 @@ def __post_init__(self):
raise ValueError(
f"Batch size has to be None, or a non-negative value. Got {self.batch_size}."
)
if self.num_negatives is not None and self.num_negatives <= 0:
raise ValueError(
f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}."
)

if self.num_negatives is None:
self.num_negatives = self.batch_size

def __len__(self):
"""The number of batches returned when calling as an iterator."""
return self.num_steps

def __iter__(self) -> Batch:
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch(index)

@abc.abstractmethod
def get_indices(self, num_samples: int):
def get_indices(self, *, num_samples: int = None):
"""Sample and return the specified number of indices.

The elements of the returned `BatchIndex` will be used to index the
`dataset` of this data loader.

Args:
num_samples: The size of each of the reference, positive and
negative samples.
num_samples: Deprecated. Use ``batch_size`` on the instance level
instead.

Returns:
batch indices for the reference, positive and negative sample.

Note:
From version 0.7.0 onwards, specifying the ``num_samples``
directly is deprecated and will be removed in version 0.8.0.
Please set ``batch_size`` and ``num_negatives`` on the instance
level instead.
"""
raise NotImplementedError()
22 changes: 18 additions & 4 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def __post_init__(self):
super().__post_init__()
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

def get_indices(self, num_samples: int) -> List[BatchIndex]:
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
# is not used in the multi-session case, which is different to the single session samples.
def get_indices(self) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)

ref_idx = torch.from_numpy(ref_idx)
Expand Down Expand Up @@ -192,8 +196,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
# NOTE(stes): __post_init__ from superclass is intentionally not called.
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
self.dataset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

@property
def index(self):
Expand Down Expand Up @@ -229,7 +236,14 @@ def __post_init__(self):
self.sampler = cebra.distributions.UnifiedSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
if self.batch_size is not None and self.batch_size < 2:
raise ValueError("UnifiedLoader does not support batch_size < 2.")

if self.num_negatives is not None and self.num_negatives < 2:
raise ValueError(
"UnifiedLoader does not support num_negatives < 2.")

def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices.

The elements of the returned ``BatchIndex`` will be used to index the
Expand All @@ -251,7 +265,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Batch indices for the reference, positive and negative samples.
"""
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)

pos_idx = self.sampler.sample_conditional(ref_idx)

Expand Down
22 changes: 13 additions & 9 deletions cebra/data/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
# limitations under the License.
#

from typing import Iterator

import literate_dataclasses as dataclasses

import cebra.data as cebra_data
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex
from cebra.distributions.continuous import Prior

Expand Down Expand Up @@ -71,9 +74,9 @@ def __post_init__(self):
def add_config(self, config):
self.labels.append(config['label'])

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
if self.sampling_mode_supervised == "ref_shared":
reference_idx = self.prior.sample_prior(num_samples)
reference_idx = self.prior.sample_prior(self.batch_size)
else:
raise ValueError(
f"Sampling mode {self.sampling_mode_supervised} is not implemented."
Expand All @@ -87,9 +90,9 @@ def get_indices(self, num_samples: int):

return batch_index

def __iter__(self):
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_supervised(index, self.labels)


Expand Down Expand Up @@ -142,13 +145,14 @@ def add_config(self, config):

self.distributions.append(distribution)

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices."""

if self.sampling_mode_contrastive == "refneg_shared":
ref_and_neg = self.prior.sample_prior(num_samples * 2)
reference_idx = ref_and_neg[:num_samples]
negative_idx = ref_and_neg[num_samples:]
ref_and_neg = self.prior.sample_prior(self.batch_size +
self.num_negatives)
reference_idx = ref_and_neg[:self.batch_size]
negative_idx = ref_and_neg[self.batch_size:]

positives_idx = []
for distribution in self.distributions:
Expand All @@ -169,5 +173,5 @@ def get_indices(self, num_samples: int):

def __iter__(self):
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_contrastive(index)
78 changes: 36 additions & 42 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import abc
import warnings
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -138,7 +139,7 @@ def _init_distribution(self):
f"Invalid choice of prior distribution. Got '{self.prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference samples will be sampled from the empirical or uniform prior
Expand All @@ -151,16 +152,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
The negative samples will be sampled from the same distribution as the
reference examples.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
reference = self.index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
Expand Down Expand Up @@ -246,7 +244,7 @@ def _init_distribution(self):
else:
raise ValueError(self.conditional)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -255,16 +253,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
The positive samples will be sampled conditional on the reference
samples according to the specified ``conditional`` distribution.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
positive_idx = self.distribution.sample_conditional(reference_idx)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
Expand Down Expand Up @@ -305,7 +300,7 @@ def __post_init__(self):
continuous=self.cindex,
time_delta=self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -316,10 +311,6 @@ def get_indices(self, num_samples: int) -> BatchIndex:
:py:class:`ContinuousDataLoader`, or just sampled based on the
conditional variable.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.

Expand All @@ -328,10 +319,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
class.
- Sample the negatives with matching discrete variable
"""
reference_idx = self.distribution.sample_prior(num_samples)
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
negative=negative_idx,
positive=self.distribution.sample_conditional(reference_idx),
)

Expand Down Expand Up @@ -421,32 +415,29 @@ def _init_time_distribution(self):
else:
raise ValueError

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
all available time steps, and a total of ``2*num_samples`` will be
returned for both.
all available time steps, and a total of ``self.batch_size + self.num_negatives``
will be returned for both.

For the positive samples, ``num_samples`` are sampled according to the
behavior conditional distribution, and another ``num_samples`` are
sampled according to the dime contrastive distribution. The indices
For the positive samples, ``self.batch_size`` samples are sampled according to the
behavior conditional distribution, and another ``self.batch_size`` samples are
sampled according to the time contrastive distribution. The indices
for the positive samples are concatenated across the first dimension.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.

Todo:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
reference_idx = self.time_distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.time_distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
behavior_positive_idx = self.behavior_distribution.sample_conditional(
reference_idx)
time_positive_idx = self.time_distribution.sample_conditional(
Expand All @@ -464,13 +455,18 @@ class FullDataLoader(ContinuousDataLoader):

def __post_init__(self):
super().__post_init__()
self.batch_size = None

if self.batch_size is not None:
raise ValueError("Batch size cannot be set for FullDataLoader.")
if self.num_negatives is not None:
raise ValueError(
"Number of negatives cannot be set for FullDataLoader.")

@property
def offset(self):
return self.dataset.offset

def get_indices(self, num_samples=None) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference indices are all available (valid, according to the
Expand All @@ -490,7 +486,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
assert num_samples is None

reference_idx = torch.arange(
self.offset.left,
Expand All @@ -504,7 +499,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
positive=positive_idx,
negative=negative_idx)

def __iter__(self):
def __iter__(self) -> Iterator[BatchIndex]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
yield index
yield self.get_indices()
Loading