cellpin.models.CellPin#
- class cellpin.models.CellPin(sc_dataset, config=None, checkpoint=None)#
Bases:
LightningModuleCellPin: hybrid two-view VAE for single-cell and spatial transcriptomics.
Model construction (
CellPinVAE).Two-stage training (pretrain → main).
ELBO computation with KL annealing.
Invariance regularisation (KL-distillation/MSE + SNN).
Inference / imputation API.
- Args:
- sc_dataset: Training
scAnnDatasetused to infer gene counts and panel size.
- config: Hyper-parameter dict, path to a YAML file, or
Nonefor defaults. Key parameters:
n_latent(192): latent space dimensionalityn_hidden(1024): encoder/decoder hidden widthencoder_layers(16): number of residual blocks per encoderreconstruction_loss(“nb”):"nb","zinb","poisson","normal","zin"distillation_mode(“mse”):"kl"or"mse"reconstruct_panel(True): ifFalse, reconstruction loss on non-panel genes onlykl_warmup_epochs(20): epochs for linear KL annealing from 0 tokl_weightlambda_inv(20.0): weight on the invariance (distillation + SNN) lossexclude_panel(False): ifTrue, full encoder sees panel genes zeroed outlr(0.00021): AdamW learning rate
Full defaults in
configs/cellpin_config.yaml.
checkpoint: Path to a
.ptcheckpoint to load weights from.- sc_dataset: Training
- compute_losses(batch)#
Main training losses.
Objective:
ELBO via panel encoder (imputation-facing objective).
Invariance loss: KL-distillation or MSE between full-latent and panel-latent.
Soft nearest-neighbour (SNN) alignment.
A single inference pass per view is made; the ELBO, invariance loss, and SNN all share the same sampled
zand means.- Args:
- batch: Must contain
'full_expr'and'panel_expr'. Optionally
'local_l_mean','local_l_var','batch_index'.
- batch: Must contain
Returns:#
- :
Dict with scalar tensors:
'loss','reconst_loss','kl_loss','kl_l_loss','distill_loss','snn_loss','inv_loss','snn_temperature','pearson_loss'.
- compute_pretrain_losses(batch)#
Pretraining losses (full-gene view only).
Objective: ELBO on full-gene path.
- Args:
- batch: Must contain
'full_expr'and'panel_expr'. Optionally
'local_l_mean','local_l_var','batch_index'.
- batch: Must contain
Returns:#
- :
Dict with scalar tensors:
'loss','reconst_loss','kl_loss','kl_l_loss'.
- configure_optimizers()#
AdamW optimiser with cosine-annealing LR schedule.
- embed_and_impute(dataloader, use_mean=True, mc_impute=False, mc_samples=50, mask_fraction=0.2)#
Embed cells and generate imputed expression values.
- fit(dataset, pretrain_epochs=50, train_epochs=60, batch_size=256, gradient_clip_val=0.5, early_stopping_patience=10, freeze_pretrained=False, train_size=0.8, save_checkpoints=False, output_dir='./cellpin_output', decoder_warm_unfreeze_epoch=-1, **trainer_kwargs)#
Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).
This is the recommended entry point for training. It runs both stages sequentially with a single call.
- Return type:
- Args:
dataset: Single-cell dataset returned by
cellpin.pp.setup(). pretrain_epochs: Max epochs for Stage 1 (full-gene ELBO pretraining). train_epochs: Max epochs for Stage 2 (panel distillation). batch_size: Mini-batch size for both stages. gradient_clip_val: Gradient clipping value. early_stopping_patience: Epochs without improvement before stopping. freeze_pretrained: Freeze the full-gene encoder/decoder during Stage 2. train_size: Fraction of cells used for training (rest → validation). save_checkpoints: Save model checkpoints tooutput_dir.Disabled by default — enable when you need to resume training or load the best epoch after early stopping.
- output_dir: Root directory for checkpoints and logs
(only used when
save_checkpoints=True).- decoder_warm_unfreeze_epoch: Stage 2 epoch at which the frozen
decoder is unfrozen for warm fine-tuning. Only active when
freeze_pretrained=True.-1(default) keeps the decoder frozen for the entire Stage 2 run.- **trainer_kwargs: Extra arguments forwarded to
CellPinTrainer(e.g.devices,precision,accelerator).
Example:
sc_dataset, _ = cellpin.pp.setup_data(sc_adata, st_adata) model = cellpin.CellPin(sc_dataset) model.fit(sc_dataset)
- get_cell_embedding(dataloader, use_mean=True)#
Encode cells to the latent space via the panel encoder.
- Return type:
- Args:
- dataloader: DataLoader over a
use_mean: Return the posterior mean rather than a sample.
Returns:#
- :
Float32 array
(n_cells, n_latent).
- impute(dataloader, obs_adata=None, mc_samples=50, mask_fraction=0.2, return_norm=False, norm_target_sum=1000.0, area_key=None, nb_count_samples=100, return_int=False, return_sparse=True, table_key='table')#
Impute with MC averaging and optional count-space normalisation.
- Return type:
- Args:
dataloader: DataLoader to run inference on. obs_adata: Optional AnnData (or
spatialdata.SpatialData) whose.obsis copied to the output. If SpatialData, the AnnData is read fromobs_adata.tables[table_key]and the result is returned as an updated SpatialData object. Must have the same number of observations.- mc_samples: Number of stochastic forward passes for MC averaging
(default 50; more → smoother but slower).
- mask_fraction: Fraction of panel genes randomly zeroed per MC pass
to simulate missing measurements (default 0.2).
- return_norm: If
True, add a log-normalised layer layers['imputed_norm'](total-count or area normalised, then log1p-transformed).- norm_target_sum: Target total counts for normalisation
(default 1e3; only used when
return_norm=True).- area_key:
obscolumn with cell area for area-based normalisation. Auto-detected as
'cell_area'when present; passNonefor total-count normalisation (only used whenreturn_norm=True).- nb_count_samples: Number of NB draws used to compute the MC estimate
of
E[log1p(norm(X))]whenreturn_norm=True(default 100). Because log1p is concave, Jensen’s inequality meanslog1p(norm(E[X])) > E[log1p(norm(X))]; sampling inside the transform corrects this bias. More samples → lower variance.
return_int: If
True, roundXto integer counts (int32). return_sparse: IfTrue(default), storeX,layers['imputed'],and
layers['imputed_norm']asscipy.sparse.csr_matrix. Set toFalseto keep dense numpy arrays.- table_key: Table name to read/write when
obs_adatais a SpatialData object (default
"table").
Returns:#
- :
anndata.AnnDatawithX= imputed (float or int) counts,obsm['X_cellpin']= embeddings,layers['imputed']= copy ofX, and optionallylayers['imputed_norm'].var['is_measured']marks genes present inobs_adata(allTruewhenobs_adataisNone). Ifobs_adatawas a SpatialData object, returns the updated SpatialData with the result stored insdata.tables[table_key].
Raises:#
- ValueError: If
obs_adatahas the wrong number of cells, or if area_keyis specified but not found inadata.obs, or if any cell area is ≤ 0.
- pretrain_model(dataset, custom_callbacks=None, train_size=0.8, pretrain_epochs=50, **trainer_kwargs)#
Stage-1 pretraining (full-gene view only, ELBO).
- Args:
dataset: Training dataset. custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. pretrain_epochs: Default max epochs (overridden by
trainer_kwargs['max_epochs']if present).**trainer_kwargs: Forwarded to
CellPinTrainer.
Returns:#
- :
Fitted
CellPinTrainer.
- save(path)#
Serialise model weights and hyper-parameters to a
.ptfile.- Return type:
- Args:
path: Destination file path.
- set_stage_loss_weights(stage, **weights)#
Programmatically update loss weights for a stage.
Useful for sweeps or ablations.
- Return type:
- Args:
stage:
'pretrain'or'main'. **weights: Key-value overrides, e.g.inv=2.0,recon=1.5.
Raises:#
KeyError: For unknown weight keys.
Example:
model.set_stage_loss_weights("main", inv=2.0, recon=1.5)
- train_model(dataset, custom_callbacks=None, train_size=0.8, freeze_pretrained=False, require_pretrained=True, **trainer_kwargs)#
Stage-2 main training (both views, full ELBO + invariance + SNN).
- Args:
dataset: Training dataset (
scAnnDataset). custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. freeze_pretrained: IfTrue, freeze the full-gene encoder anddecoder (Stage 1 weights) during Stage 2.
- require_pretrained: If
True(default), raise an error when freeze_pretrained=Truebutpretrain_modelwas never called, preventing silent training against a random frozen decoder.
**trainer_kwargs: Forwarded to
CellPinTrainer.- require_pretrained: If
Returns:#
- :
Fitted
CellPinTrainer.
Raises:#
- RuntimeError: If
require_pretrained=True,freeze_pretrained=True, and pretraining has not been completed.
- training_step(batch, batch_idx)#
Run a training step for the current stage.
- Return type:
Tensor
- validation_step(batch, batch_idx)#
Run a validation step for the current stage.
- Return type:
Tensor
Methods table#
|
|
|
Main training losses. |
|
Pretraining losses (full-gene view only). |
AdamW optimiser with cosine-annealing LR schedule. |
|
|
Embed cells and generate imputed expression values. |
|
Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation). |
|
Encode cells to the latent space via the panel encoder. |
|
Impute with MC averaging and optional count-space normalisation. |
Warm-unfreeze decoder parameters when scheduled. |
|
|
Stage-1 pretraining (full-gene view only, ELBO). |
|
Serialise model weights and hyper-parameters to a |
|
Programmatically update loss weights for a stage. |
|
Stage-2 main training (both views, full ELBO + invariance + SNN). |
|
Run a training step for the current stage. |
|
Run a validation step for the current stage. |
Attributes#
- CellPin.training#
Methods#
- CellPin.compute_losses(batch)#
Main training losses.
Objective:
ELBO via panel encoder (imputation-facing objective).
Invariance loss: KL-distillation or MSE between full-latent and panel-latent.
Soft nearest-neighbour (SNN) alignment.
A single inference pass per view is made; the ELBO, invariance loss, and SNN all share the same sampled
zand means.- Args:
- batch: Must contain
'full_expr'and'panel_expr'. Optionally
'local_l_mean','local_l_var','batch_index'.
- batch: Must contain
Returns:#
- :
Dict with scalar tensors:
'loss','reconst_loss','kl_loss','kl_l_loss','distill_loss','snn_loss','inv_loss','snn_temperature','pearson_loss'.
- CellPin.compute_pretrain_losses(batch)#
Pretraining losses (full-gene view only).
Objective: ELBO on full-gene path.
- Args:
- batch: Must contain
'full_expr'and'panel_expr'. Optionally
'local_l_mean','local_l_var','batch_index'.
- batch: Must contain
Returns:#
- :
Dict with scalar tensors:
'loss','reconst_loss','kl_loss','kl_l_loss'.
- CellPin.configure_optimizers()#
AdamW optimiser with cosine-annealing LR schedule.
- CellPin.embed_and_impute(dataloader, use_mean=True, mc_impute=False, mc_samples=50, mask_fraction=0.2)#
Embed cells and generate imputed expression values.
- CellPin.fit(dataset, pretrain_epochs=50, train_epochs=60, batch_size=256, gradient_clip_val=0.5, early_stopping_patience=10, freeze_pretrained=False, train_size=0.8, save_checkpoints=False, output_dir='./cellpin_output', decoder_warm_unfreeze_epoch=-1, **trainer_kwargs)#
Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).
This is the recommended entry point for training. It runs both stages sequentially with a single call.
- Return type:
- Args:
dataset: Single-cell dataset returned by
cellpin.pp.setup(). pretrain_epochs: Max epochs for Stage 1 (full-gene ELBO pretraining). train_epochs: Max epochs for Stage 2 (panel distillation). batch_size: Mini-batch size for both stages. gradient_clip_val: Gradient clipping value. early_stopping_patience: Epochs without improvement before stopping. freeze_pretrained: Freeze the full-gene encoder/decoder during Stage 2. train_size: Fraction of cells used for training (rest → validation). save_checkpoints: Save model checkpoints tooutput_dir.Disabled by default — enable when you need to resume training or load the best epoch after early stopping.
- output_dir: Root directory for checkpoints and logs
(only used when
save_checkpoints=True).- decoder_warm_unfreeze_epoch: Stage 2 epoch at which the frozen
decoder is unfrozen for warm fine-tuning. Only active when
freeze_pretrained=True.-1(default) keeps the decoder frozen for the entire Stage 2 run.- **trainer_kwargs: Extra arguments forwarded to
CellPinTrainer(e.g.devices,precision,accelerator).
Example:
sc_dataset, _ = cellpin.pp.setup_data(sc_adata, st_adata) model = cellpin.CellPin(sc_dataset) model.fit(sc_dataset)
- CellPin.get_cell_embedding(dataloader, use_mean=True)#
Encode cells to the latent space via the panel encoder.
- Return type:
- Args:
- dataloader: DataLoader over a
use_mean: Return the posterior mean rather than a sample.
Returns:#
- :
Float32 array
(n_cells, n_latent).
- CellPin.impute(dataloader, obs_adata=None, mc_samples=50, mask_fraction=0.2, return_norm=False, norm_target_sum=1000.0, area_key=None, nb_count_samples=100, return_int=False, return_sparse=True, table_key='table')#
Impute with MC averaging and optional count-space normalisation.
- Return type:
- Args:
dataloader: DataLoader to run inference on. obs_adata: Optional AnnData (or
spatialdata.SpatialData) whose.obsis copied to the output. If SpatialData, the AnnData is read fromobs_adata.tables[table_key]and the result is returned as an updated SpatialData object. Must have the same number of observations.- mc_samples: Number of stochastic forward passes for MC averaging
(default 50; more → smoother but slower).
- mask_fraction: Fraction of panel genes randomly zeroed per MC pass
to simulate missing measurements (default 0.2).
- return_norm: If
True, add a log-normalised layer layers['imputed_norm'](total-count or area normalised, then log1p-transformed).- norm_target_sum: Target total counts for normalisation
(default 1e3; only used when
return_norm=True).- area_key:
obscolumn with cell area for area-based normalisation. Auto-detected as
'cell_area'when present; passNonefor total-count normalisation (only used whenreturn_norm=True).- nb_count_samples: Number of NB draws used to compute the MC estimate
of
E[log1p(norm(X))]whenreturn_norm=True(default 100). Because log1p is concave, Jensen’s inequality meanslog1p(norm(E[X])) > E[log1p(norm(X))]; sampling inside the transform corrects this bias. More samples → lower variance.
return_int: If
True, roundXto integer counts (int32). return_sparse: IfTrue(default), storeX,layers['imputed'],and
layers['imputed_norm']asscipy.sparse.csr_matrix. Set toFalseto keep dense numpy arrays.- table_key: Table name to read/write when
obs_adatais a SpatialData object (default
"table").
Returns:#
- :
anndata.AnnDatawithX= imputed (float or int) counts,obsm['X_cellpin']= embeddings,layers['imputed']= copy ofX, and optionallylayers['imputed_norm'].var['is_measured']marks genes present inobs_adata(allTruewhenobs_adataisNone). Ifobs_adatawas a SpatialData object, returns the updated SpatialData with the result stored insdata.tables[table_key].
Raises:#
- ValueError: If
obs_adatahas the wrong number of cells, or if area_keyis specified but not found inadata.obs, or if any cell area is ≤ 0.
- CellPin.pretrain_model(dataset, custom_callbacks=None, train_size=0.8, pretrain_epochs=50, **trainer_kwargs)#
Stage-1 pretraining (full-gene view only, ELBO).
- Args:
dataset: Training dataset. custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. pretrain_epochs: Default max epochs (overridden by
trainer_kwargs['max_epochs']if present).**trainer_kwargs: Forwarded to
CellPinTrainer.
Returns:#
- :
Fitted
CellPinTrainer.
- CellPin.save(path)#
Serialise model weights and hyper-parameters to a
.ptfile.- Return type:
- Args:
path: Destination file path.
- CellPin.set_stage_loss_weights(stage, **weights)#
Programmatically update loss weights for a stage.
Useful for sweeps or ablations.
- Return type:
- Args:
stage:
'pretrain'or'main'. **weights: Key-value overrides, e.g.inv=2.0,recon=1.5.
Raises:#
KeyError: For unknown weight keys.
Example:
model.set_stage_loss_weights("main", inv=2.0, recon=1.5)
- CellPin.train_model(dataset, custom_callbacks=None, train_size=0.8, freeze_pretrained=False, require_pretrained=True, **trainer_kwargs)#
Stage-2 main training (both views, full ELBO + invariance + SNN).
- Args:
dataset: Training dataset (
scAnnDataset). custom_callbacks: Extra PyTorch-Lightning callbacks. train_size: Fraction of data used for training. freeze_pretrained: IfTrue, freeze the full-gene encoder anddecoder (Stage 1 weights) during Stage 2.
- require_pretrained: If
True(default), raise an error when freeze_pretrained=Truebutpretrain_modelwas never called, preventing silent training against a random frozen decoder.
**trainer_kwargs: Forwarded to
CellPinTrainer.- require_pretrained: If
Returns:#
- :
Fitted
CellPinTrainer.
Raises:#
- RuntimeError: If
require_pretrained=True,freeze_pretrained=True, and pretraining has not been completed.
- CellPin.training_step(batch, batch_idx)#
Run a training step for the current stage.
- Return type:
Tensor
- CellPin.validation_step(batch, batch_idx)#
Run a validation step for the current stage.
- Return type:
Tensor