Source code for rush.boltz

"""
Boltz module for the Rush Python client.

Boltz predicts folded structures from protein sequences, optional ligands, and
MSA inputs. The fetched output is parsed into Python-friendly result objects,
while the saved output writes the model and JSON artifacts into the workspace.

Usage::

    from rush import boltz

    result = boltz.fold([ProteinSequence(...)]).fetch()
    print(next(result).metrics.confidence_score)
"""

import base64
import json
import sys
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Self

import numpy as np
import numpy.typing as npt
from gql.transport.exceptions import TransportQueryError

from ._rex import dict_to_vec_of_tuples_str, optional_str
from .convert import _single_trc, from_json, from_pdb
from .mol import TRC
from .objects import (
    RushObject,
    TRCPaths,
    TRCRef,
    _json_content_name,
    save_json,
    upload_object,
)
from .runs import Run, RunOpts, RunSpec
from .session import _submit_rex

# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------


[docs] @dataclass class Modification: position: int ccd: str
[docs] @dataclass class ProteinSequence: id: list[str] sequence: str msa: Path | str | RushObject modifications: list[Modification] | None = None cyclic: bool | None = None def _to_rex(self): match self.msa: case Path() | str(): msa_vobj = upload_object(self.msa) case RushObject(): msa_vobj = self.msa.to_dict() return Template( """(boltz2_rex::Sequence::Protein { id = $id, sequence = "$sequence", msa = VirtualObject { path = "$msa", format = ObjectFormat::bin, size = 0 }, modifications = None, cyclic = $cyclic, })""" ).substitute( id=f"[{', '.join([f'"{v}"' for v in self.id])}]", sequence=self.sequence, msa=msa_vobj["path"], cyclic=optional_str(self.cyclic), )
[docs] @dataclass class LigandSequence: id: list[str] smiles: str def _to_rex(self): return Template( """(boltz2_rex::Sequence::Ligand { id = $id, smiles = "$smiles", })""" ).substitute( id=f"[{', '.join([f'"{v}"' for v in self.id])}]", smiles=self.smiles, )
# --------------------------------------------------------------------------- # Result types # ---------------------------------------------------------------------------
[docs] @dataclass class Metrics: """Summary confidence metrics returned by Boltz.""" confidence_score: float ptm: float iptm: float ligand_iptm: float protein_iptm: float complex_plddt: float complex_iplddt: float complex_pde: float complex_ipde: float
[docs] @dataclass class Affinities: """Optional affinity predictions returned for binding runs.""" affinity_pred_value: float affinity_probability_binary: float affinity_pred_value1: float affinity_probability_binary1: float affinity_pred_value2: float affinity_probability_binary2: float
[docs] @dataclass class Result: """ Parsed Boltz fold result. Returned by ``ResultRef.fetch()`` — one per diffusion sample. """ model: TRC metrics: Metrics plddt: npt.NDArray[np.float32] pae: npt.NDArray[np.float32] affinities: Affinities | None = None
[docs] @dataclass(frozen=True) class ResultPaths: """Workspace paths for a saved Boltz result bundle.""" model: TRCPaths metrics: Path plddt: Path pae: Path affinities: Path | None = None
def _decode_float_array(output: dict[str, Any]) -> npt.NDArray[np.float32]: raw = base64.b64decode(output["data"]) shape = tuple(int(dim) for dim in output["shape"]) return np.frombuffer(raw, dtype=np.dtype("<f4")).reshape(shape)
[docs] @dataclass(frozen=True) class DiffusionSampleRef: """Parsed reference to a single Boltz diffusion sample.""" model: TRCRef metrics: dict[str, Any] plddt: RushObject pae: RushObject affinities: dict[str, Any] | None
[docs] @dataclass(frozen=True) class ResultRef: """Lightweight reference to Boltz outputs in the Rush object store. Each element of *samples* is a parsed :class:`DiffusionSampleRef` for one diffusion sample. """ diffusion_samples: list[DiffusionSampleRef] def __getitem__(self, index: int) -> DiffusionSampleRef: return self.diffusion_samples[index] def __len__(self) -> int: return len(self.diffusion_samples) def __iter__(self) -> Iterator[DiffusionSampleRef]: return iter(self.diffusion_samples)
[docs] @classmethod def from_raw_output(cls, res: Any) -> Self: """Parse raw ``collect_run`` output into a ``ResultRef``.""" if not isinstance(res, list) or len(res) == 0: raise ValueError( f"boltz output received unexpected format: {type(res).__name__}" ) # collect_run returns [[sample0, sample1, ...]] — outer list wraps # the single run, inner list contains one tuple per diffusion sample. out = res[0] diffusion_samples: list[DiffusionSampleRef] = [] for item in out: model_obj, metrics, plddt_obj, pae_obj, affinities = item topo, resid, chain = model_obj diffusion_samples.append( DiffusionSampleRef( model=TRCRef( topology=RushObject.from_dict(topo), residues=RushObject.from_dict(resid), chains=RushObject.from_dict(chain), ), metrics=metrics, plddt=RushObject.from_dict(plddt_obj), pae=RushObject.from_dict(pae_obj), affinities=affinities, ) ) return cls(diffusion_samples=diffusion_samples)
[docs] def fetch(self) -> Iterator[Result]: """Download Boltz outputs and parse into Python objects. Yields one :class:`Result` per diffusion sample. Each sample is downloaded lazily on iteration — stop early to skip downloads. """ for sample in self.diffusion_samples: yield Result( model=sample.model.fetch(), metrics=Metrics(**sample.metrics), plddt=_decode_float_array(sample.plddt.fetch_dict()), pae=_decode_float_array(sample.pae.fetch_dict()), affinities=( Affinities(**sample.affinities) if sample.affinities is not None else None ), )
[docs] def save(self) -> Iterator[ResultPaths]: """Download Boltz outputs and save to the workspace. Yields one :class:`ResultPaths` per diffusion sample. Each sample is downloaded lazily on iteration — stop early to skip downloads. """ for sample in self.diffusion_samples: yield ResultPaths( model=sample.model.save(), metrics=save_json( sample.metrics, name=_json_content_name("boltz_metrics", sample.metrics), ), plddt=sample.plddt.save(), pae=sample.pae.save(), affinities=( save_json( sample.affinities, name=_json_content_name("boltz_affinities", sample.affinities), ) if sample.affinities is not None else None ), )
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def fold( sequences: list[ProteinSequence | LigandSequence], recycling_steps: int | None = None, sampling_steps: int | None = None, diffusion_samples: int | None = None, step_scale: float | None = None, affinity_binder_chain_id: str | None = None, affinity_mw_correction: bool | None = None, sampling_steps_affinity: int | None = None, diffusion_samples_affinity: bool | None = None, max_msa_seqs: int | None = None, subsample_msa: bool | None = None, num_subsampled_msa: int | None = None, use_potentials: bool | None = None, seed: int | None = None, template_path: Path | str | None = None, template_threshold_angstroms: float | None = None, template_chain_mapping: dict[str, str] | None = None, run_spec: RunSpec = RunSpec(gpus=1), run_opts: RunOpts = RunOpts(), ) -> Run[ResultRef]: """ Submit a Boltz fold job for the given protein/ligand *sequences*. Returns a :class:`~rush.runs.Run` handle. Call ``.collect()`` to get a :class:`ResultRef`, then ``.fetch()`` or ``.save()`` on that ref. """ # If necessary, upload template TRC inputs has_template = template_path is not None if template_path is not None: if isinstance(template_path, str): template_path = Path(template_path) with open(template_path) as f: if template_path.suffix == ".pdb": trc = from_pdb(f.read()) else: trc = from_json(json.load(f)) trc = _single_trc(trc, template_path) trc_ref = TRCRef.upload(trc) # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, boltz = λ topology residues chains → try_boltz2_rex ($run_spec) (boltz2_rex::Boltz2Config { recycling_steps = $maybe_recycling_steps, sampling_steps = $maybe_sampling_steps, diffusion_samples = $maybe_diffusion_samples, step_scale = $maybe_step_scale, affinity_binder_chain_id = $maybe_affinity_binder_chain_id, affinity_mw_correction = $maybe_affinity_mw_correction, sampling_steps_affinity = $maybe_sampling_steps_affinity, diffusion_samples_affinity = $maybe_diffusion_samples_affinity, max_msa_seqs = $maybe_max_msa_seqs, subsample_msa = $maybe_subsample_msa, num_subsampled_msa = $maybe_num_subsampled_msa, use_potentials = $maybe_use_potentials, seed = $maybe_seed, template_threshold_angstroms = $maybe_template_threshold_angstroms, template_chain_mapping = $maybe_template_chain_mapping, }) $sequences $template_trc_expr in boltz "$topology_vobj_path" "$residues_vobj_path" "$chains_vobj_path" """).substitute( run_spec=run_spec._to_rex(), maybe_recycling_steps=optional_str(recycling_steps), maybe_sampling_steps=optional_str(sampling_steps), maybe_diffusion_samples=optional_str(diffusion_samples), maybe_step_scale=optional_str(step_scale), maybe_affinity_binder_chain_id=optional_str(affinity_binder_chain_id), maybe_affinity_mw_correction=optional_str(affinity_mw_correction), maybe_sampling_steps_affinity=optional_str(sampling_steps_affinity), maybe_diffusion_samples_affinity=optional_str(diffusion_samples_affinity), maybe_max_msa_seqs=optional_str(max_msa_seqs), maybe_subsample_msa=optional_str(subsample_msa), maybe_num_subsampled_msa=optional_str(num_subsampled_msa), maybe_use_potentials=optional_str(use_potentials), maybe_seed=optional_str(seed), maybe_template_threshold_angstroms=optional_str(template_threshold_angstroms), maybe_template_chain_mapping=( f"(Some {dict_to_vec_of_tuples_str(template_chain_mapping)})" if template_chain_mapping is not None else "None" ), sequences=f"[\n {',\n '.join([f'{seq._to_rex()}' for seq in sequences])},\n ]", template_trc_expr=( "(Some ((obj_j topology), (obj_j residues), (obj_j chains)) )" if has_template else "None" ), topology_vobj_path=trc_ref.topology.path if has_template else "", residues_vobj_path=trc_ref.residues.path if has_template else "", chains_vobj_path=trc_ref.chains.path if has_template else "", ) try: return Run(_submit_rex(rex, run_opts), ResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise