Cycle-Consistent Antigen Reconstruction + Direct Attention + ESM-2
AgForce Enables Antigen-conditioned Generative Antibody Design
1. Attention Modes (Three Options)
Based on ablation: gate=0 had ZERO impact (gate is dead). Three modes available:
# 'no_gate' (RECOMMENDED): Remove dead gate, keep concat
combined = torch.cat([cdr_h, attn_out], dim=-1) # (L, 512)
# 'direct': Aggressive, attn_out only (lost too much context in testing)
combined = attn_out + pos_emb # (L, 256) - needs position info
# False: Original gated bottleneck (gate does nothing empirically)
combined = self.antigen_gate(cdr_h, attn_out) # (L, 512)
no_gate mode preserves the original capacity while removing the dead gate mechanism.
2. Antigen Classification Head (Cycle Consistency)
New head that takes predicted CDR softmax probabilities and classifies which antigen they were designed for:
class AntigenClassificationHead(nn.Module):
"""Cycle consistency: if model ignores antigen, all CDRs look similar,
and antigen classification becomes impossible."""
def forward(self, cdr_probs_list, antigen_embs):
# Encode soft sequences via weighted sum of AA embeddings
cdr_embs = [self.encode_soft_sequence(p) for p in cdr_probs_list]
# Project and compute similarity to all antigens in batch
logits = matmul(proj(stack(cdr_embs)), antigen_embs.T)
return logits # (B, B) contrastive classification
The loss is CrossEntropy(logits, arange(B)) -- each CDR should classify its own antigen.
Key insight: This loss flows through the seq_head output. If the model collapses to identical predictions for all antigens, classification accuracy = 1/B (random). The only way to minimize this loss is to produce antigen-specific CDRs.
MINE, InfoNCE, and latent coupling are disabled (lambda=0). Empirically:
- MINE loss optimized (-1.13 nats) but val_aar unchanged
- InfoNCE loss optimized (0.88) but val_aar unchanged
- These losses operate on GNN embeddings, but prediction reads cross-attention -- different pathways
anneal_base: 0.97 (was 0.9). At epoch 30:
- Old: lr = 1e-3 * 0.9^30 ≈ 4e-5 (too small to escape local optima)
- New: lr = 1e-3 * 0.97^30 ≈ 4e-4 (10x higher)
5. Precomputed ESM-2 Embeddings
Adds frozen ESM-2 (650M) embeddings for CDR positions:
# Precompute embeddings (run once)
python precompute_esm.py --esm_model esm2_t33_650M_UR50D
# Embeddings saved to: ./data/processed/esm_embeddings/esm2_t33_650M_UR50D/esm_cdr{1,2,3}.pt
The embeddings are computed with CDR positions masked (replaced with <mask> token), then extracted at CDR positions. This provides rich evolutionary context without information leakage.
Integration:
- esm_proj: Linear(1280, hidden_size) projects ESM embeddings
- Combined features: concat(cdr_h, attn_out, esm_h) -> seq_head
- Memory efficient: ~2.5GB GPU savings vs runtime ESM forward pass
Input: Ab-Ag complex
|
v
ProteinFeatureEncoder (105D + 3D segment) -> H_0
|
v
Framework Dropout (p=0.3, training only)
|
v
VirtualNodeEGNN (4 layers, 3 virtual nodes)
|
v
aa_embd (GNN output for all residues)
|
+---> cdr_h (CDR positions from GNN)
|
+---> LearnedCDRQuery (pos_emb + cdr_type + epitope_context)
| |
| v
| Cross-Attention(query, ag_h, ag_h) -> attn_out
|
+---> ESM-2 (frozen, precomputed) -> esm_cdr_emb -> esm_proj -> esm_h
|
v
[no_gate mode] concat(cdr_h, attn_out, esm_h) -> (L, 768)
|
v
seq_head(combined) -> logits
|
+---> softmax(logits) -> probs
| |
| v
| AntigenClassificationHead(probs, antigen_embs) -> ag_cls_loss [NEW]
|
v
Losses: seq + coord + pairing + dock + shadow + gdpp + aux + antigen_cls
If swap similarity drops significantly, conditioning works and AAR should follow.
# Step 1: Precompute ESM embeddings (run once, takes ~30 min)
python precompute_esm.py --esm_model esm2_t33_650M_UR50D --cdr_types 1,2,3
# Step 2: Train
# Quick test
python chimera_trainer.py --max_epoch 2 --no-wandb
# Full training
python chimera_trainer.py --split epitope_group --cdr_type 3
# Test only (load checkpoint)
python chimera_trainer.py --test_only --checkpoint /path/to/best.pt
Key settings in config.yaml:
dataset:
esm_model: esm2_t33_650M_UR50D
esm_embeddings_dir: ./data/processed/esm_embeddings
model:
# Attention mode: 'no_gate' (recommended), 'direct', or false
use_direct_attn: no_gate # Remove dead gate, keep concat(cdr_h, attn_out)
lambda_antigen_cls: 0.5 # antigen classification weight (cycle consistency)
lambda_infonce: 0.0 # disabled (empirically inert)
lambda_mine: 0.0 # disabled (empirically inert)
zeta: 0.0 # latent coupling disabled
# ESM-2 embeddings
use_esm: true # Enable precomputed ESM embeddings
esm_dim: 1280 # ESM-2 650M dimension
training:
anneal_base: 0.97 # slower LR decay (0.9 -> 0.97)
- code/models.py - DockDesigner with use_direct_attn, AntigenClassificationHead, and ESM projection
- code/v2_dataset.py - Dataset with ESM embeddings loading
- config.yaml - defaults (direct_attn on, MI losses off, ESM on)
- chimera_trainer.py - 13-tuple loss handling, ESM integration
- precompute_esm.py - One-time ESM-2 embedding precomputation