Source code for rush.pbsa

"""
PBSA module for the Rush Python client.

Computes solvation energies using the Poisson-Boltzmann Surface Area method.

Usage::

    from rush import pbsa

    result = pbsa.solvation_energy("mol.json", ...).fetch()
    print(result.solvation_energy)
"""

import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from string import Template
from typing import Any, Self

from gql.transport.exceptions import TransportQueryError

from ._rex import float_to_str
from .mol import TRC, Topology
from .objects import (
    RushObject,
    TRCRef,
    _json_content_name,
    _to_topology_vobj,
    save_json,
)
from .runs import Run, RunOpts, RunSpec
from .session import _submit_rex

# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------


[docs] @dataclass class Result: """Parsed PBSA solvation energy results (all values in Hartrees).""" solvation_energy: float polar_solvation_energy: float nonpolar_solvation_energy: float
[docs] @dataclass(frozen=True) class ResultPaths: """Workspace path for saved PBSA output.""" output: Path
[docs] @dataclass(frozen=True) class ResultRef: """Lightweight reference to PBSA output. PBSA results are small enough to be returned inline (three floats), so no object store download is needed. """ solvation_energy: float polar_solvation_energy: float nonpolar_solvation_energy: float
[docs] @classmethod def from_raw_output(cls, res: Any) -> Self: """Parse raw ``collect_run`` output into a ``ResultRef``.""" if isinstance(res, list) and len(res) == 3: return cls( solvation_energy=float(res[0]), polar_solvation_energy=float(res[1]), nonpolar_solvation_energy=float(res[2]), ) raise ValueError( f"pbsa should return exactly 3 float outputs, " f"got {type(res).__name__} with {len(res) if hasattr(res, '__len__') else '?'} items." )
[docs] def fetch(self) -> Result: """Return parsed PBSA results (no download needed — data is inline).""" return Result( solvation_energy=self.solvation_energy, polar_solvation_energy=self.polar_solvation_energy, nonpolar_solvation_energy=self.nonpolar_solvation_energy, )
[docs] def save(self) -> ResultPaths: """Save PBSA results as JSON to the workspace.""" output_json = asdict(self.fetch()) return ResultPaths( output=save_json( output_json, name=_json_content_name("pbsa_output", output_json), ), )
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def solvation_energy( mol: TRC | TRCRef | Path | str | RushObject | Topology, solute_dielectric: float, solvent_dielectric: float, solvent_radius: float, ion_concentration: float, temperature: float, spacing: float, sasa_gamma: float, sasa_beta: float, sasa_n_samples: int, convergence: float, box_size_factor: float, run_spec: RunSpec = RunSpec(gpus=1), run_opts: RunOpts = RunOpts(), ) -> Run[ResultRef]: """ Submit a PBSA solvation energy calculation for the topology at *topology_path*. Returns a :class:`~rush.runs.Run` handle. Call ``.fetch()`` to get the parsed result, or ``.save()`` to write it to disk as JSON. """ # Upload inputs topology_vobj = _to_topology_vobj(mol) # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, pbsa = λ topology → try_pbsa_rex ($run_spec) (pbsa_rex::PBSAParameters { solute_dielectric = $solute_dielectric, solvent_dielectric = $solvent_dielectric, solvent_radius = $solvent_radius, ion_concentration = $ion_concentration, temperature = $temperature, spacing = $spacing, sasa_gamma = $sasa_gamma, sasa_beta = $sasa_beta, sasa_n_samples = $sasa_n_samples, convergence = $convergence, box_size_factor = $box_size_factor, }) (obj_j topology) in pbsa "$topology_vobj_path" """).substitute( run_spec=run_spec._to_rex(), solute_dielectric=float_to_str(solute_dielectric), solvent_dielectric=float_to_str(solvent_dielectric), solvent_radius=float_to_str(solvent_radius), ion_concentration=float_to_str(ion_concentration), temperature=float_to_str(temperature), spacing=float_to_str(spacing), sasa_gamma=float_to_str(sasa_gamma), sasa_beta=float_to_str(sasa_beta), sasa_n_samples=sasa_n_samples, convergence=float_to_str(convergence), box_size_factor=float_to_str(box_size_factor), topology_vobj_path=topology_vobj["path"], ) try: return Run(_submit_rex(rex, run_opts), ResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise