Training and Fine-Tuning Models¶
Jaeger provides a complete training pipeline for building custom phage detection models from scratch or fine-tuning existing ones on new data.
Table of contents¶
Overview¶
Jaeger’s architecture consists of three components:
Representation learner — A 1D convolutional network with residual blocks that learns sequence embeddings from translated DNA (6-frame codon embeddings).
Classification head — Predicts the class of each window (bacteria, phage, eukarya, archaea, plasmid, virus).
Reliability head — An out-of-distribution detector that estimates prediction confidence.
All three can be trained jointly or independently.
Training workflow¶
1. Prepare FASTA files per class
↓
2. Generate fragments (jaeger utils dataset / fragment)
↓
3. Convert to CSV format (jaeger utils convert)
↓
4. Create training config YAML
↓
5. Run training (jaeger train -c config.yaml)
↓
6. Save model and register (jaeger register-models)
Preparing training data¶
Step 1: Collect reference sequences¶
Organize your reference genomes into separate FASTA files by class:
data/
├── bacteria.fasta
├── phage.fasta
├── eukarya.fasta
├── archaea.fasta
├── plasmid.fasta
└── virus.fasta
Step 2: Generate training fragments¶
Use jaeger utils dataset to create non-redundant fragment databases:
# Fragment bacteria genomes into 2048 bp pieces with 60% identity filtering
jaeger utils dataset \
-i bacteria.fasta \
-o bacteria_fragments.csv \
--itype fasta \
--outtype csv \
--fraglen 2048 \
--overlap 1024 \
--maxiden 0.6 \
--maxcov 0.6 \
--class 0 \
--seq_col 1 \
--class_col 0
Parameters explained:
Parameter |
Description |
Typical value |
|---|---|---|
|
Maximum fragment length |
2048 |
|
Overlap between adjacent fragments |
1024 |
|
Maximum identity between any two fragments |
0.6 |
|
Maximum coverage between any two fragments |
0.6 |
|
Numeric class label |
0=bacteria, 1=phage, etc. |
|
Fraction for validation set |
0.1 |
|
Fraction for training set |
0.8 |
|
Fraction for test set |
0.1 |
Step 3: Simulate metagenome fragments¶
For more realistic training data, simulate variable-length fragments:
jaeger utils fragment \
-i phage.fasta \
-o phage_fragments.fasta \
--minlen 1000 \
--maxlen 5000 \
--overlap 0
Then convert to CSV:
jaeger utils convert \
-i phage_fragments.fasta \
-o phage_fragments.csv \
--itype fasta
Step 4: Generate OOD data for reliability training¶
The reliability head needs out-of-distribution examples. Use shuffled sequences:
jaeger utils ood-data \
-i combined_train.csv \
-o ood_train.csv \
--itype csv \
--otype csv \
--dinuc
Configuration file¶
Training is controlled by a YAML configuration file. A template is provided at train_config/nn_config.yaml.
Minimal example¶
model:
name: "my_jaeger_model"
experiment: 1
seed: 42
classifier_out_dim: 6
reliability_out_dim: 1
base_dir: "/path/to/experiments"
class_label_map:
- class: "bacteria"
label: 0
- class: "phage"
label: 1
- class: "eukarya"
label: 2
- class: "archaea"
label: 3
- class: "plasmid"
label: 4
- class: "virus"
label: 5
embedding:
use_embedding_layer: true
input_type: "translated"
strands: 2
frames: 6
length: null
input_shape: [6, null]
embedding_size: 192
training:
data_dir: "/path/to/training/data"
experiment_root: "exp_001"
epochs: 100
batch_size: 64
learning_rate: 0.001
fragment_classifier_data:
train:
- class: ["bacteria", "phage", "eukarya", "archaea", "plasmid", "virus"]
path:
- "{{ training.data_dir }}/train_data.csv"
label: [0, 1, 2, 3, 4, 5]
validation:
- class: ["bacteria", "phage", "eukarya", "archaea", "plasmid", "virus"]
path:
- "{{ training.data_dir }}/validation_data.csv"
label: [0, 1, 2, 3, 4, 5]
fragment_reliability_data:
train:
- class: [indist, ood]
path:
- "{{ training.data_dir }}/train_data_ood.csv"
label: [1, 0]
validation:
- class: [indist, ood]
path:
- "{{ training.data_dir }}/validation_data_ood.csv"
label: [1, 0]
Key configuration sections¶
Section |
Purpose |
|---|---|
|
Architecture, embedding, class labels |
|
CNN layers, residual blocks, attention |
|
Classification head architecture |
|
Reliability (OOD) head architecture |
|
Optimizer, batch size, epochs, callbacks |
|
Paths to classification training data |
|
Paths to reliability training data |
Running training¶
From scratch¶
jaeger train -c train_config/nn_config.yaml
With mixed precision (faster on modern GPUs)¶
jaeger train -c train_config/nn_config.yaml --mixed_precision
Resume from checkpoint¶
jaeger train -c train_config/nn_config.yaml --from_last_checkpoint
Save model without training¶
If you already have checkpoints and just want to export a SavedModel:
jaeger train -c train_config/nn_config.yaml --only_save
Fine-tuning¶
Fine-tuning allows you to adapt a pre-trained model to new data without training from scratch.
Freeze the representation learner, train only heads¶
jaeger train -c fine_tune_config.yaml --only_heads
Train only the classification head¶
jaeger train -c fine_tune_config.yaml --only_classification_head
Train only the reliability head¶
jaeger train -c fine_tune_config.yaml --only_reliability_head
Tips for fine-tuning¶
Use a lower learning rate (e.g., 1e-4 vs 1e-3) to avoid catastrophic forgetting.
Start from the pre-trained model’s checkpoints by setting the correct paths in your config.
Use
--from_last_checkpointto resume interrupted fine-tuning runs.
Self-supervised pretraining¶
You can pretrain the representation learner with self-supervised learning before supervised classification:
jaeger train -c pretrain_config.yaml --self_supervised_pretraining
This is useful when you have large amounts of unlabeled sequence data.
Creating ensembles¶
Combine multiple trained models into an ensemble for improved robustness:
jaeger utils combine-models \
-i /path/to/model1 \
-i /path/to/model2 \
-i /path/to/model3 \
-o /path/to/ensemble \
-c mean
Aggregation methods:
Method |
Description |
|---|---|
|
Majority voting |
|
Sum of logits |
|
Mean of logits (recommended) |
|
No aggregation (returns all outputs) |
Command reference¶
jaeger train¶
Usage: jaeger train [OPTIONS]
Options:
-c, --config PATH Training config YAML [required]
--only_classification_head Train only classification head
--only_reliability_head Train only reliability head
--self_supervised_pretraining Self-supervised pretraining
--only_heads Train both heads, freeze representation
--from_last_checkpoint Resume from last checkpoint
--force Delete existing checkpoints and restart
--save_model Save model from last checkpoint
--only_save Save model without training
--mixed_precision Use mixed-precision floats
--meta PATH Write container metadata
-v, --verbose Verbosity: -vv debug, -v info
--help Show this message and exit.
jaeger register-models¶
After training, register your model so jaeger predict can find it:
jaeger register-models --path /path/to/my_model
Data augmentation tips¶
Sequence masking¶
Gradually mask random positions to improve robustness:
jaeger utils mask \
-i train.fasta \
-o train_masked.fasta \
--minperc 0.0 \
--maxperc 0.3 \
--step 0.05
Sequence mutation¶
Introduce random mutations instead of masking:
jaeger utils mask \
-i train.fasta \
-o train_mutated.fasta \
--mutate \
--minperc 0.0 \
--maxperc 0.1 \
--step 0.01