[2]:
import os
from sonnia.sonnia import SoNNia
from sonnia.sonia import Sonia
from sonnia.plotting import Plotter
import numpy as np
import pandas as pd
from sonnia.utils import filter_seqs

load lists of sequences with gene specification

[4]:
# this assume data sequences are in semi-colon separated text file, with gene specification
data_seqs = pd.read_csv('data_seqs.csv.gz')
data_seqs.head()
[4]:
junction_aa v_gene j_gene
0 CASSKQGASEAFF TRBV7-8 TRBJ1-1
1 CASSPPPNYGYTF TRBV6-1 TRBJ1-2
2 CASSTDTTEAFF TRBV6-5 TRBJ1-1
3 CATERGGAPGFADTQYF TRBV11-2 TRBJ2-3
4 CASSLITGENTEAFF TRBV5-4 TRBJ1-1

define and infer model

[5]:
qm = SoNNia(data_seqs=data_seqs,pgen_model='humanTRB')
2025-11-17 12:48:40,022: Adding data seqs.
2025-11-17 12:48:40,123: 200000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta for filtering.
2025-11-17 12:48:40,463: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 12:48:40,478: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 12:48:40,553: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 12:48:40,704: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 12:48:40,770: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 12:48:40,771: 200000 sequences remain. Filtering completed.
2025-11-17 12:48:41,254: Encode data seqs.
Encoding sequence features: 100%|██████████| 200000/200000 [00:03<00:00, 61009.77it/s]
[6]:
# add generated sequences (you can add them from file too)
qm.add_generated_seqs(int(5e5))
2025-11-17 12:48:48,167: Generating 500000 using the pgen model in /home/ubuntu/soNNia/sonnia/default_models/human_T_beta.
2025-11-17 12:48:54,111: Adding gen seqs.
2025-11-17 12:48:54,307: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 12:48:54,308: Using default index (1) for V genes.
2025-11-17 12:48:54,337: Using default index (2) for J genes.
2025-11-17 12:48:54,599: 500000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta for filtering.
2025-11-17 12:48:55,609: 500000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 12:48:55,655: 500000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 12:48:55,840: 500000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 12:48:56,220: 500000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 12:48:56,377: 499998 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 12:48:56,378: 499998 sequences remain. Filtering completed.
2025-11-17 12:48:57,250: Encode gen seqs.
Encoding sequence features: 100%|██████████| 499998/499998 [00:08<00:00, 56300.43it/s]
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.45it/s]
[7]:
#define and train model
qm.infer_selection(epochs=50,batch_size=int(5e3))
2025-11-17 12:57:43,095: Finished training.
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.14it/s]
2025-11-17 12:57:45,472: Updating marginals.
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.16it/s]
2025-11-17 12:57:47,916: Finished updating marginals.

do some plotting

[8]:
plot_sonia=Plotter(qm)
plot_sonia.plot_model_learning()
_images/sonnia_tutorial_8_0.png
[9]:
plot_sonia.plot_vjl()
_images/sonnia_tutorial_9_0.png
_images/sonnia_tutorial_9_1.png

Generate sequences

[10]:
pre_seqs=qm.generate_sequences_pre(int(1e5))
pre_seqs[:3]
[10]:
array([['CASSLEKTRTGGDTGELFF', 'TRBV11-2', 'TRBJ2-2'],
       ['CASSSSAAGVNTEAFF', 'TRBV28', 'TRBJ1-1'],
       ['CASSNLSDNYEQYF', 'TRBV5-1', 'TRBJ2-7']], dtype='<U93')
[11]:
post_seqs=qm.generate_sequences_post(int(1e5))
post_seqs[:3]
Encoding sequence features: 100%|██████████| 1100000/1100000 [00:20<00:00, 53482.96it/s]
Computing energies: 100%|██████████| 11/11 [00:05<00:00,  2.19it/s]
[11]:
array([['CASSLRGLYNEQFF', 'TRBV5-4', 'TRBJ2-1'],
       ['CASSAGVTYEQYF', 'TRBV2', 'TRBJ2-7'],
       ['CASSLFPSLETQYF', 'TRBV7-2', 'TRBJ2-5']], dtype='<U96')

Evaluate sequences

