Skip to content

API reference for development package labproject

Here all functions will be documented that are part of the public API of the labproject package.

Metrics

Best practices for developing metrics:

  1. Please do everything in torch, and if that is not possible, cast the output to torch.Tensor.
  2. The function should be well-documented, including type hints.
  3. The function should be tested with a simple example.
  4. Add an assert at the beginning for shape checking (N,D), see examples.
  5. Register the function by importing labrpoject.metrics.utils.regiter_metric and give your function a meaningful name.

Gaussian KL divergence

gaussian_kl_divergence(real_samples, fake_samples)

Compute the KL divergence between Gaussian approximations of real and fake samples. Dimensionality of the samples must be the same and >=2 (for covariance calculation).

In detail, for each set of samples, we calculate the mean and covariance matrix.

\[ \mu_{\text{real}} = \frac{1}{n} \sum_{i=1}^{n} x_i \qquad \mu_{\text{fake}} = \frac{1}{n} \sum_{i=1}^{n} y_i \]
\[ \Sigma_{\text{real}} = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \mu_{\text{real}})(x_i - \mu_{\text{real}})^T \qquad \Sigma_{\text{fake}} = \frac{1}{n-1} \sum_{i=1}^{n} (y_i - \mu_{\text{fake}})(y_i - \mu_{\text{fake}})^T \]

Then we calculate the KL divergence between the two Gaussian approximations:

\[ D_{KL}(N(\mu_{\text{real}}, \Sigma_{\text{real}}) || N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) = \frac{1}{2} \left( \text{tr}(\Sigma_{\text{fake}}^{-1} \Sigma_{\text{real}}) + (\mu_{\text{fake}} - \mu_{\text{real}})^T \Sigma_{\text{fake}}^{-1} (\mu_{\text{fake}} - \mu_{\text{real}}) - k + \log \frac{|\Sigma_{\text{fake}}|}{|\Sigma_{\text{real}}|} \right) \]

Parameters:

Name Type Description Default
real_samples Tensor

A tensor representing the real samples.

required
fake_samples Tensor

A tensor representing the fake samples.

required

Returns:

Type Description
Tensor

torch.Tensor: The KL divergence between the two Gaussian approximations.

Examples:

>>> real_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
>>> kl_div = gaussian_kl_divergence(real_samples, fake_samples)
>>> print(kl_div)
Source code in labproject/metrics/gaussian_kl.py
@register_metric("gaussian_kl_divergence")
def gaussian_kl_divergence(real_samples: Tensor, fake_samples: Tensor) -> Tensor:
    r"""
    Compute the KL divergence between Gaussian approximations of real and fake samples.
    Dimensionality of the samples must be the same and >=2 (for covariance calculation).

    In detail, for each set of samples, we calculate the mean and covariance matrix.

    $$ \mu_{\text{real}} = \frac{1}{n} \sum_{i=1}^{n} x_i \qquad \mu_{\text{fake}} = \frac{1}{n} \sum_{i=1}^{n} y_i $$


    $$
    \Sigma_{\text{real}} = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \mu_{\text{real}})(x_i - \mu_{\text{real}})^T \qquad
    \Sigma_{\text{fake}} = \frac{1}{n-1} \sum_{i=1}^{n} (y_i - \mu_{\text{fake}})(y_i - \mu_{\text{fake}})^T
    $$

    Then we calculate the KL divergence between the two Gaussian approximations:

    $$
    D_{KL}(N(\mu_{\text{real}}, \Sigma_{\text{real}}) || N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) =
    \frac{1}{2} \left( \text{tr}(\Sigma_{\text{fake}}^{-1} \Sigma_{\text{real}}) + (\mu_{\text{fake}} - \mu_{\text{real}})^T \Sigma_{\text{fake}}^{-1} (\mu_{\text{fake}} - \mu_{\text{real}})
    - k + \log \frac{|\Sigma_{\text{fake}}|}{|\Sigma_{\text{real}}|} \right)
    $$

    Args:
        real_samples (torch.Tensor): A tensor representing the real samples.
        fake_samples (torch.Tensor): A tensor representing the fake samples.

    Returns:
        torch.Tensor: The KL divergence between the two Gaussian approximations.

    Examples:
        >>> real_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
        >>> fake_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
        >>> kl_div = gaussian_kl_divergence(real_samples, fake_samples)
        >>> print(kl_div)
    """

    # check input (n,d only)
    assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
    assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

    # calculate mean and covariance of real and fake samples
    mu_real = real_samples.mean(dim=0)
    mu_fake = fake_samples.mean(dim=0)
    cov_real = torch.cov(real_samples.t())
    cov_fake = torch.cov(fake_samples.t())

    # ensure the covariance matrices are invertible
    eps = 1e-8
    cov_real += torch.eye(cov_real.size(0)) * eps
    cov_fake += torch.eye(cov_fake.size(0)) * eps

    # compute KL divergence
    inv_cov_fake = torch.inverse(cov_fake)
    kl_div = 0.5 * (
        torch.trace(inv_cov_fake @ cov_real)
        + (mu_fake - mu_real).dot(inv_cov_fake @ (mu_fake - mu_real))
        - real_samples.size(1)
        + torch.log(torch.det(cov_fake) / torch.det(cov_real))
    )

    return kl_div

Gaussian Wasserstein

gaussian_squared_w2_distance(real_samples, fake_samples, real_mu=None, real_cov=None)

Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples. Dimensionality of the samples must be the same and >=2 (for covariance calculation).

In detail, for each set of samples, we calculate the mean and covariance matrix.

\[ \mu_{\text{real}} = \frac{1}{n} \sum_{i=1}^{n} x_i \qquad \mu_{\text{fake}} = \frac{1}{n} \sum_{i=1}^{n} y_i \]
\[ \Sigma_{\text{real}} = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \mu_{\text{real}})(x_i - \mu_{\text{real}})^T \qquad \Sigma_{\text{fake}} = \frac{1}{n-1} \sum_{i=1}^{n} (y_i - \mu_{\text{fake}})(y_i - \mu_{\text{fake}})^T \]

Then we calculate the squared Wasserstein distance between the two Gaussian approximations:

\[ d_{W_2}^2(N(\mu_{\text{real}}, \Sigma_{\text{real}}), N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) = \left\| \mu_{\text{real}} - \mu_{\text{fake}} \right\|^2 + \text{tr}(\Sigma_{\text{real}} + \Sigma_{\text{fake}} - 2 \sqrt{\Sigma_{\text{real}} \Sigma_{\text{fake}}}) \]

Parameters:

Name Type Description Default
real_samples Tensor

A tensor representing the real samples.

required
fake_samples Tensor

A tensor representing the fake samples.

required

Returns:

Type Description
Tensor

torch.Tensor: The KL divergence between the two Gaussian approximations.

References

[1] https://en.wikipedia.org/wiki/Wasserstein_metric [2] https://arxiv.org/pdf/1706.08500.pdf

