Save and load model parameters by stCluster
In certain scenarios, users may find it necessary to retain the embeddings or model parameters acquired during training for subsequent training sessions or to ensure reproducibility in other hardwares. The stCluster framework offers a method that enables users to efficiently and promptly save and load model parameters. The following tutorial will utilize the ZESTA dataset to illustrate this process.
Frist, we load the dataset, train latent representation by stCluster and save the model parameters and embedding matrix by setting attribute save_model and save_embedding in function stCluster.train.train().
[1]:
from st_datasets.dataset import get_data, get_zesta_data
from stCluster.train import train
adata, n_cluster = get_data(get_zesta_data)
_ = train(adata, radius=15, save_model='zesta_model.pkl', save_embedding='zesta_embedding.npy')
>>> INFO: Download dataset: 100%|██████████| 234M/234M [04:01<00:00, 1.01MB/s]
>>> INFO: dataset name: zebrafish embryogenesis spatiotemporal transcriptomic atlas (ZESTA), size: (13166, 26628), cluster: 45.(242.497s)
>>> INFO: Input size torch.Size([13166, 3000]).
>>> INFO: Graph contains 41704 edges, average 3.168 edges per node.
>>> INFO: Build graph success!
>>> INFO: Finish generate precluster embedding!
>>> INFO: Finish pre-cluster, result image is saved at "None", begin to prune graph.
>>> INFO: Finish pruning graph, result image is saved at "None".
>>> INFO: Graph contains 367103 edges, average 27.883 edges per node.
>>> INFO: Build graph success!
>>> INFO: Finish model preparations, begin to train model, input data size: (13166, 3000).
>>> INFO: Training: 100%|██████████| 1000/1000 [01:13<00:00, 13.55it/s]
>>> INFO: Successfully save embedding at zesta_embedding.npy.
>>> INFO: Successfully export model at zesta_model.pkl.
>>> INFO: Finish embedding process, total time: 144.739s.
Load model parameters
We can easily load the model parameters in another device, generate embedding again and do downstream analysis.
[2]:
from st_datasets.dataset import get_data, get_zesta_data
from stCluster.run import load_and_evaluate
adata, n_cluster = get_data(get_zesta_data)
adata, score = load_and_evaluate(adata, radius=15, n_cluster=n_cluster, cluster_method='mclust', cluster_score_method='ARI', model_paras_path='zesta_model.pkl')
print(score)
>>> INFO: Use locate data.
>>> INFO: dataset name: zebrafish embryogenesis spatiotemporal transcriptomic atlas (ZESTA), size: (13166, 26628), cluster: 45.(0.424s)
>>> INFO: Input size torch.Size([13166, 3000]).
>>> INFO: Graph contains 41704 edges, average 3.168 edges per node.
>>> INFO: Build graph success!
>>> INFO: Finish load model, begin to generate embedding and rebuild gene expression, input data size: (13166, 3000).
>>> INFO: Finish embedding generation process, please use the embedding to do downstream evaluation, total time: 0.340s
R[write to console]: __ __
____ ___ _____/ /_ _______/ /_
/ __ `__ \/ ___/ / / / / ___/ __/
/ / / / / / /__/ / /_/ (__ ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/ version 5.4.10
Type 'citation("mclust")' for citing this R package in publications.
fitting ...
|======================================================================| 100%
{'mclust': 0.35190349034463836}
The model parameters obtained through training with stCluster can be acquired from the containers we have provided at the following locations: \root\stCluster_paras. Utilizing the aforementioned resources, you can readily and expeditiously generate latent representation in your device and do downstream analytical tasks.