[1]:
import os
import sonia
from sonnia.sonia import Sonia
from sonnia.sonia_paired import SoniaPaired
from sonia.plotting import Plotter
import sonia
import numpy as np
import pandas as pd

define and infer model

[2]:
data_seqs=np.loadtxt('data_seqs_paired.txt',dtype=str)
print(data_seqs[:5])
qm = SoniaPaired(data_seqs=data_seqs,ppost_model='humanTCR')
[['CASSEGGGSSYEQYF' 'TRBV7-2' 'TRBJ2-7' 'CAMRGPFNKFYF' 'TRAV14/DV4'
  'TRAJ21']
 ['CASRGSGGEIGYTF' 'TRBV5-8' 'TRBJ1-2' 'CAMSVKGAGQNFVF' 'TRAV12-3'
  'TRAJ26']
 ['CASSQAWALTYEQYF' 'TRBV4-1' 'TRBJ2-7' 'CLEHLTGGFKTIF' 'TRAV4' 'TRAJ9']
 ['CASSPRTGTVGNTIYF' 'TRBV2' 'TRBJ1-3' 'CAAQGSSNTGKLIF' 'TRAV13-1'
  'TRAJ37']
 ['CASRTGGGGYTF' 'TRBV28' 'TRBJ1-2' 'CAGFKTIF' 'TRAV13-2' 'TRAJ9']]
2025-11-17 13:03:45,011: Adding data seqs.
2025-11-17 13:03:45,045: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:03:45,045: Using default index (1) for V genes.
2025-11-17 13:03:45,051: Using default index (2) for J genes.
2025-11-17 13:03:45,103: 100000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/heavy_chain for filtering.
2025-11-17 13:03:45,286: 100000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:03:45,294: 100000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:03:45,330: 100000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:03:45,404: 100000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:03:45,437: 100000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:03:45,438: 100000 sequences remain. Filtering completed.
2025-11-17 13:03:45,640: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:03:45,641: Using default index (1) for V genes.
2025-11-17 13:03:45,647: Using default index (2) for J genes.
2025-11-17 13:03:45,701: 100000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/light_chain for filtering.
2025-11-17 13:03:45,901: 100000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:03:45,910: 99986 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:03:45,947: 99986 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:03:46,023: 99986 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:03:46,054: 99986 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:03:46,055: 99986 sequences remain. Filtering completed.
2025-11-17 13:03:46,318: Encode data seqs.
Encoding sequence features: 100%|██████████| 99986/99986 [00:03<00:00, 32628.96it/s]
[3]:
# add generated sequences (you can add them from file too, more is better.)
qm.add_generated_seqs(int(2e5))
2025-11-17 13:03:51,284: Generating 200000 using the light pgen model in /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/light_chain and the heavy pgen model in /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/heavy_chain.
2025-11-17 13:03:55,400: Adding gen seqs.
2025-11-17 13:03:55,496: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:03:55,497: Using default index (1) for V genes.
2025-11-17 13:03:55,512: Using default index (2) for J genes.
2025-11-17 13:03:55,615: 200000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/heavy_chain for filtering.
2025-11-17 13:03:56,010: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:03:56,028: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:03:56,099: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:03:56,250: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:03:56,318: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:03:56,319: 200000 sequences remain. Filtering completed.
2025-11-17 13:03:56,727: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:03:56,728: Using default index (1) for V genes.
2025-11-17 13:03:56,740: Using default index (2) for J genes.
2025-11-17 13:03:56,841: 200000 sequences before filtering. Using /home/ubuntu/soNNia/sonnia/default_models/human_T_beta_alpha/light_chain for filtering.
2025-11-17 13:03:57,211: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:03:57,225: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:03:57,297: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:03:57,446: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:03:57,512: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:03:57,512: 200000 sequences remain. Filtering completed.
2025-11-17 13:03:58,066: Encode gen seqs.
Encoding sequence features: 100%|██████████| 200000/200000 [00:06<00:00, 30997.91it/s]
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s]
[4]:
#define and train model
qm.infer_selection(epochs=30,batch_size=int(1e4))
2025-11-17 13:05:31,661: Finished training.
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.09it/s]
2025-11-17 13:05:33,539: Updating marginals.
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.08it/s]
2025-11-17 13:05:35,530: Finished updating marginals.