Examples:

>>> real_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
>>> w2 = gaussian_squared_w2_distance(real_samples, fake_samples)
>>> print(w2)
Source code in labproject/metrics/gaussian_squared_wasserstein.py
@register_metric("wasserstein_gauss_squared")
def gaussian_squared_w2_distance(
    real_samples: Tensor, fake_samples: Tensor, real_mu=None, real_cov=None
) -> Tensor:
    r"""
    Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples.
    Dimensionality of the samples must be the same and >=2 (for covariance calculation).

    In detail, for each set of samples, we calculate the mean and covariance matrix.

    $$ \mu_{\text{real}} = \frac{1}{n} \sum_{i=1}^{n} x_i \qquad \mu_{\text{fake}} = \frac{1}{n} \sum_{i=1}^{n} y_i $$


    $$
    \Sigma_{\text{real}} = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \mu_{\text{real}})(x_i - \mu_{\text{real}})^T \qquad
    \Sigma_{\text{fake}} = \frac{1}{n-1} \sum_{i=1}^{n} (y_i - \mu_{\text{fake}})(y_i - \mu_{\text{fake}})^T
    $$

    Then we calculate the squared Wasserstein distance between the two Gaussian approximations:

    $$
    d_{W_2}^2(N(\mu_{\text{real}}, \Sigma_{\text{real}}), N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) =
    \left\| \mu_{\text{real}} - \mu_{\text{fake}} \right\|^2 + \text{tr}(\Sigma_{\text{real}} + \Sigma_{\text{fake}} - 2 \sqrt{\Sigma_{\text{real}} \Sigma_{\text{fake}}})
    $$

    Args:
        real_samples (torch.Tensor): A tensor representing the real samples.
        fake_samples (torch.Tensor): A tensor representing the fake samples.

    Returns:
        torch.Tensor: The KL divergence between the two Gaussian approximations.

    References:
        [1] https://en.wikipedia.org/wiki/Wasserstein_metric
        [2] https://arxiv.org/pdf/1706.08500.pdf

    Examples:
        >>> real_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
        >>> fake_samples = torch.randn(100, 2)  # 100 samples, 2-dimensional
        >>> w2 = gaussian_squared_w2_distance(real_samples, fake_samples)
        >>> print(w2)
    """

    # check input (n,d only)
    assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
    assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

    if real_samples.shape[-1] == 1:
        mu_real = real_samples.mean(dim=0)
        var_real = real_samples.var(dim=0)

        mu_fake = fake_samples.mean(dim=0)
        var_fake = fake_samples.var(dim=0)

        w2_squared_dist = (mu_real - mu_fake) ** 2 + (
            var_real + var_fake - 2 * (var_real * var_fake).sqrt()
        )

        return w2_squared_dist
    else:
        # calculate mean and covariance of real and fake samples
        if real_mu is None:
            mu_real = real_samples.mean(dim=0)
        else:
            mu_real = real_mu
        if real_cov is None:
            cov_real = torch.cov(real_samples.t())
        else:
            cov_real = real_cov

        mu_fake = fake_samples.mean(dim=0)
        cov_fake = torch.cov(fake_samples.t())

        # ensure the covariance matrices are invertible
        eps = 1e-6
        cov_real += torch.eye(cov_real.size(0)) * eps
        cov_fake += torch.eye(cov_fake.size(0)) * eps

        # compute KL divergence
        mean_dist = torch.norm(mu_real - mu_fake, p=2)
        cov_sqrt = scipy.linalg.sqrtm((cov_real @ cov_fake).numpy())
        # print(cov_sqrt.real)
        cov_sqrt = torch.from_numpy(cov_sqrt.real)
        cov_dist = torch.trace(cov_real + cov_fake - 2 * cov_sqrt)
        w2_squared_dist = mean_dist**2 + cov_dist

        return w2_squared_dist

Sliced Wasserstein

rand_projections(embedding_dim, num_samples)

This function generates num_samples random samples from the latent space's unti sphere.r

Parameters:

Name Type Description Default
embedding_dim int

dimention of the embedding

required
sum_samples int

number of samples

required
Return

torch.tensor: tensor of size (num_samples, embedding_dim)

Source code in labproject/metrics/sliced_wasserstein.py
def rand_projections(embedding_dim: int, num_samples: int):
    """
    This function generates num_samples random samples from the latent space's unti sphere.r

    Args:
        embedding_dim (int): dimention of the embedding
        sum_samples (int): number of samples

    Return :
        torch.tensor: tensor of size (num_samples, embedding_dim)
    """

    ws = torch.randn((num_samples, embedding_dim))
    projection = ws / torch.norm(ws, dim=-1, keepdim=True)
    return projection

sliced_wasserstein_distance(encoded_samples, distribution_samples, num_projections=50, p=2, device='cpu')

Sliced Wasserstein distance between encoded samples and distribution samples. Note that the SWD does not converge to the true Wasserstein distance, but rather it is a different proper distance metric.

Parameters:

Name Type Description Default
encoded_samples Tensor

tensor of encoded training samples

required
distribution_samples Tensor

tensor drawn from the prior distribution

required
num_projection int

number of projections to approximate sliced wasserstein distance

required
p int

power of distance metric

2
device device

torch device 'cpu' or 'cuda' gpu

'cpu'
Return

torch.Tensor: Tensor of wasserstein distances of size (num_projections, 1)

