cellpin.models.CellPin.fit

Contents

cellpin.models.CellPin.fit#

CellPin.fit(dataset, pretrain_epochs=50, train_epochs=60, batch_size=256, gradient_clip_val=0.5, early_stopping_patience=10, freeze_pretrained=False, train_size=0.8, save_checkpoints=False, output_dir='./cellpin_output', decoder_warm_unfreeze_epoch=-1, **trainer_kwargs)#

Train CellPin: Stage 1 (pretrain) followed by Stage 2 (distillation).

This is the recommended entry point for training. It runs both stages sequentially with a single call.

Return type:

None

Args:

dataset: Single-cell dataset returned by cellpin.pp.setup(). pretrain_epochs: Max epochs for Stage 1 (full-gene ELBO pretraining). train_epochs: Max epochs for Stage 2 (panel distillation). batch_size: Mini-batch size for both stages. gradient_clip_val: Gradient clipping value. early_stopping_patience: Epochs without improvement before stopping. freeze_pretrained: Freeze the full-gene encoder/decoder during Stage 2. train_size: Fraction of cells used for training (rest → validation). save_checkpoints: Save model checkpoints to output_dir.

Disabled by default — enable when you need to resume training or load the best epoch after early stopping.

output_dir: Root directory for checkpoints and logs

(only used when save_checkpoints=True).

decoder_warm_unfreeze_epoch: Stage 2 epoch at which the frozen

decoder is unfrozen for warm fine-tuning. Only active when freeze_pretrained=True. -1 (default) keeps the decoder frozen for the entire Stage 2 run.

**trainer_kwargs: Extra arguments forwarded to

CellPinTrainer (e.g. devices, precision, accelerator).

Example:

sc_dataset, _ = cellpin.pp.setup_data(sc_adata, st_adata)
model = cellpin.CellPin(sc_dataset)
model.fit(sc_dataset)