Do some plotting

[9]:
plot_sonia=Plotter(qm)
plot_sonia.plot_model_learning()
_images/sonnia_paired_tutorial_6_0.png

Generate sequences

[10]:
pre_seqs=qm.generate_sequences_pre(int(1e5),nucleotide=False)
pre_seqs[:3]
[10]:
array([['CASPRQPQPQHF', 'TRBV14', 'TRBJ1-5', 'CAGAVTNAGKSTF', 'TRAV27',
        'TRAJ27'],
       ['CSASHDHILYEKLFF', 'TRBV20-1', 'TRBJ1-4', 'CAAGVMDSSYKLIF',
        'TRAV13-1', 'TRAJ12'],
       ['CASSWDYQGILNYEQYF', 'TRBV5-4', 'TRBJ2-7', 'CGTNAGKSTF',
        'TRAV27', 'TRAJ27']], dtype='<U93')
[7]:
post_seqs=qm.generate_sequences_post(int(1e5),nucleotide=False)
post_seqs[:5]
Encoding sequence features: 100%|██████████| 1100000/1100000 [00:39<00:00, 27907.16it/s]
Computing energies: 100%|██████████| 11/11 [00:09<00:00,  1.12it/s]
Encoding sequence features: 100%|██████████| 1100000/1100000 [00:39<00:00, 27659.77it/s]
Computing energies: 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
[7]:
array([['CASSPPPDTQYF', 'TRBV7-2', 'TRBJ2-3', 'CADSRGSQGNLIF', 'TRAV41',
        'TRAJ42'],
       ['CATSSGNTIYF', 'TRBV15', 'TRBJ1-3', 'CAANRGGKLIF', 'TRAV13-1',
        'TRAJ23'],
       ['CASSPWPGQGHTIYF', 'TRBV2', 'TRBJ1-3', 'CAVRPGNTPLVF',
        'TRAV12-2', 'TRAJ29'],
       ['CASSLGQFNTEAFF', 'TRBV11-3', 'TRBJ1-1', 'CAATKGGGNKLTF',
        'TRAV13-1', 'TRAJ10'],
       ['CASSLRPGQAFF', 'TRBV5-1', 'TRBJ1-1', 'CLLGDGGGSNYKLTF',
        'TRAV40', 'TRAJ53']], dtype='<U93')

Evaluate sequences

[8]:
Q_data,pgen_data,ppost_data=qm.evaluate_seqs(qm.data_seqs[:int(1e5)].values)
Q_gen,pgen_gen,ppost_gen=qm.evaluate_seqs(pre_seqs)
Q_model,pgen_model,ppost_model=qm.evaluate_seqs(post_seqs)
print(Q_model[:3]),
print(pgen_model[:3])
print(ppost_model[:3])
Encoding sequence features:   3%|▎         | 3242/99986 [00:00<00:02, 32416.48it/s]Encoding sequence features: 100%|██████████| 99986/99986 [00:03<00:00, 32209.16it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s]
Encoding sequence features: 100%|██████████| 100000/100000 [00:03<00:00, 25189.45it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s]
Encoding sequence features: 100%|██████████| 100000/100000 [00:03<00:00, 28078.54it/s]
Computing energies: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
[0.5769772 3.9851768 3.6605735]
[1.26058019e-15 2.72032238e-15 3.89377588e-19]
[7.27326022e-16 1.08409657e-14 1.42534527e-18]
[11]:
plot_sonia.plot_prob(data=pgen_data,gen=pgen_gen,model=pgen_model,ptype='P_{pre}',bin_min=-40,bin_max=-10)
_images/sonnia_paired_tutorial_12_0.png
[12]:
plot_sonia.plot_prob(ppost_data,ppost_gen,ppost_model,ptype='P_{post}',bin_min=-40,bin_max=-10)
_images/sonnia_paired_tutorial_13_0.png
[15]:
plot_sonia.plot_prob(Q_data,Q_gen,Q_model,ptype='Q',bin_min=-4,bin_max=2)
_images/sonnia_paired_tutorial_14_0.png

some utils inherited from OLGA