Source code in labproject/metrics/sliced_wasserstein.py
@register_metric("sliced_wasserstein")
def sliced_wasserstein_distance(
    encoded_samples: Tensor,
    distribution_samples: Tensor,
    num_projections: int = 50,
    p: int = 2,
    device: str = "cpu",
) -> Tensor:
    """
    Sliced Wasserstein distance between encoded samples and distribution samples.
    Note that the SWD does not converge to the true Wasserstein distance, but rather it is a different proper distance metric.

    Args:
        encoded_samples (torch.Tensor): tensor of encoded training samples
        distribution_samples (torch.Tensor): tensor drawn from the prior distribution
        num_projection (int): number of projections to approximate sliced wasserstein distance
        p (int): power of distance metric
        device (torch.device): torch device 'cpu' or 'cuda' gpu

    Return:
        torch.Tensor: Tensor of wasserstein distances of size (num_projections, 1)
    """

    # check input (n,d only)
    assert len(encoded_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
    assert len(distribution_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

    embedding_dim = distribution_samples.size(-1)

    projections = rand_projections(embedding_dim, num_projections).to(device)

    encoded_projections = encoded_samples.matmul(projections.transpose(-2, -1))

    distribution_projections = distribution_samples.matmul(projections.transpose(-2, -1))

    wasserstein_distance = (
        torch.sort(encoded_projections.transpose(-2, -1), dim=-1)[0]
        - torch.sort(distribution_projections.transpose(-2, -1), dim=-1)[0]
    )

    wasserstein_distance = torch.pow(torch.abs(wasserstein_distance), p)

    return torch.pow(torch.mean(wasserstein_distance, dim=(-2, -1)), 1 / p)

Main Modules

Data

download_file(remote_path, local_path)

Downloads a file from the Hetzner Storage Box.

Parameters:

Name Type Description Default
remote_path str

The path to the remote file to be downloaded.

required
local_path str

The path where the file should be saved locally.

required

Returns:

Name Type Description
bool

True if the download is successful, False otherwise.

Example

if download_file('path/to/remote/file.txt', 'path/to/save/file.txt'): print("Download successful") else: print("Download failed")

Source code in labproject/data.py
def download_file(remote_path, local_path):
    r"""
    Downloads a file from the Hetzner Storage Box.

    Args:
        remote_path (str): The path to the remote file to be downloaded.
        local_path (str): The path where the file should be saved locally.

    Returns:
        bool: True if the download is successful, False otherwise.

    Example:
        >>> if download_file('path/to/remote/file.txt', 'path/to/save/file.txt'):
        >>>     print("Download successful")
        >>> else:
        >>>     print("Download failed")
    """
    url = f"{STORAGEBOX_URL}/remote.php/dav/files/{HETZNER_STORAGEBOX_USERNAME}/{remote_path}"
    auth = HTTPBasicAuth(HETZNER_STORAGEBOX_USERNAME, HETZNER_STORAGEBOX_PASSWORD)
    response = requests.get(url, auth=auth)
    if response.status_code == 200:
        with open(local_path, "wb") as f:
            f.write(response.content)
        return True
    return False

get_dataset(name)

Get a dataset by name

Parameters:

Name Type Description Default
name str

Name of the dataset

required
n int

Number of samples

required
d int

Dimensionality of the samples

required

Returns:

Type Description
Tensor

torch.Tensor: Dataset

Source code in labproject/data.py
def get_dataset(name: str) -> torch.Tensor:
    r"""Get a dataset by name

    Args:
        name (str): Name of the dataset
        n (int): Number of samples
        d (int): Dimensionality of the samples

    Returns:
        torch.Tensor: Dataset
    """
    assert name in DATASETS, f"Dataset {name} not found, please register it first "
    return DATASETS[name]

get_distribution(name)

Get a distribution by name

Parameters:

Name Type Description Default
name str

Name of the distribution

required

Returns:

Type Description
Tensor

torch.Tensor: Distribution

Source code in labproject/data.py
def get_distribution(name: str) -> torch.Tensor:
    r"""Get a distribution by name

    Args:
        name (str): Name of the distribution

    Returns:
        torch.Tensor: Distribution
    """
    assert name in DISTRIBUTIONS, f"Distribution {name} not found, please register it first "
    return DISTRIBUTIONS[name]

imagenet_conditional_model(n, d=2048, label=None, device='cpu', permute_if_no_label=True, save_path='data')

Get the conditional model embeddings for ImageNet

Parameters:

Name Type Description Default
n int

Number of samples

required
d int

Dimensionality of the embeddings. Defaults to 2048.

2048
label int

Label, if None it takes random samples. Defaults to None.

None
device str

Device. Defaults to "cpu".

'cpu'

Returns:

Type Description

torch.Tensor: ImageNet embeddings

Source code in labproject/data.py
@register_dataset("imagenet_conditional_model")
def imagenet_conditional_model(
    n, d=2048, label: Optional[int] = None, device="cpu", permute_if_no_label=True, save_path="data"
):
    r"""Get the conditional model embeddings for ImageNet

    Args:
        n (int): Number of samples
        d (int, optional): Dimensionality of the embeddings. Defaults to 2048.
        label (int, optional): Label, if None it takes random samples. Defaults to None.
        device (str, optional): Device. Defaults to "cpu".

    Returns:
        torch.Tensor: ImageNet embeddings
    """
    assert d == 2048, "The dimensionality of the embeddings must be 2048"
    if not os.path.exists("imagenet_conditional_model.pt"):
        import gdown

        gdown.download(IMAGENET_CONDITIONAL_MODEL, "imagenet_conditional_model.pt", quiet=False)
    conditional_embeddings = torch.load("imagenet_conditional_model.pt")

    if label is not None:
        conditional_embeddings = conditional_embeddings[label]
    else:
        conditional_embeddings = conditional_embeddings.flatten(0, 1)
        if permute_if_no_label:
            conditional_embeddings = conditional_embeddings[
                torch.randperm(conditional_embeddings.shape[0])
            ]

    max_n = conditional_embeddings.shape[0]

    assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

    return conditional_embeddings[:n]

imagenet_test_embedding(n, d=2048, device='cpu', save_path='data')

Get the test embeddings for ImageNet

Parameters:

Name Type Description Default
n int

Number of samples

required
d int

Dimensionality of the embeddings. Defaults to 2048.

2048
device str

Device. Defaults to "cpu".

'cpu'

Returns:

Type Description

torch.Tensor: ImageNet embeddings

Source code in labproject/data.py
@register_dataset("imagenet_test_embedding")
def imagenet_test_embedding(n, d=2048, device="cpu", save_path="data"):
    r"""Get the test embeddings for ImageNet

    Args:
        n (int): Number of samples
        d (int, optional): Dimensionality of the embeddings. Defaults to 2048.
        device (str, optional): Device. Defaults to "cpu".

    Returns:
        torch.Tensor: ImageNet embeddings
    """
    assert d == 2048, "The dimensionality of the embeddings must be 2048"
    if not os.path.exists("imagenet_test_embedding.pt"):
        import gdown

        gdown.download(IMAGENET_TEST_EMBEDDING, "imagenet_test_embedding.pt", quiet=False)
    test_embeddigns = torch.load("imagenet_test_embedding.pt")

    max_n = test_embeddigns.shape[0]

    assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

    return test_embeddigns[:n]

imagenet_unconditional_model_embedding(n, d=2048, device='cpu', save_path='data')

Get the unconditional model embeddings for ImageNet

Parameters:

Name Type Description Default
n int

Number of samples

required
d int

Dimensionality of the embeddings. Defaults to 2048.

2048
device str

Device. Defaults to "cpu".

'cpu'

Returns:

Type Description

torch.Tensor: ImageNet embeddings

Source code in labproject/data.py
@register_dataset("imagenet_unconditional_model_embedding")
def imagenet_unconditional_model_embedding(n, d=2048, device="cpu", save_path="data"):
    r"""Get the unconditional model embeddings for ImageNet

    Args:
        n (int): Number of samples
        d (int, optional): Dimensionality of the embeddings. Defaults to 2048.
        device (str, optional): Device. Defaults to "cpu".

    Returns:
        torch.Tensor: ImageNet embeddings
    """
    assert d == 2048, "The dimensionality of the embeddings must be 2048"
    if not os.path.exists("imagenet_unconditional_model_embedding.pt"):
        import gdown

        gdown.download(
            IMAGENET_UNCONDITIONAL_MODEL_EMBEDDING,
            "imagenet_unconditional_model_embedding.pt",
            quiet=False,
        )
    unconditional_embeddigns = torch.load("imagenet_unconditional_model_embedding.pt")

    max_n = unconditional_embeddigns.shape[0]

    assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

    return unconditional_embeddigns[:n]

imagenet_validation_embedding(n, d=2048, device='cpu', save_path='data')

Get the validation embeddings for ImageNet

Parameters:

Name Type Description Default
n int

Number of samples

required
d int

Dimensionality of the embeddings. Defaults to 2048.

2048
device str

Device. Defaults to "cpu".

'cpu'

Returns:

Type Description

torch.Tensor: ImageNet embeddings

Source code in labproject/data.py
@register_dataset("imagenet_validation_embedding")
def imagenet_validation_embedding(n, d=2048, device="cpu", save_path="data"):
    r"""Get the validation embeddings for ImageNet

    Args:
        n (int): Number of samples
        d (int, optional): Dimensionality of the embeddings. Defaults to 2048.
        device (str, optional): Device. Defaults to "cpu".

    Returns:
        torch.Tensor: ImageNet embeddings
    """
    assert d == 2048, "The dimensionality of the embeddings must be 2048"
    if not os.path.exists("imagenet_validation_embedding.pt"):
        import gdown

        gdown.download(
            IMAGENET_VALIDATION_EMBEDDING, "imagenet_validation_embedding.pt", quiet=False
        )
    validation_embeddigns = torch.load("imagenet_validation_embedding.pt")

    max_n = validation_embeddigns.shape[0]

    assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

    return validation_embeddigns[:n]

load_cifar10(n, save_path='data', train=True, batch_size=100, shuffle=False, num_workers=1, device='cpu', return_labels=False)

Load a subset of cifar10

Parameters:

Name Type Description Default
n int

Number of samples to load

required
save_path str

Path to save files. Defaults to "data".

'data'
train bool

Train or test. Defaults to True.

True
batch_size int

Batch size. Defaults to 100.

100
shuffle bool

Shuffle. Defaults to False.

False
num_workers int

Parallel workers. Defaults to 1.

1
device str

Device. Defaults to "cpu".

'cpu'

Returns:

Type Description
Tensor

torch.Tensor: Cifar10 embeddings

Source code in labproject/data.py
def load_cifar10(
    n: int,
    save_path="data",
    train=True,
    batch_size=100,
    shuffle=False,
    num_workers=1,
    device="cpu",
    return_labels=False,
) -> torch.Tensor:
    """Load a subset of cifar10

    Args:
        n (int): Number of samples to load
        save_path (str, optional): Path to save files. Defaults to "data".
        train (bool, optional): Train or test. Defaults to True.
        batch_size (int, optional): Batch size. Defaults to 100.
        shuffle (bool, optional): Shuffle. Defaults to False.
        num_workers (int, optional): Parallel workers. Defaults to 1.
        device (str, optional): Device. Defaults to "cpu".

    Returns:
        torch.Tensor: Cifar10 embeddings
    """
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    cifar10 = CIFAR10(root=save_path, train=train, download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(
        cifar10, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
    )
    dataset_subset = Subset(dataloader.dataset, range(n))
    dataloader = DataLoader(
        dataset_subset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
    )
    net = FIDEmbeddingNet(device=device)
    if return_labels:
        embeddings, labels = net.get_embeddings_with_labels(dataloader)
        return embeddings, labels
    embeddings = net.get_embeddings(dataloader)
    return embeddings

register_dataset(name)

This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.

Parameters:

Name Type Description Default
func callable

Dataset generator function

required

Returns:

Name Type Description
callable callable

Dataset generator function wrapper

Example

@register_dataset("random") def random_dataset(n=1000, d=10): return torch.randn(n, d)

Source code in labproject/data.py
def register_dataset(name: str) -> callable:
    r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.

    Args:
        func (callable): Dataset generator function

    Returns:
        callable: Dataset generator function wrapper

    Example:
        >>> @register_dataset("random")
        >>> def random_dataset(n=1000, d=10):
        >>>     return torch.randn(n, d)
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(n: int, d: Optional[int] = None, **kwargs):

            assert n > 0, "n must be a positive integer"
            if d is not None:
                assert d > 0, "d must be a positive integer"
            else:
                warnings.warn("d is not specified, make sure you know what you're doing!")

            # Call the original function
            if d is not None:
                dataset = func(n, d, **kwargs)
            else:
                dataset = func(n, **kwargs)
            if isinstance(dataset, tuple):
                dataset = tuple(
                    torch.Tensor(data) if not isinstance(data, torch.Tensor) else data
                    for data in dataset
                )
            else:
                dataset = (
                    torch.Tensor(dataset) if not isinstance(dataset, torch.Tensor) else dataset
                )
            if d is not None:
                assert dataset.shape == (n, d), f"Dataset shape must be {(n, d)}"
            else:
                assert dataset.shape[0] == n, f"Dataset shape must be {(n, ...)}"

            return dataset

        DATASETS[name] = wrapper
        return wrapper

    return decorator

register_distribution(name)

This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.

Parameters:

Name Type Description Default
func callable

Dataset generator function

required

Returns:

Name Type Description
callable callable

Dataset generator function wrapper

Example

@register_dataset("random") def random_dataset(n=1000, d=10): return torch.randn(n, d)

Source code in labproject/data.py
def register_distribution(name: str) -> callable:
    r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.

    Args:
        func (callable): Dataset generator function

    Returns:
        callable: Dataset generator function wrapper

    Example:
        >>> @register_dataset("random")
        >>> def random_dataset(n=1000, d=10):
        >>>     return torch.randn(n, d)
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Call the original function
            distribution = func(*args, **kwargs)
            return distribution

        DISTRIBUTIONS[name] = wrapper
        return wrapper

    return decorator

upload_file(local_path, remote_path)

Uploads a file to the Hetzner Storage Box.

Parameters:

Name Type Description Default
local_path str

The path to the local file to be uploaded.

required
remote_path str

The path where the file should be uploaded on the remote server.

required

Returns:

Name Type Description
bool

True if the upload is successful, False otherwise.

Example

if upload_file('path/to/your/local/file.txt', 'path/to/remote/file.txt'): print("Upload successful") else: print("Upload failed")

Source code in labproject/data.py
def upload_file(local_path: str, remote_path: str):
    r"""
    Uploads a file to the Hetzner Storage Box.

    Args:
        local_path (str): The path to the local file to be uploaded.
        remote_path (str): The path where the file should be uploaded on the remote server.

    Returns:
        bool: True if the upload is successful, False otherwise.

    Example:
        >>> if upload_file('path/to/your/local/file.txt', 'path/to/remote/file.txt'):
        >>>     print("Upload successful")
        >>> else:
        >>>     print("Upload failed")
    """
    url = f"{STORAGEBOX_URL}/remote.php/dav/files/{HETZNER_STORAGEBOX_USERNAME}/{remote_path}"
    auth = HTTPBasicAuth(HETZNER_STORAGEBOX_USERNAME, HETZNER_STORAGEBOX_PASSWORD)
    with open(local_path, "rb") as f:
        data = f.read()
    response = requests.put(url, data=data, auth=auth)
    return response.status_code == 201

Embeddings

This contains embedding nets, and auxilliary functions for extracting (N,D) embeddings from respective data and models.

Experiments

ScaleDim

Bases: Experiment

Source code in labproject/experiments.py
class ScaleDim(Experiment):
    def __init__(self, metric_name, metric_fn, dim_sizes=None, min_dim=1, max_dim=1000, step=100):
        self.metric_name = metric_name
        self.metric_fn = metric_fn
        if dim_sizes is not None:
            self.dim_sizes = dim_sizes
        else:
            self.dim_sizes = list(range(min_dim, max_dim, step))
        super().__init__()

    def run_experiment(self, dataset1, dataset2, dataset_size, nb_runs=5, dim_sizes=None, **kwargs):
        final_distances = []
        final_errors = []
        n = dataset_size
        if dim_sizes is None:
            dim_sizes = self.dim_sizes
        for idx in range(nb_runs):
            distances = []
            for d in dim_sizes:
                # 3000 x 100
                data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :d]
                data2 = dataset2[
                    torch.randperm(dataset2.size(0))[:n], :d
                ]  # AS: changed from dataset1 to dataset2 in randperm
                distances.append(self.metric_fn(data1, data2, **kwargs))
            final_distances.append(distances)
        final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
        final_errors = (
            torch.tensor([torch.std(d) for d in final_distances])
            if nb_runs > 1
            else torch.zeros_like(torch.tensor(dim_sizes))
        )
        final_distances = torch.tensor([torch.mean(d) for d in final_distances])
        return dim_sizes, final_distances, final_errors

    def plot_experiment(
        self,
        dim_sizes,
        distances,
        errors,
        dataset_name,
        ax=None,
        color=None,
        label=None,
        linestyle="-",
        **kwargs,
    ):

        plot_scaling_metric_dimensionality(
            dim_sizes,
            distances,
            errors,
            self.metric_name,
            dataset_name,
            ax=ax,
            color=color,
            label=label,
            linestyle=linestyle,
            **kwargs,
        )

    def log_results(self, results, log_path):
        """
        Save the results to a file.
        """
        with open(log_path, "wb") as f:
            pickle.dump(results, f)
log_results(results, log_path)

Save the results to a file.

Source code in labproject/experiments.py
def log_results(self, results, log_path):
    """
    Save the results to a file.
    """
    with open(log_path, "wb") as f:
        pickle.dump(results, f)

ScaleHyperparameter

Bases: Experiment

Source code in labproject/experiments.py
class ScaleHyperparameter(Experiment):
    def __init__(
        self, metric_name, metric_fn, value_sizes=None, min_value=0.2, max_value=50, step=10
    ):
        self.metric_name = metric_name
        self.metric_fn = metric_fn
        if value_sizes is not None:
            self.value_sizes = value_sizes
        else:
            self.value_sizes = list(np.linspace(min_value, max_value, step))
        super().__init__()

    def run_experiment(self, dataset1, dataset2, nb_runs=5, n=10000, value_sizes=None, **kwargs):
        final_distances = []
        final_errors = []
        # n = 1000 # AS: turned into argument
        if value_sizes is None:
            value_sizes = self.value_sizes
            # print(value_sizes)
        for idx in range(nb_runs):
            distances = []
            for v in value_sizes:
                # print(v)
                # 3000 x 100
                data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :]
                data2 = dataset2[
                    torch.randperm(dataset2.size(0))[:n], :
                ]  # AS: changed from dataset1 to dataset2 in randperm
                distances.append(self.metric_fn(data1, data2, v, **kwargs))

            final_distances.append(distances)

        final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
        final_errors = (
            torch.tensor([torch.std(d) for d in final_distances])
            if nb_runs > 1
            else torch.zeros_like(torch.tensor(value_sizes))
        )
        final_distances = torch.tensor([torch.mean(d) for d in final_distances])
        return value_sizes, final_distances, final_errors

    def plot_experiment(
        self,
        value_sizes,
        distances,
        errors,
        dataset_name,
        ax=None,
        color=None,
        label=None,
        linestyle="-",
        **kwargs,
    ):

        plot_scaling_metric_dimensionality(
            value_sizes,
            distances,
            errors,
            self.metric_name,
            dataset_name,
            ax=ax,
            color=color,
            label=label,
            linestyle=linestyle,
            **kwargs,
        )

    def log_results(self, results, log_path):
        """
        Save the results to a file.
        """
        with open(log_path, "wb") as f:
            pickle.dump(results, f)
log_results(results, log_path)

Save the results to a file.

Source code in labproject/experiments.py
def log_results(self, results, log_path):
    """
    Save the results to a file.
    """
    with open(log_path, "wb") as f:
        pickle.dump(results, f)

ScaleSampleSize

Bases: Experiment

Source code in labproject/experiments.py
class ScaleSampleSize(Experiment):

    def __init__(
        self, metric_name, metric_fn, min_samples=3, max_samples=2000, step=100, sample_sizes=None
    ):
        assert min_samples > 2, "min_samples must be greater than 2 to compute covariance for KL"
        self.metric_name = metric_name
        self.metric_fn = metric_fn
        # TODO: add logarithmic scale or only keep pass in run experiment
        if sample_sizes is not None:
            self.sample_sizes = sample_sizes
        else:
            self.sample_sizes = list(range(min_samples, max_samples, step))
        super().__init__()

    def run_experiment(self, dataset1, dataset2, nb_runs=5, sample_sizes=None, **kwargs):
        """
        Computes for each subset 5 different random subsets and averages performance across the subsets.
        """
        final_distances = []
        final_errors = []
        if sample_sizes is None:
            sample_sizes = self.sample_sizes
        for idx in range(nb_runs):
            distances = []
            for n in sample_sizes:
                data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :]
                data2 = dataset2[torch.randperm(dataset2.size(0))[:n], :]
                distances.append(self.metric_fn(data1, data2, **kwargs))
            final_distances.append(distances)
        final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
        final_errors = (
            torch.tensor([torch.std(d) for d in final_distances])
            if nb_runs > 1
            else torch.zeros_like(torch.tensor(sample_sizes))
        )
        final_distances = torch.tensor([torch.mean(d) for d in final_distances])

        return sample_sizes, final_distances, final_errors

    def plot_experiment(
        self,
        sample_sizes,
        distances,
        errors,
        dataset_name,
        ax=None,
        color=None,
        label=None,
        linestyle="-",
        **kwargs,
    ):
        plot_scaling_metric_sample_size(
            sample_sizes,
            distances,
            errors,
            self.metric_name,
            dataset_name,
            ax=ax,
            color=color,
            label=label,
            linestyle=linestyle,
            **kwargs,
        )

    def log_results(self, results, log_path):
        """
        Save the results to a file.
        """
        with open(log_path, "wb") as f:
            pickle.dump(results, f)
