cellpin.models.CellPin#

class cellpin.models.CellPin(sc_dataset, config=None, checkpoint=None)#

Bases: LightningModule

CellPin: hybrid two-view VAE for single-cell and spatial transcriptomics.

  • Model construction (CellPinVAE).

  • Two-stage training (pretrain → main).

  • ELBO computation with KL annealing.

  • Invariance regularisation (KL-distillation/MSE + SNN).

  • Inference / imputation API.

Args:
sc_dataset: Training scAnnDataset used to

infer gene counts and panel size.

config: Hyper-parameter dict, path to a YAML file, or None for

defaults. Key parameters:

  • n_latent (192): latent space dimensionality

  • n_hidden (1024): encoder/decoder hidden width

  • encoder_layers (16): number of residual blocks per encoder

  • reconstruction_loss (“nb”): "nb", "zinb", "poisson", "normal", "zin"

  • distillation_mode (“mse”): "kl" or "mse"

  • reconstruct_panel (True): if False, reconstruction loss on non-panel genes only

  • kl_warmup_epochs (20): epochs for linear KL annealing from 0 to kl_weight

  • lambda_inv (20.0): weight on the invariance (distillation + SNN) loss

  • exclude_panel (False): if True, full encoder sees panel genes zeroed out

  • lr (0.00021): AdamW learning rate

Full defaults in configs/cellpin_config.yaml.

checkpoint: Path to a .pt checkpoint to load weights from.

compute_losses(batch)#

Main training losses.

Objective:

  • ELBO via panel encoder (imputation-facing objective).

  • Invariance loss: KL-distillation or MSE between full-latent and panel-latent.

  • Soft nearest-neighbour (SNN) alignment.

A single inference pass per view is made; the ELBO, invariance loss, and SNN all share the same sampled z and means.

Args:
batch: Must contain 'full_expr' and 'panel_expr'.

Optionally 'local_l_mean', 'local_l_var', 'batch_index'.

Returns:#

:

Dict with scalar tensors: 'loss', 'reconst_loss', 'kl_loss', 'kl_l_loss', 'distill_loss', 'snn_loss', 'inv_loss', 'snn_temperature', 'pearson_loss'.

rtype:

dict[str, Tensor]

compute_pretrain_losses(batch)#

Pretraining losses (full-gene view only).

Objective: ELBO on full-gene path.

Return type:

dict[str, Tensor]

Args:
batch: Must contain 'full_expr' and 'panel_expr'.

Optionally 'local_l_mean', 'local_l_var', 'batch_index'.

Returns:#

:

Dict with scalar tensors: 'loss', 'reconst_loss', 'kl_loss', 'kl_l_loss'.

configure_optimizers()#

AdamW optimiser with cosine-annealing LR schedule.

embed_and_impute(dataloader, use_mean=True, mc_impute=False, mc_samples=50, mask_fraction=0.2)#

Embed cells and generate imputed expression values.

fit(dataset, pretrain_epochs=50, train_epochs=60, batch_size=256, gradient_clip_val=0.5, early_stopping_patience=10, freeze_pretrained=False, train_size=0.8, save_checkpoints=False, output_dir='./cellpin_output', decoder_warm_unfreeze_epoch=-1, **trainer_kwargs)#

Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).

This is the recommended entry point for training. It runs both stages sequentially with a single call.

Return type:

None

Args:

dataset: Single-cell dataset returned by cellpin.pp.setup(). pretrain_epochs: Max epochs for Stage 1 (full-gene ELBO pretraining). train_epochs: Max epochs for Stage 2 (panel distillation). batch_size: Mini-batch size for both stages. gradient_clip_val: Gradient clipping value. early_stopping_patience: Epochs without improvement before stopping. freeze_pretrained: Freeze the full-gene encoder/decoder during Stage 2. train_size: Fraction of cells used for training (rest → validation). save_checkpoints: Save model checkpoints to output_dir.

Disabled by default — enable when you need to resume training or load the best epoch after early stopping.

output_dir: Root directory for checkpoints and logs

(only used when save_checkpoints=True).

decoder_warm_unfreeze_epoch: Stage 2 epoch at which the frozen

decoder is unfrozen for warm fine-tuning. Only active when freeze_pretrained=True. -1 (default) keeps the decoder frozen for the entire Stage 2 run.

**trainer_kwargs: Extra arguments forwarded to

CellPinTrainer (e.g. devices, precision, accelerator).

Example:

sc_dataset, _ = cellpin.pp.setup_data(sc_adata, st_adata)
model = cellpin.CellPin(sc_dataset)
model.fit(sc_dataset)
get_cell_embedding(dataloader, use_mean=True)#

Encode cells to the latent space via the panel encoder.

Return type:

ndarray

Args:
dataloader: DataLoader over a

scAnnDataset or stAnnDataset.

use_mean: Return the posterior mean rather than a sample.

Returns:#

:

Float32 array (n_cells, n_latent).

impute(dataloader, obs_adata=None, mc_samples=50, mask_fraction=0.2, return_norm=False, norm_target_sum=1000.0, area_key=None, nb_count_samples=100, return_int=False, return_sparse=True, table_key='table')#

Impute with MC averaging and optional count-space normalisation.

Return type:

AnnData

Args:

dataloader: DataLoader to run inference on. obs_adata: Optional AnnData (or spatialdata.SpatialData) whose

.obs is copied to the output. If SpatialData, the AnnData is read from obs_adata.tables[table_key] and the result is returned as an updated SpatialData object. Must have the same number of observations.

mc_samples: Number of stochastic forward passes for MC averaging

(default 50; more → smoother but slower).

mask_fraction: Fraction of panel genes randomly zeroed per MC pass

to simulate missing measurements (default 0.2).

return_norm: If True, add a log-normalised layer

layers['imputed_norm'] (total-count or area normalised, then log1p-transformed).

norm_target_sum: Target total counts for normalisation

(default 1e3; only used when return_norm=True).

area_key: obs column with cell area for area-based normalisation.

Auto-detected as 'cell_area' when present; pass None for total-count normalisation (only used when return_norm=True).

nb_count_samples: Number of NB draws used to compute the MC estimate

of E[log1p(norm(X))] when return_norm=True (default 100). Because log1p is concave, Jensen’s inequality means log1p(norm(E[X])) > E[log1p(norm(X))]; sampling inside the transform corrects this bias. More samples → lower variance.

return_int: If True, round X to integer counts (int32). return_sparse: If True (default), store X, layers['imputed'],

and layers['imputed_norm'] as scipy.sparse.csr_matrix. Set to False to keep dense numpy arrays.

table_key: Table name to read/write when obs_adata is a SpatialData

object (default "table").

Returns:#

:

anndata.AnnData with X = imputed (float or int) counts, obsm['X_cellpin'] = embeddings, layers['imputed'] = copy of X, and optionally layers['imputed_norm']. var['is_measured'] marks genes present in obs_adata (all True when obs_adata is None). If obs_adata was a SpatialData object, returns the updated SpatialData with the result stored in sdata.tables[table_key].

Raises:#

ValueError: If obs_adata has the wrong number of cells, or if

area_key is specified but not found in adata.obs, or if any cell area is ≤ 0.

on_train_epoch_start()#

Warm-unfreeze decoder parameters when scheduled.

Return type:

None

pretrain_model(dataset, custom_callbacks=None, train_size=0.8, pretrain_epochs=50, **trainer_kwargs)#

Stage-1 pretraining (full-gene view only, ELBO).

Args:

dataset: Training dataset. custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. pretrain_epochs: Default max epochs (overridden by

trainer_kwargs['max_epochs'] if present).

**trainer_kwargs: Forwarded to CellPinTrainer.

Returns:#

:

Fitted CellPinTrainer.

save(path)#

Serialise model weights and hyper-parameters to a .pt file.

Return type:

None

Args:

path: Destination file path.

set_stage_loss_weights(stage, **weights)#

Programmatically update loss weights for a stage.

Useful for sweeps or ablations.

Return type:

None

Args:

stage: 'pretrain' or 'main'. **weights: Key-value overrides, e.g. inv=2.0, recon=1.5.

Raises:#

KeyError: For unknown weight keys.

Example:

model.set_stage_loss_weights("main", inv=2.0, recon=1.5)
train_model(dataset, custom_callbacks=None, train_size=0.8, freeze_pretrained=False, require_pretrained=True, **trainer_kwargs)#

Stage-2 main training (both views, full ELBO + invariance + SNN).

Args:

dataset: Training dataset (scAnnDataset). custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. freeze_pretrained: If True, freeze the full-gene encoder and

decoder (Stage 1 weights) during Stage 2.

require_pretrained: If True (default), raise an error when

freeze_pretrained=True but pretrain_model was never called, preventing silent training against a random frozen decoder.

**trainer_kwargs: Forwarded to CellPinTrainer.

Returns:#

:

Fitted CellPinTrainer.

Raises:#

RuntimeError: If require_pretrained=True, freeze_pretrained=True,

and pretraining has not been completed.

training_step(batch, batch_idx)#

Run a training step for the current stage.

Return type:

Tensor

validation_step(batch, batch_idx)#

Run a validation step for the current stage.

Return type:

Tensor

Methods table#

