← 返回首页
GitHub - mansoor181/ag-force: AgForce generates realistic and diverse antigen-conditioned antibody CDR design via multiple choice learning · GitHub
Skip to content

Navigation Menu

Toggle navigation
Sign in
Appearance settings
Search or jump to...

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Include my email address so I can be contacted

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Resetting focus

mansoor181/ag-force

Go to file
Code

Repository files navigation

Cycle-Consistent Antigen Reconstruction + Direct Attention + ESM-2

AgForce Enables Antigen-conditioned Generative Antibody Design

1. Attention Modes (Three Options)

Based on ablation: gate=0 had ZERO impact (gate is dead). Three modes available:

# 'no_gate' (RECOMMENDED): Remove dead gate, keep concat combined = torch.cat([cdr_h, attn_out], dim=-1) # (L, 512) # 'direct': Aggressive, attn_out only (lost too much context in testing) combined = attn_out + pos_emb # (L, 256) - needs position info # False: Original gated bottleneck (gate does nothing empirically) combined = self.antigen_gate(cdr_h, attn_out) # (L, 512)

no_gate mode preserves the original capacity while removing the dead gate mechanism.

2. Antigen Classification Head (Cycle Consistency)

New head that takes predicted CDR softmax probabilities and classifies which antigen they were designed for:

class AntigenClassificationHead(nn.Module): """Cycle consistency: if model ignores antigen, all CDRs look similar, and antigen classification becomes impossible.""" def forward(self, cdr_probs_list, antigen_embs): # Encode soft sequences via weighted sum of AA embeddings cdr_embs = [self.encode_soft_sequence(p) for p in cdr_probs_list] # Project and compute similarity to all antigens in batch logits = matmul(proj(stack(cdr_embs)), antigen_embs.T) return logits # (B, B) contrastive classification

The loss is CrossEntropy(logits, arange(B)) -- each CDR should classify its own antigen.

Key insight: This loss flows through the seq_head output. If the model collapses to identical predictions for all antigens, classification accuracy = 1/B (random). The only way to minimize this loss is to produce antigen-specific CDRs.

3. Disable Inert Losses

MINE, InfoNCE, and latent coupling are disabled (lambda=0). Empirically:

  • MINE loss optimized (-1.13 nats) but val_aar unchanged
  • InfoNCE loss optimized (0.88) but val_aar unchanged
  • These losses operate on GNN embeddings, but prediction reads cross-attention -- different pathways

4. Slower LR Decay

anneal_base: 0.97 (was 0.9). At epoch 30:

  • Old: lr = 1e-3 * 0.9^30 ≈ 4e-5 (too small to escape local optima)
  • New: lr = 1e-3 * 0.97^30 ≈ 4e-4 (10x higher)

5. Precomputed ESM-2 Embeddings

Adds frozen ESM-2 (650M) embeddings for CDR positions:

# Precompute embeddings (run once) python precompute_esm.py --esm_model esm2_t33_650M_UR50D # Embeddings saved to: ./data/processed/esm_embeddings/esm2_t33_650M_UR50D/esm_cdr{1,2,3}.pt

The embeddings are computed with CDR positions masked (replaced with <mask> token), then extracted at CDR positions. This provides rich evolutionary context without information leakage.

Integration:

  • esm_proj: Linear(1280, hidden_size) projects ESM embeddings
  • Combined features: concat(cdr_h, attn_out, esm_h) -> seq_head
  • Memory efficient: ~2.5GB GPU savings vs runtime ESM forward pass

Architecture

Input: Ab-Ag complex | v ProteinFeatureEncoder (105D + 3D segment) -> H_0 | v Framework Dropout (p=0.3, training only) | v VirtualNodeEGNN (4 layers, 3 virtual nodes) | v aa_embd (GNN output for all residues) | +---> cdr_h (CDR positions from GNN) | +---> LearnedCDRQuery (pos_emb + cdr_type + epitope_context) | | | v | Cross-Attention(query, ag_h, ag_h) -> attn_out | +---> ESM-2 (frozen, precomputed) -> esm_cdr_emb -> esm_proj -> esm_h | v [no_gate mode] concat(cdr_h, attn_out, esm_h) -> (L, 768) | v seq_head(combined) -> logits | +---> softmax(logits) -> probs | | | v | AntigenClassificationHead(probs, antigen_embs) -> ag_cls_loss [NEW] | v Losses: seq + coord + pairing + dock + shadow + gdpp + aux + antigen_cls

If swap similarity drops significantly, conditioning works and AAR should follow.

Usage

# Step 1: Precompute ESM embeddings (run once, takes ~30 min) python precompute_esm.py --esm_model esm2_t33_650M_UR50D --cdr_types 1,2,3 # Step 2: Train # Quick test python chimera_trainer.py --max_epoch 2 --no-wandb # Full training python chimera_trainer.py --split epitope_group --cdr_type 3 # Test only (load checkpoint) python chimera_trainer.py --test_only --checkpoint /path/to/best.pt

Config

Key settings in config.yaml:

dataset: esm_model: esm2_t33_650M_UR50D esm_embeddings_dir: ./data/processed/esm_embeddings model: # Attention mode: 'no_gate' (recommended), 'direct', or false use_direct_attn: no_gate # Remove dead gate, keep concat(cdr_h, attn_out) lambda_antigen_cls: 0.5 # antigen classification weight (cycle consistency) lambda_infonce: 0.0 # disabled (empirically inert) lambda_mine: 0.0 # disabled (empirically inert) zeta: 0.0 # latent coupling disabled # ESM-2 embeddings use_esm: true # Enable precomputed ESM embeddings esm_dim: 1280 # ESM-2 650M dimension training: anneal_base: 0.97 # slower LR decay (0.9 -> 0.97)

Files

  • code/models.py - DockDesigner with use_direct_attn, AntigenClassificationHead, and ESM projection
  • code/v2_dataset.py - Dataset with ESM embeddings loading
  • config.yaml - defaults (direct_attn on, MI losses off, ESM on)
  • chimera_trainer.py - 13-tuple loss handling, ESM integration
  • precompute_esm.py - One-time ESM-2 embedding precomputation

About

AgForce generates realistic and diverse antigen-conditioned antibody CDR design via multiple choice learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Footer

© 2026 GitHub, Inc.