log_results(results, log_path)

Save the results to a file.

Source code in labproject/experiments.py
def log_results(self, results, log_path):
    """
    Save the results to a file.
    """
    with open(log_path, "wb") as f:
        pickle.dump(results, f)
run_experiment(dataset1, dataset2, nb_runs=5, sample_sizes=None, **kwargs)

Computes for each subset 5 different random subsets and averages performance across the subsets.

Source code in labproject/experiments.py
def run_experiment(self, dataset1, dataset2, nb_runs=5, sample_sizes=None, **kwargs):
    """
    Computes for each subset 5 different random subsets and averages performance across the subsets.
    """
    final_distances = []
    final_errors = []
    if sample_sizes is None:
        sample_sizes = self.sample_sizes
    for idx in range(nb_runs):
        distances = []
        for n in sample_sizes:
            data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :]
            data2 = dataset2[torch.randperm(dataset2.size(0))[:n], :]
            distances.append(self.metric_fn(data1, data2, **kwargs))
        final_distances.append(distances)
    final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
    final_errors = (
        torch.tensor([torch.std(d) for d in final_distances])
        if nb_runs > 1
        else torch.zeros_like(torch.tensor(sample_sizes))
    )
    final_distances = torch.tensor([torch.mean(d) for d in final_distances])

    return sample_sizes, final_distances, final_errors