[12]:
Q_data,pgen_data,ppost_data=qm.evaluate_seqs(qm.data_seqs[:int(1e5)].values)
Q_gen,pgen_gen,ppost_gen=qm.evaluate_seqs(pre_seqs)
Q_model,pgen_model,ppost_model=qm.evaluate_seqs(post_seqs)
print(Q_model[:3]),
print(pgen_model[:3])
print(ppost_model[:3])
Encoding sequence features: 100%|██████████| 100000/100000 [00:01<00:00, 62793.12it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Encoding sequence features: 100%|██████████| 100000/100000 [00:01<00:00, 51424.73it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  1.75it/s]
Encoding sequence features: 100%|██████████| 100000/100000 [00:01<00:00, 53724.20it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
[1.7047786 4.289785  3.5878801]
[2.50054270e-08 1.93495475e-08 3.02493194e-10]
[4.26287157e-08 8.30053967e-08 1.08530932e-09]
[14]:
plot_sonia.plot_prob(data=pgen_data,gen=pgen_gen,model=pgen_model,ptype='P_{pre}',n_bins=50)
_images/sonnia_tutorial_15_0.png
[15]:
plot_sonia.plot_prob(ppost_data,ppost_gen,ppost_model,ptype='P_{post}',n_bins=50)
_images/sonnia_tutorial_16_0.png
[17]:
plot_sonia.plot_prob(Q_data,Q_gen,Q_model,ptype='Q',bin_min=-4,bin_max=2,n_bins=50)
_images/sonnia_tutorial_17_0.png

some utils inherited from OLGA

[18]:
# olga functionality can be directly accessed through the main SoNNia model
[19]:
qm.seqgen_model.gen_rnd_prod_CDR3()
[19]:
('TGTGCCAGCAGCTTAGGGGCAATCCGGGACAGGGTTTTGACCTACAATGAGCAGTTCTTC',
 'CASSLGAIRDRVLTYNEQFF',
 84,
 7)
[20]:
qm.genomic_data.genJ[1]
[20]:
['TRBJ1-2*01',
 'CTAACTATGGCTACACCTTC',
 'CTAACTATGGCTACACCTTCGGTTCGGGGACCAGGTTAACCGTTGTAG']
[21]:
qm.pgen_model.PinsDJ
[21]:
array([6.17437e-02, 3.61889e-02, 9.09608e-02, 1.05828e-01, 1.37586e-01,
       1.14643e-01, 9.60481e-02, 8.14864e-02, 6.38634e-02, 4.92164e-02,
       3.93751e-02, 2.90524e-02, 2.30059e-02, 1.64381e-02, 1.45157e-02,
       1.13759e-02, 5.79127e-03, 5.97164e-03, 3.92779e-03, 2.96191e-03,
       2.04381e-03, 2.48417e-03, 9.09996e-04, 1.35102e-03, 2.44798e-04,
       4.52171e-04, 9.17052e-04, 6.28282e-04, 0.00000e+00, 1.41295e-05,
       9.74155e-04, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00])

Save and Load Model

[22]:
qm.save_model('test')
[23]:
qm_new=SoNNia(ppost_model='test')
/home/ubuntu/miniforge3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:797: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 2 variables whereas the saved optimizer has 15 variables.
  saveable.load_own_variables(weights_store.get(inner_path))
[24]:
# By default sequences are not added. Adding them now.
qm_new.update_model(add_data_seqs=qm.data_seqs,add_gen_seqs=qm.gen_seqs)
2025-11-17 13:00:18,774: Adding data seqs.
2025-11-17 13:00:18,877: 200000 sequences before filtering. Using test for filtering.
2025-11-17 13:00:19,314: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:00:19,335: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:00:19,408: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:00:19,552: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:00:19,614: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:00:19,614: 200000 sequences remain. Filtering completed.
2025-11-17 13:00:19,926: Adding gen seqs.
2025-11-17 13:00:20,191: 499998 sequences before filtering. Using test for filtering.
2025-11-17 13:00:21,269: 499998 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:00:21,319: 499998 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:00:21,499: 499998 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:00:21,871: 499998 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:00:22,023: 499998 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:00:22,024: 499998 sequences remain. Filtering completed.
2025-11-17 13:00:23,042: Encode data seqs.
Encoding sequence features: 100%|██████████| 200000/200000 [00:03<00:00, 59394.06it/s]
2025-11-17 13:00:27,049: Encode gen seqs.
Encoding sequence features: 100%|██████████| 499998/499998 [00:08<00:00, 56962.16it/s]
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.28it/s]
[25]:
# Continue inference
qm_new.infer_selection(epochs=5)
qm_new.learning_history.history
2025-11-17 13:01:23,226: Finished training.
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.13it/s]
2025-11-17 13:01:25,589: Updating marginals.
Computing energies: 100%|██████████| 5/5 [00:02<00:00,  2.16it/s]
2025-11-17 13:01:28,041: Finished updating marginals.
[25]:
{'_likelihood': [-0.8920620083808899,
  -0.8960410952568054,
  -0.899197518825531,
  -0.901885449886322,
  -0.9051015973091125],
 'binary_crossentropy': [0.4252631962299347,
  0.42459744215011597,
  0.42381709814071655,
  0.4235226511955261,
  0.42295804619789124],
 'loss': [0.42526325583457947,
  0.42459729313850403,
  0.42381709814071655,
  0.4235227406024933,
  0.422958105802536],
 'val__likelihood': [-0.904228687286377,
  -0.9076792597770691,
  -0.9064220190048218,
  -0.9084171652793884,
  -0.9089504480361938],
 'val_binary_crossentropy': [0.4258957803249359,
  0.42045947909355164,
  0.4227898418903351,
  0.4201054871082306,
  0.42125770449638367],
 'val_loss': [0.4258957505226135,
  0.4204593896865845,
  0.42278987169265747,
  0.4201054573059082,
  0.4212576448917389]}

Compute Diversity and Distance from Reference Distribution

[26]:
print('model entropy is: ', qm_new.entropy(), '[bits]')
Computing energies: 100%|██████████| 1/1 [00:00<00:00, 11.50it/s]
model entropy is:  30.451140341671806 [bits]
[27]:
print('Dkl post gen is: ', qm_new.dkl_post_gen(), '[bits]')
Dkl post gen is:  1.283561 [bits]

Load Default Sonia Models

[28]:
# load beta model
beta_default_model=Sonia(ppost_model='humanTRB')