__init__(sc_dataset[, config, checkpoint])

compute_losses(batch)

Main training losses.

compute_pretrain_losses(batch)

Pretraining losses (full-gene view only).

configure_optimizers()

AdamW optimiser with cosine-annealing LR schedule.

embed_and_impute(dataloader[, use_mean, ...])

Embed cells and generate imputed expression values.

fit(dataset[, pretrain_epochs, ...])

Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).

get_cell_embedding(dataloader[, use_mean])

Encode cells to the latent space via the panel encoder.

impute(dataloader[, obs_adata, mc_samples, ...])

Impute with MC averaging and optional count-space normalisation.

on_train_epoch_start()

Warm-unfreeze decoder parameters when scheduled.

pretrain_model(dataset[, custom_callbacks, ...])

Stage-1 pretraining (full-gene view only, ELBO).

save(path)

Serialise model weights and hyper-parameters to a .pt file.

set_stage_loss_weights(stage, **weights)

Programmatically update loss weights for a stage.

train_model(dataset[, custom_callbacks, ...])

Stage-2 main training (both views, full ELBO + invariance + SNN).

training_step(batch, batch_idx)

Run a training step for the current stage.

validation_step(batch, batch_idx)

Run a validation step for the current stage.

Attributes#

CellPin.training#

Methods#

CellPin.compute_losses(batch)#

Main training losses.

Objective:

  • ELBO via panel encoder (imputation-facing objective).

  • Invariance loss: KL-distillation or MSE between full-latent and panel-latent.

  • Soft nearest-neighbour (SNN) alignment.

A single inference pass per view is made; the ELBO, invariance loss, and SNN all share the same sampled z and means.

Args:
batch: Must contain 'full_expr' and 'panel_expr'.

Optionally 'local_l_mean', 'local_l_var', 'batch_index'.

Returns:#

:

Dict with scalar tensors: 'loss', 'reconst_loss', 'kl_loss', 'kl_l_loss', 'distill_loss', 'snn_loss', 'inv_loss', 'snn_temperature', 'pearson_loss'.

rtype:

dict[str, Tensor]

CellPin.compute_pretrain_losses(batch)#

Pretraining losses (full-gene view only).

Objective: ELBO on full-gene path.

Return type:

dict[str, Tensor]

Args:
batch: Must contain 'full_expr' and 'panel_expr'.

Optionally 'local_l_mean', 'local_l_var', 'batch_index'.

Returns:#

:

Dict with scalar tensors: 'loss', 'reconst_loss', 'kl_loss', 'kl_l_loss'.

CellPin.configure_optimizers()#

AdamW optimiser with cosine-annealing LR schedule.

CellPin.embed_and_impute(dataloader, use_mean=True, mc_impute=False, mc_samples=50, mask_fraction=0.2)#

Embed cells and generate imputed expression values.

CellPin.fit(dataset, pretrain_epochs=50, train_epochs=60, batch_size=256, gradient_clip_val=0.5, early_stopping_patience=10, freeze_pretrained=False, train_size=0.8, save_checkpoints=False, output_dir='./cellpin_output', decoder_warm_unfreeze_epoch=-1, **trainer_kwargs)#

Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).

This is the recommended entry point for training. It runs both stages sequentially with a single call.

Return type:

None

Args:

dataset: Single-cell dataset returned by cellpin.pp.setup(). pretrain_epochs: Max epochs for Stage 1 (full-gene ELBO pretraining). train_epochs: Max epochs for Stage 2 (panel distillation). batch_size: Mini-batch size for both stages. gradient_clip_val: Gradient clipping value. early_stopping_patience: Epochs without improvement before stopping. freeze_pretrained: Freeze the full-gene encoder/decoder during Stage 2. train_size: Fraction of cells used for training (rest → validation). save_checkpoints: Save model checkpoints to output_dir.

Disabled by default — enable when you need to resume training or load the best epoch after early stopping.

output_dir: Root directory for checkpoints and logs

(only used when save_checkpoints=True).

decoder_warm_unfreeze_epoch: Stage 2 epoch at which the frozen

decoder is unfrozen for warm fine-tuning. Only active when freeze_pretrained=True. -1 (default) keeps the decoder frozen for the entire Stage 2 run.

**trainer_kwargs: Extra arguments forwarded to

CellPinTrainer (e.g. devices, precision, accelerator).

Example:

sc_dataset, _ = cellpin.pp.setup_data(sc_adata, st_adata)
model = cellpin.CellPin(sc_dataset)
model.fit(sc_dataset)
CellPin.get_cell_embedding(dataloader, use_mean=True)#

Encode cells to the latent space via the panel encoder.

Return type:

ndarray