Plotting

place_boxplot(ax, x, y, body_face_color='#8189c9', body_edge_color='k', body_lw=0.25, body_alpha=1.0, body_zorder=0, whisker_color='k', whisker_alpha=1.0, whisker_lw=1, whisker_zorder=1, cap_color='k', cap_lw=0.25, cap_zorder=1, median_color='k', median_alpha=1.0, median_lw=1.5, median_bar_length=1.0, median_zorder=10, width=0.5, scatter_face_color='k', scatter_edge_color='none', scatter_radius=5, scatter_lw=0.25, scatter_alpha=0.35, scatter_width=0.5, scatter=True, scatter_zorder=3, fill_box=True, showcaps=False, showfliers=False, whis=(0, 100), vert=True)

Example

X = [1, 2] Y = [np.random.normal(0.75, 0.12, size=50), np.random.normal(0.8, 0.20, size=25)] fig, ax = plt.subplots(figsize=[1, 1]) for (x, y) in zip(X, Y): place_boxplot(ax, x, y)

Source code in labproject/plotting.py
def place_boxplot(
    ax,
    x,
    y,
    body_face_color="#8189c9",
    body_edge_color="k",
    body_lw=0.25,
    body_alpha=1.0,
    body_zorder=0,
    whisker_color="k",
    whisker_alpha=1.0,
    whisker_lw=1,
    whisker_zorder=1,
    cap_color="k",
    cap_lw=0.25,
    cap_zorder=1,
    median_color="k",
    median_alpha=1.0,
    median_lw=1.5,
    median_bar_length=1.0,
    median_zorder=10,
    width=0.5,
    scatter_face_color="k",
    scatter_edge_color="none",
    scatter_radius=5,
    scatter_lw=0.25,
    scatter_alpha=0.35,
    scatter_width=0.5,
    scatter=True,
    scatter_zorder=3,
    fill_box=True,
    showcaps=False,
    showfliers=False,
    whis=(0, 100),
    vert=True,
):
    """
    Example:
        X = [1, 2]
        Y = [np.random.normal(0.75, 0.12, size=50), np.random.normal(0.8, 0.20, size=25)]
        fig, ax = plt.subplots(figsize=[1, 1])
        for (x, y) in zip(X, Y):
            place_boxplot(ax, x, y)
    """
    parts = ax.boxplot(
        y,
        positions=[x],
        widths=width,
        showcaps=showcaps,
        showfliers=showfliers,
        whis=whis,
        vert=vert,
    )

    # polish the body
    b = parts["boxes"][0]
    b.set_color(body_edge_color)
    b.set_alpha(body_alpha)
    b.set_linewidth(body_lw)
    b.set_zorder(body_zorder)
    if fill_box:
        if vert:
            x0, x1 = b.get_xdata()[:2]
            y0, y1 = b.get_ydata()[1:3]
            r = Rectangle(
                [x0, y0],
                x1 - x0,
                y1 - y0,
                facecolor=body_face_color,
                alpha=body_alpha,
                edgecolor="none",
            )
            ax.add_patch(r)
        else:
            x0, x1 = b.get_xdata()[1:3]
            y0, y1 = b.get_ydata()[:2]
            r = Rectangle(
                [x0, y0],
                x1 - x0,
                y1 - y0,
                facecolor=body_face_color,
                alpha=body_alpha,
                edgecolor="none",
            )
            ax.add_patch(r)

    # polish the whiskers
    for w in parts["whiskers"]:
        w.set_color(whisker_color)
        w.set_alpha(whisker_alpha)
        w.set_linewidth(whisker_lw)
        w.set_zorder(whisker_zorder)

    # polish the caps
    for c in parts["caps"]:
        c.set_color(cap_color)
        c.set_linewidth(cap_lw)
        c.set_zorder(cap_zorder)

    # polish the median
    m = parts["medians"][0]
    m.set_color(median_color)
    m.set_linewidth(median_lw)
    m.set_alpha(median_alpha)
    m.set_zorder(median_zorder)
    if median_bar_length is not None:
        if vert:
            x0, x1 = m.get_xdata()
            m.set_xdata(
                [
                    x0 - 1 / 2 * (median_bar_length - 1) * (x1 - x0),
                    x1 + 1 / 2 * (median_bar_length - 1) * (x1 - x0),
                ]
            )
        else:
            y0, y1 = m.get_ydata()
            m.set_ydata(
                [
                    y0 - 1 / 2 * (median_bar_length - 1) * (y1 - y0),
                    y1 + 1 / 2 * (median_bar_length - 1) * (y1 - y0),
                ]
            )

    # scatter data
    if scatter:
        if vert:
            x0, x1 = b.get_xdata()[:2]
            ax.scatter(
                np.random.uniform(
                    x0 + 1 / 2 * (1 - scatter_width) * (x1 - x0),
                    x1 - 1 / 2 * (1 - scatter_width) * (x1 - x0),
                    size=len(y),
                ),
                y,
                facecolor=scatter_face_color,
                edgecolor=scatter_edge_color,
                s=scatter_radius,
                linewidth=scatter_lw,
                zorder=scatter_zorder,
                alpha=scatter_alpha,
            )
        else:
            y0, y1 = b.get_ydata()[:2]
            ax.scatter(
                y,
                np.random.uniform(
                    y0 + 1 / 2 * (1 - scatter_width) * (y1 - y0),
                    y1 - 1 / 2 * (1 - scatter_width) * (y1 - y0),
                    size=len(y),
                ),
                facecolor=scatter_face_color,
                edgecolor=scatter_edge_color,
                s=scatter_radius,
                linewidth=scatter_lw,
                zorder=scatter_zorder,
                alpha=scatter_alpha,
            )