[16]:
# olga functionality can be directly accessed throu2gh the main SoNNia model
[17]:
qm.seqgen_model_light.gen_rnd_prod_CDR3(),qm.seqgen_model_heavy.gen_rnd_prod_CDR3()
[17]:
(('TGCATCTCCCGGCCCCTCACCTTT', 'CISRPLTF', 43, 2),
 ('TGTGCCAGCAGCCATGTACCAACGGACGCTTATAGCAATCAGCCCCAGCATTTT',
  'CASSHVPTDAYSNQPQHF',
  46,
  4))
[18]:
qm.genomic_data_light.genJ[1],qm.genomic_data_heavy.genJ[1]
[18]:
(['TRAJ10*01',
  'ATACTCACGGGAGGAGGAAACAAACTCACCTTT',
  'ATACTCACGGGAGGAGGAAACAAACTCACCTTTGGGACAGGCACTCAGCTAAAAGTGGAACTCA'],
 ['TRBJ1-2*01',
  'CTAACTATGGCTACACCTTC',
  'CTAACTATGGCTACACCTTCGGTTCGGGGACCAGGTTAACCGTTGTAG'])
[19]:
qm.pgen_model_heavy.PinsDJ,qm.pgen_model_light.PinsVJ
[19]:
(array([7.19554e-02, 5.63345e-02, 8.55143e-02, 1.27028e-01, 1.21860e-01,
        1.08232e-01, 8.59492e-02, 7.45618e-02, 5.07897e-02, 3.88103e-02,
        3.57069e-02, 2.68135e-02, 2.15603e-02, 1.95199e-02, 1.66126e-02,
        1.44175e-02, 1.02015e-02, 9.54077e-03, 7.30822e-03, 5.44689e-03,
        4.32498e-03, 2.46807e-03, 1.75657e-03, 9.34188e-04, 5.72187e-04,
        7.20642e-04, 4.44661e-04, 1.06769e-04, 7.10554e-05, 3.08047e-04,
        1.29150e-04, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00]),
 array([2.80720e-02, 3.70881e-02, 7.18272e-02, 1.11361e-01, 1.45677e-01,
        1.17451e-01, 9.39653e-02, 8.25969e-02, 7.71960e-02, 5.37416e-02,
        4.99160e-02, 3.79668e-02, 2.58719e-02, 1.95766e-02, 1.36449e-02,
        8.74302e-03, 5.57685e-03, 3.94724e-03, 2.88705e-03, 1.72847e-03,
        1.77234e-03, 1.44091e-03, 1.39968e-03, 1.40263e-03, 1.25466e-03,
        6.01223e-04, 1.11938e-03, 7.01411e-04, 3.20548e-04, 2.40517e-04,
        3.74385e-05, 8.82230e-05, 8.38335e-05, 1.07031e-07, 3.11986e-05,
        5.83995e-06, 2.36875e-05, 1.06530e-10, 7.64546e-09, 2.24623e-04,
        4.17104e-04, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00]))

Save and Load Model