Args:
dataloader: DataLoader over a

scAnnDataset or stAnnDataset.

use_mean: Return the posterior mean rather than a sample.

Returns:#

:

Float32 array (n_cells, n_latent).

CellPin.impute(dataloader, obs_adata=None, mc_samples=50, mask_fraction=0.2, return_norm=False, norm_target_sum=1000.0, area_key=None, nb_count_samples=100, return_int=False, return_sparse=True, table_key='table')#

Impute with MC averaging and optional count-space normalisation.

Return type:

AnnData

Args:

dataloader: DataLoader to run inference on. obs_adata: Optional AnnData (or spatialdata.SpatialData) whose

.obs is copied to the output. If SpatialData, the AnnData is read from obs_adata.tables[table_key] and the result is returned as an updated SpatialData object. Must have the same number of observations.

mc_samples: Number of stochastic forward passes for MC averaging

(default 50; more → smoother but slower).

mask_fraction: Fraction of panel genes randomly zeroed per MC pass

to simulate missing measurements (default 0.2).

return_norm: If True, add a log-normalised layer

layers['imputed_norm'] (total-count or area normalised, then log1p-transformed).

norm_target_sum: Target total counts for normalisation

(default 1e3; only used when return_norm=True).

area_key: obs column with cell area for area-based normalisation.

Auto-detected as 'cell_area' when present; pass None for total-count normalisation (only used when return_norm=True).

nb_count_samples: Number of NB draws used to compute the MC estimate

of E[log1p(norm(X))] when return_norm=True (default 100). Because log1p is concave, Jensen’s inequality means log1p(norm(E[X])) > E[log1p(norm(X))]; sampling inside the transform corrects this bias. More samples → lower variance.

return_int: If True, round X to integer counts (int32). return_sparse: If True (default), store X, layers['imputed'],

and layers['imputed_norm'] as scipy.sparse.csr_matrix. Set to False to keep dense numpy arrays.

table_key: Table name to read/write when obs_adata is a SpatialData

object (default "table").

Returns:#

:

anndata.AnnData with X = imputed (float or int) counts, obsm['X_cellpin'] = embeddings, layers['imputed'] = copy of X, and optionally layers['imputed_norm']. var['is_measured'] marks genes present in obs_adata (all True when obs_adata is None). If obs_adata was a SpatialData object, returns the updated SpatialData with the result stored in sdata.tables[table_key].

Raises:#

ValueError: If obs_adata has the wrong number of cells, or if

area_key is specified but not found in adata.obs, or if any cell area is ≤ 0.

CellPin.on_train_epoch_start()#

Warm-unfreeze decoder parameters when scheduled.

Return type:

None

CellPin.pretrain_model(dataset, custom_callbacks=None, train_size=0.8, pretrain_epochs=50, **trainer_kwargs)#

Stage-1 pretraining (full-gene view only, ELBO).

Args:

dataset: Training dataset. custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. pretrain_epochs: Default max epochs (overridden by

trainer_kwargs['max_epochs'] if present).

**trainer_kwargs: Forwarded to CellPinTrainer.

Returns:#

:

Fitted CellPinTrainer.

CellPin.save(path)#

Serialise model weights and hyper-parameters to a .pt file.

Return type:

None

Args:

path: Destination file path.

CellPin.set_stage_loss_weights(stage, **weights)#

Programmatically update loss weights for a stage.

Useful for sweeps or ablations.

Return type:

None

Args:

stage: 'pretrain' or 'main'. **weights: Key-value overrides, e.g. inv=2.0, recon=1.5.

Raises:#

KeyError: For unknown weight keys.

Example:

model.set_stage_loss_weights("main", inv=2.0, recon=1.5)
CellPin.train_model(dataset, custom_callbacks=None, train_size=0.8, freeze_pretrained=False, require_pretrained=True, **trainer_kwargs)#

Stage-2 main training (both views, full ELBO + invariance + SNN).

Args:

dataset: Training dataset (scAnnDataset). custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. freeze_pretrained: If True, freeze the full-gene encoder and

decoder (Stage 1 weights) during Stage 2.

require_pretrained: If True (default), raise an error when

freeze_pretrained=True but pretrain_model was never called, preventing silent training against a random frozen decoder.

**trainer_kwargs: Forwarded to CellPinTrainer.

Returns:#

:

Fitted CellPinTrainer.

Raises:#

RuntimeError: If require_pretrained=True, freeze_pretrained=True,

and pretraining has not been completed.

CellPin.training_step(batch, batch_idx)#

Run a training step for the current stage.

Return type:

Tensor

CellPin.validation_step(batch, batch_idx)#

Run a validation step for the current stage.

Return type:

Tensor