place_violin(ax, x, y, body_face_color='#8189c9', body_edge_color='k', body_lw=0.25, body_alpha=1.0, body_zorder=0, whisker_color='k', whisker_alpha=1.0, whisker_lw=1, whisker_zorder=1, cap_color='k', cap_lw=0.25, cap_zorder=1, median_color='k', median_alpha=1.0, median_lw=1.5, median_bar_length=1.0, median_zorder=10, width=0.5, scatter_face_color='k', scatter_edge_color='none', scatter_radius=5, scatter_lw=0.25, scatter_alpha=0.35, scatter_width=0.5, scatter=True, scatter_zorder=3, showextrema=True, showmedians=True, showmeans=False, vert=True)

Example

X = [1, 2] Y = [np.random.normal(0.75, 0.12, size=50), np.random.normal(0.8, 0.20, size=25)] fig, ax = plt.subplots(figsize=[1, 1]) for (x, y) in zip(X, Y): place_violin(ax, x, y)

Source code in labproject/plotting.py
def place_violin(
    ax,
    x,
    y,
    body_face_color="#8189c9",
    body_edge_color="k",
    body_lw=0.25,
    body_alpha=1.0,
    body_zorder=0,
    whisker_color="k",
    whisker_alpha=1.0,
    whisker_lw=1,
    whisker_zorder=1,
    cap_color="k",
    cap_lw=0.25,
    cap_zorder=1,
    median_color="k",
    median_alpha=1.0,
    median_lw=1.5,
    median_bar_length=1.0,
    median_zorder=10,
    width=0.5,
    scatter_face_color="k",
    scatter_edge_color="none",
    scatter_radius=5,
    scatter_lw=0.25,
    scatter_alpha=0.35,
    scatter_width=0.5,
    scatter=True,
    scatter_zorder=3,
    showextrema=True,
    showmedians=True,
    showmeans=False,
    vert=True,
):
    """
    Example:
        X = [1, 2]
        Y = [np.random.normal(0.75, 0.12, size=50), np.random.normal(0.8, 0.20, size=25)]
        fig, ax = plt.subplots(figsize=[1, 1])
        for (x, y) in zip(X, Y):
            place_violin(ax, x, y)
    """
    if not np.any(y):
        return
    parts = ax.violinplot(
        y,
        positions=[x],
        widths=width,
        showmedians=showmedians,
        showmeans=showmeans,
        showextrema=showextrema,
        vert=vert,
    )
    # Color the bodies.
    b = parts["bodies"][0]
    b.set_facecolor(body_face_color)
    b.set_edgecolor(body_edge_color)
    b.set_linewidth(body_lw)
    b.set_alpha(body_alpha)
    b.set_zorder(body_zorder)

    # Color the lines.
    if showextrema:
        parts["cbars"].set_color(whisker_color)
        parts["cbars"].set_alpha(whisker_alpha)
        parts["cbars"].set_linewidth(whisker_lw)
        parts["cbars"].set_zorder(whisker_zorder)
        parts["cmaxes"].set_color(cap_color)
        parts["cmaxes"].set_linewidth(cap_lw)
        parts["cmaxes"].set_zorder(cap_zorder)
        parts["cmins"].set_color(cap_color)
        parts["cmins"].set_linewidth(cap_lw)
        parts["cmins"].set_zorder(cap_zorder)

    if showmeans:
        parts["cmeans"].set_color(median_color)
        parts["cmeans"].set_linewidth(median_lw)
        parts["cmeans"].set_alpha(median_alpha)
        parts["cmeans"].set_zorder(median_zorder)
        if median_bar_length is not None:
            if vert:
                (_, y0), (_, y1) = parts["cmeans"].get_segments()[0]
                parts["cmeans"].set_segments(
                    [
                        [
                            [x - median_bar_length * width / 2, y0],
                            [x + median_bar_length * width / 2, y1],
                        ]
                    ]
                )
            else:
                (x0, _), (x1, _) = parts["cmeans"].get_segments()[0]
                parts["cmeans"].set_segments(
                    [
                        [
                            [x0, x - median_bar_length * width / 2],
                            [x1, x + median_bar_length * width / 2],
                        ]
                    ]
                )

    if showmedians:
        parts["cmedians"].set_color(median_color)
        parts["cmedians"].set_alpha(median_alpha)
        parts["cmedians"].set_linewidth(median_lw)
        parts["cmedians"].set_zorder(median_zorder)
        if median_bar_length is not None:
            if vert:
                (_, y0), (_, y1) = parts["cmedians"].get_segments()[0]
                parts["cmedians"].set_segments(
                    [
                        [
                            [x - median_bar_length * width / 2, y0],
                            [x + median_bar_length * width / 2, y1],
                        ]
                    ]
                )
            else:
                (x0, _), (x1, _) = parts["cmedians"].get_segments()[0]
                parts["cmedians"].set_segments(
                    [
                        [
                            [x0, x - median_bar_length * width / 2],
                            [x1, x + median_bar_length * width / 2],
                        ]
                    ]
                )

    # scatter data
    if scatter:
        if vert:
            ax.scatter(
                np.random.uniform(
                    x - width / 2 + 1 / 2 * (1 - scatter_width) * width,
                    x + width / 2 - 1 / 2 * (1 - scatter_width) * width,
                    size=len(y),
                ),
                y,
                facecolor=scatter_face_color,
                edgecolor=scatter_edge_color,
                s=scatter_radius,
                linewidth=scatter_lw,
                alpha=scatter_alpha,
                zorder=scatter_zorder,
            )
        else:
            ax.scatter(
                y,
                np.random.uniform(
                    x - width / 2 + 1 / 2 * (1 - scatter_width) * width,
                    x + width / 2 - 1 / 2 * (1 - scatter_width) * width,
                    size=len(y),
                ),
                facecolor=scatter_face_color,
                edgecolor=scatter_edge_color,
                s=scatter_radius,
                linewidth=scatter_lw,
                zorder=5,
                alpha=scatter_alpha,
            )

