Quick Start
Architecture
The disent directory structure:
disent/dataset
: dataset wrappers, datasets & sampling strategiesdisent/dataset/data
: raw datasetsdisent/dataset/sampling
: sampling strategies forDisentDataset
disent/framework
: frameworks, including Auto-Encoders and VAEsdisent/metric
: metrics for evaluating disentanglement using ground truth datasetsdisent/model
: common encoder and decoder models used for VAE researchdisent/nn
: torch components for building models including layers, transforms, losses and general mathsdisent/schedule
: annealing schedules that can be registered to a frameworkdisent/util
: helper functions for the rest of the framework
Please Note The API Is Still Unstable ⚠️
Disent is still under active development. Features and APIs are not considered stable, and should be expected to change! A limited set of tests currently exist which will be expanded upon in time.
Examples
dataset/data
Common and custom data for vision based AE, VAE and Disentanglement research.
- Most data is generated from ground truth factors which is necessary for evaluation using disentanglement metrics. Each image generated from ground truth data has the ground truth variables available.
Example
from disent.dataset.data import XYObjectData
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1')
print(f'Number of observations: {len(data)} == {data.size}')
print(f'Observation shape: {data.img_shape}')
print(f'Num Factors: {data.num_factors}')
print(f'Factor Names: {data.factor_names}')
print(f'Factor Sizes: {data.factor_sizes}')
for i, obs in enumerate(data):
print(
f'i={i}',
f'pos: ({", ".join(data.factor_names)}) = {tuple(data.idx_to_pos(i))}',
f'obs={obs.tolist()}',
sep=' | ',
)
dataset
Ground truth variables of the data can be used to generate pairs or ordered sets for each observation in the datasets, using sampling strategies.
Examples
from disent.dataset.data import XYObjectData
from disent.dataset import DisentDataset
# prepare the data
# - DisentDataset is a generic wrapper around torch Datasets that prepares
# the data for the various frameworks according to some sampling strategy
# by default this sampling strategy just returns the data at the given idx.
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1')
dataset = DisentDataset(data, transform=None, augment=None)
# iterate over single epoch
for obs in dataset:
# transform(data[i]) gives 'x_targ', then augment(x_targ) gives 'x'
(x0,) = obs['x_targ']
print(x0.dtype, x0.min(), x0.max(), x0.shape)
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.dataset.transform import ToImgTensorF32
# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1')
dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32())
# iterate over single epoch
for obs in dataset:
# singles are contained in tuples of size 1 for compatibility with pairs with size 2
(x0, x1) = obs['x_targ']
print(x0.dtype, x0.min(), x0.max(), x0.shape)
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairSampler
from disent.dataset.transform import ToImgTensorF32, FftBoxBlur
# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1')
dataset = DisentDataset(data, sampler=GroundTruthPairSampler(), transform=ToImgTensorF32(), augment=FftBoxBlur(radius=1, p=1.0))
# iterate over single epoch
for obs in dataset:
# if augment is not specified, then the augmented 'x' key does not exist!
(x0, x1), (x0_targ, x1_targ) = obs['x'], obs['x_targ']
print(x0.dtype, x0.min(), x0.max(), x0.shape)
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.dataset.transform import ToImgTensorF32
# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1')
dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# iterate over single epoch
for batch in dataloader:
(x0, x1) = batch['x_targ']
print(x0.dtype, x0.min(), x0.max(), x0.shape)
framework
PytorchLightning modules that contain various AE or VAE implementations.
Examples
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.ae import Ae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run # you can ignore and remove this
# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# create the pytorch lightning system
module: pl.LightningModule = Ae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=data.x_shape, z_size=6),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
),
cfg=Ae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum')
)
# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run # you can ignore and remove this
# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# create the pytorch lightning system
module: pl.LightningModule = BetaVae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
),
cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4)
)
# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.frameworks.vae import AdaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run # you can ignore and remove this
# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# create the pytorch lightning system
module: pl.LightningModule = AdaVae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
),
cfg=AdaVae.cfg(
optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
loss_reduction='mean_sum', beta=4, ada_average_mode='gvae', ada_thresh_mode='kl',
)
)
# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
metrics
Various metrics used to evaluate representations learnt by AEs and VAEs.
Example
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.metrics import metric_dci, metric_mig
from disent.util import is_test_run
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
def make_vae(beta):
return BetaVae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
),
cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), beta=beta)
)
def train(module):
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
# we cannot guarantee which device the representation is on
get_repr = lambda x: module.encode(x.to(module.device))
# evaluate
return {
**metric_dci(dataset, get_repr, num_train=10 if is_test_run() else 1000, num_test=5 if is_test_run() else 500, boost_mode='sklearn'),
**metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000),
}
a_results = train(make_vae(beta=4))
b_results = train(make_vae(beta=0.01))
print('beta=4: ', a_results)
print('beta=0.01:', b_results)
schedules
Hyper-parameter schedules can be applied if models reference
their config values. Such as beta
(cfg.beta
) in all the
BetaVAE derived classes.
A warning will be printed if the hyper-parameter does not exist in the config, instead of crashing.
Example
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.schedule import CyclicSchedule
from disent.util import is_test_run # you can ignore and remove this
# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# create the pytorch lightning system
module: pl.LightningModule = BetaVae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
),
cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4)
)
# register the scheduler with the DisentFramework
# - cyclic scheduler from: https://arxiv.org/abs/1903.10145
module.register_schedule('beta', CyclicSchedule(
period=1024, # repeat every: trainer.global_step % period
))
# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
Datasets Without Ground-Truth Factors
You can use datasets that do not have ground truth factors by changing the sampling
strategy of DisentDataset
, however, metrics cannot be computed.
The following MNIST example uses the builtin RandomSampler
.
Example
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
from disent.dataset import DisentDataset
from disent.dataset.sampling import RandomSampler
from disent.frameworks.vae import AdaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderFC, EncoderFC
from disent.dataset.transform import ToImgTensorF32
# modify the mnist dataset to only return images, not labels
class MNIST(datasets.MNIST):
def __getitem__(self, index):
img, target = super().__getitem__(index)
return img
# make mnist dataset -- adjust num_samples here to match framework. TODO: add tests that can fail with a warning -- dataset downloading is not always reliable
data_folder = os.path.abspath(os.path.join(__file__, '../data/dataset'))
dataset_train = DisentDataset(MNIST(data_folder, train=True, download=True, transform=ToImgTensorF32()), sampler=RandomSampler(num_samples=2))
dataset_test = MNIST(data_folder, train=False, download=True, transform=ToImgTensorF32())
# create the dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=128, shuffle=True, num_workers=os.cpu_count())
dataloader_test = DataLoader(dataset=dataset_test, batch_size=128, shuffle=True, num_workers=os.cpu_count())
# create the model
module = AdaVae(
model=AutoEncoder(
encoder=EncoderFC(x_shape=(1, 28, 28), z_size=9, z_multiplier=2),
decoder=DecoderFC(x_shape=(1, 28, 28), z_size=9),
),
cfg=AdaVae.cfg(
optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
beta=4, recon_loss='mse', loss_reduction='mean_sum', # "mean_sum" is the traditional loss reduction mode, rather than "mean"
)
)
# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=2048) # callbacks=[VaeLatentCycleLoggingCallback(every_n_steps=250, plt_show=True)]
trainer.fit(module, dataloader_train)
# move back to gpu & manually encode some observation
for xs in tqdm(dataloader_test, desc='Custom Evaluation'):
zs = module.encode(xs.to(module.device))