[20]:
qm.save_model('test_paired')
[21]:
qm_new=SoniaPaired(ppost_model='test_paired')
[22]:
# By default sequences are not added. Adding them now.
qm_new.update_model(add_data_seqs=qm.data_seqs,add_gen_seqs=qm.gen_seqs)
2025-11-17 13:11:05,081: Adding data seqs.
2025-11-17 13:11:05,144: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:11:05,145: Using default index (1) for V genes.
2025-11-17 13:11:05,151: Using default index (2) for J genes.
2025-11-17 13:11:05,205: 99986 sequences before filtering. Using test_paired/heavy_chain for filtering.
2025-11-17 13:11:05,391: 99986 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:11:05,399: 99986 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:11:05,436: 99986 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:11:05,508: 99986 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:11:05,540: 99986 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:11:05,540: 99986 sequences remain. Filtering completed.
2025-11-17 13:11:05,709: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:11:05,710: Using default index (1) for V genes.
2025-11-17 13:11:05,717: Using default index (2) for J genes.
2025-11-17 13:11:05,769: 99986 sequences before filtering. Using test_paired/light_chain for filtering.
2025-11-17 13:11:05,976: 99986 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:11:05,984: 99986 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:11:06,021: 99986 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:11:06,096: 99986 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:11:06,126: 99986 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:11:06,126: 99986 sequences remain. Filtering completed.
2025-11-17 13:11:06,388: Adding gen seqs.
2025-11-17 13:11:06,541: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:11:06,542: Using default index (1) for V genes.
2025-11-17 13:11:06,555: Using default index (2) for J genes.
2025-11-17 13:11:06,659: 200000 sequences before filtering. Using test_paired/heavy_chain for filtering.
2025-11-17 13:11:07,042: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:11:07,059: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:11:07,131: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:11:07,281: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:11:07,344: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:11:07,345: 200000 sequences remain. Filtering completed.
2025-11-17 13:11:07,660: Using default index (0) for amino acid CDR3 sequences.
2025-11-17 13:11:07,661: Using default index (1) for V genes.
2025-11-17 13:11:07,676: Using default index (2) for J genes.
2025-11-17 13:11:07,782: 200000 sequences before filtering. Using test_paired/light_chain for filtering.
2025-11-17 13:11:08,223: 200000 sequences remain after removing sequences with V genes inconsistent with the model.
2025-11-17 13:11:08,247: 200000 sequences remain after removing sequences with J genes inconsistent with the model.
2025-11-17 13:11:08,319: 200000 sequences remain after removing data which are unproductive amino acid sequences.
2025-11-17 13:11:08,463: 200000 sequences remain after removing sequences that do not begin with a 'C' or end in a ['A', 'B', 'C', 'E', 'D', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'].
2025-11-17 13:11:08,560: 200000 sequences remain after removing sequences with CDR3 length larger than 30.
2025-11-17 13:11:08,561: 200000 sequences remain. Filtering completed.
2025-11-17 13:11:09,165: Encode data seqs.
Encoding sequence features: 100%|██████████| 99986/99986 [00:03<00:00, 30649.70it/s]
2025-11-17 13:11:13,070: Encode gen seqs.
Encoding sequence features: 100%|██████████| 200000/200000 [00:06<00:00, 30692.48it/s]
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.10it/s]
[23]:
# Continue inference
qm_new.infer_selection(epochs=5)
qm_new.learning_history.history
2025-11-17 13:11:35,447: Finished training.
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.07it/s]
2025-11-17 13:11:37,360: Updating marginals.
Computing energies: 100%|██████████| 2/2 [00:01<00:00,  1.07it/s]
2025-11-17 13:11:39,346: Finished updating marginals.
[23]:
{'_likelihood': [-1.7416709661483765,
  -1.7556878328323364,
  -1.7678178548812866,
  -1.7807191610336304,
  -1.7920759916305542],
 'binary_crossentropy': [0.3523142635822296,
  0.3505375385284424,
  0.34897327423095703,
  0.3475055992603302,
  0.3461282253265381],
 'loss': [0.3523143529891968,
  0.3505374491214752,
  0.3489736020565033,
  0.34750598669052124,
  0.3461282551288605],
 'val__likelihood': [-1.761418342590332,
  -1.7723807096481323,
  -1.7818650007247925,
  -1.7898236513137817,
  -1.7986449003219604],
 'val_binary_crossentropy': [0.3517608642578125,
  0.35050806403160095,
  0.34941861033439636,
  0.34838756918907166,
  0.34734925627708435],
 'val_loss': [0.3517608344554901,
  0.35050803422927856,
  0.349418580532074,
  0.3483875095844269,
  0.34734922647476196]}

Compute Diversity and Distance from Reference Distribution

[24]:
print('model entropy is: ', qm_new.entropy(), '[bits]')
Computing energies:   0%|          | 0/1 [00:00<?, ?it/s]Computing energies: 100%|██████████| 1/1 [00:00<00:00,  5.63it/s]
2025-11-17 13:12:18,373: 2 sequences have zero Pgen, we remove them in the evaluation of the entropy
model entropy is:  56.75602773959275 [bits]
[25]:
print('Dkl post gen is: ', qm_new.dkl_post_gen(), '[bits]')
Dkl post gen is:  2.0906377 [bits]

Load default model

[27]:
default_model=SoniaPaired(ppost_model='humanTCR')