plot_scaling_metric_dimensionality(dim_sizes, distances, errors, metric_name, dataset_name, ax=None, label=None, **kwargs)

Plot the scaling of a metric with increasing dimensionality.

Source code in labproject/plotting.py
def plot_scaling_metric_dimensionality(
    dim_sizes,
    distances,
    errors,
    metric_name,
    dataset_name,
    ax=None,
    label=None,
    **kwargs,
):
    """Plot the scaling of a metric with increasing dimensionality."""
    if ax is None:
        plt.plot(
            dim_sizes,
            distances,
            label=metric_name if label is None else label,
            **kwargs,
        )
        plt.fill_between(
            dim_sizes,
            distances - errors,
            distances + errors,
            alpha=0.2,
            color="black" if kwargs.get("color") is None else kwargs.get("color"),
        )
        plt.xlabel("Dimension")
        plt.ylabel(metric_name)
        plt.title(f"{metric_name} with increasing dimensionality size for {dataset_name}")
        plt.savefig(
            os.path.join(
                PLOT_PATH,
                f"{metric_name.lower().replace(' ', '_')}_dimensionality_size_{dataset_name.lower().replace(' ', '_')}.png",
            )
        )
        plt.close()
    else:
        ax.plot(
            dim_sizes,
            distances,
            label=metric_name if label is None else label,
            **kwargs,
        )
        ax.fill_between(
            dim_sizes,
            distances - errors,
            distances + errors,
            alpha=0.2,
            color="black" if kwargs.get("color") is None else kwargs.get("color"),
        )
        ax.set_xlabel("samples")
        ax.set_ylabel(
            metric_name, color="black" if kwargs.get("color") is None else kwargs.get("color")
        )
        return ax

plot_scaling_metric_sample_size(sample_size, distances, errors, metric_name, dataset_name, ax=None, label=None, **kwargs)

Plot the behavior of a metric with number of samples.

Source code in labproject/plotting.py
def plot_scaling_metric_sample_size(
    sample_size,
    distances,
    errors,
    metric_name,
    dataset_name,
    ax=None,
    label=None,
    **kwargs,
):
    """Plot the behavior of a metric with number of samples."""
    if ax is None:
        plt.plot(
            sample_size,
            distances,
            label=metric_name if label is None else label,
            **kwargs,
        )
        plt.fill_between(
            sample_size,
            distances - errors,
            distances + errors,
            alpha=0.2,
            color="black" if kwargs.get("color") is None else kwargs.get("color"),
        )
        plt.xlabel("samples")
        plt.ylabel(metric_name)
        plt.title(f"{metric_name} with increasing sample size for {dataset_name}")
        plt.savefig(
            os.path.join(
                PLOT_PATH,
                f"{metric_name.lower().replace(' ', '_')}_sample_size_{dataset_name.lower().replace(' ', '_')}.png",
            )
        )
        plt.close()
    else:
        ax.plot(
            sample_size,
            distances,
            label=metric_name if label is None else label,
            **kwargs,
        )
        ax.fill_between(
            sample_size,
            distances - errors,
            distances + errors,
            alpha=0.2,
            color="black" if kwargs.get("color") is None else kwargs.get("color"),
        )
        ax.set_xlabel("samples")
        ax.set_ylabel(
            metric_name, color="black" if kwargs.get("color") is None else kwargs.get("color")
        )
        return ax

Utils

get_cfg()

This function returns the configuration file for the current experiment run.

The configuration file is expected to be located at ../configs/conf_{name}.yaml, where name will match the name of the run_{name}.py file.

Raises:

Type Description
FileNotFoundError

If the configuration file is not found

Returns:

Name Type Description
OmegaConf OmegaConf

Dictionary with the configuration parameters

Source code in labproject/utils.py
def get_cfg() -> OmegaConf:
    """This function returns the configuration file for the current experiment run.

    The configuration file is expected to be located at ../configs/conf_{name}.yaml, where name will match the name of the run_{name}.py file.

    Raises:
        FileNotFoundError: If the configuration file is not found

    Returns:
        OmegaConf: Dictionary with the configuration parameters
    """
    caller_frame = inspect.currentframe().f_back
    filename = caller_frame.f_code.co_filename
    name = filename.split("/")[-1].split(".")[0].split("_")[-1]
    try:
        config = OmegaConf.load(CONF_PATH + f"/conf_{name}.yaml")
        config.running_user = name
    except FileNotFoundError:
        msg = f"Config file not found for {name}. Please create a config file at ../configs/conf_{name}.yaml"
        raise FileNotFoundError(msg)
    return config

get_cfg_from_file(name)

This function returns the configuration file for the current experiment run.

The configuration file is expected to be located at ../configs/{name}.yaml .

Raises:

Type Description
FileNotFoundError

If the configuration file is not found

Returns:

Name Type Description
OmegaConf OmegaConf

Dictionary with the configuration parameters

Source code in labproject/utils.py
def get_cfg_from_file(name: str) -> OmegaConf:
    """This function returns the configuration file for the current experiment run.

    The configuration file is expected to be located at ../configs/{name}.yaml .

    Raises:
        FileNotFoundError: If the configuration file is not found

    Returns:
        OmegaConf: Dictionary with the configuration parameters
    """
    try:
        config = OmegaConf.load(CONF_PATH + f"/{name}.yaml")
    except FileNotFoundError:
        msg = f"Config file not found for {name}. Please create a config file at ../configs/{name}.yaml"
        raise FileNotFoundError(msg)
    return config

get_log_path(cfg, tag='', timestamp=True)

Get the log path for the current experiment run. This log path is then used to save the numerical results of the experiment. Import this function in the run_{name}.py file and call it to get the log path.

Source code in labproject/utils.py
def get_log_path(cfg, tag="", timestamp=True):
    """
    Get the log path for the current experiment run.
    This log path is then used to save the numerical results of the experiment.
    Import this function in the run_{name}.py file and call it to get the log path.
    """

    # get datetime string
    now = datetime.datetime.now()
    if "exp_log_name" not in cfg:
        exp_log_name = now.strftime("%Y-%m-%d_%H-%M-%S")
    else:
        exp_log_name = cfg.exp_log_name
        # add datetime to the name
        add_date = now.strftime("%Y-%m-%d_%H-%M-%S") if timestamp else ""
        exp_log_name = exp_log_name + tag + "_" + add_date
    log_path = os.path.join(f"results/{cfg.running_user}/{exp_log_name}.pkl")
    return log_path

load_experiments(cfg, tag='', now='')

load the experiments to run

Source code in labproject/utils.py
def load_experiments(cfg, tag="", now=""):
    """
    load the experiments to run
    """
    exp_log_name = cfg.exp_log_name
    # add datetime to the name
    exp_log_name = exp_log_name + tag + "_" + now
    log_path = os.path.join(f"results/{cfg.running_user}/{exp_log_name}")
    return log_path

set_seed(seed)

Set seed for reproducibility

Parameters:

Name Type Description Default
seed int

Integer seed

required
Source code in labproject/utils.py
def set_seed(seed: int) -> None:
    """Set seed for reproducibility

    Args:
        seed (int): Integer seed
    """
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return seed