Source code for rush.nnxtb

"""
NN-xTB module for the Rush Python client.

NN-xTB reparameterizes xTB with a neural network to approach DFT-level accuracy
while keeping xTB-like speed. It supports arbitrary charge and spin states and
is well-suited for large-scale screening where fast, per-atom forces or
vibrational frequencies are needed. Frequency calculations are more expensive.

Usage::

    from rush import nnxtb

    result = nnxtb.energy("mol.json").fetch()
    print(result.energy_mev)
"""

import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Self

from gql.transport.exceptions import TransportQueryError

from ._rex import optional_str
from .mol import TRC, Topology
from .objects import (
    RushObject,
    TRCRef,
    _to_topology_vobj,
)
from .runs import Run, RunOpts, RunSpec
from .session import _submit_rex

# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------


[docs] @dataclass class Result: """Parsed nn-xTB calculation results.""" energy_mev: float forces_mev_per_angstrom: list[tuple[float, float, float]] | None = None frequencies_inv_cm: list[float] | None = None
[docs] @dataclass(frozen=True) class ResultPaths: """Workspace path for saved nn-xTB output.""" output: Path
[docs] @dataclass(frozen=True) class ResultRef: """Lightweight reference to nn-xTB output in the Rush object store.""" output: RushObject
[docs] @classmethod def from_raw_output(cls, res: Any) -> Self: """Parse raw ``collect_run`` output into a ``ResultRef``.""" if not isinstance(res, list) or len(res) != 1: raise ValueError( f"nnxtb should return a list with exactly 1 output, " f"got {type(res).__name__}" f"{f' with {len(res)} items' if hasattr(res, '__len__') else ''}." ) return cls(output=RushObject.from_dict(res[0]))
[docs] def fetch(self) -> Result: """Download nn-xTB output and parse into Python objects.""" output = self.output.fetch_dict() return Result(**output)
[docs] def save(self) -> ResultPaths: """Download nn-xTB output and save to the workspace.""" return ResultPaths(output=self.output.save())
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def energy( mol: TRC | TRCRef | Path | str | RushObject | Topology, compute_forces: bool | None = None, compute_frequencies: bool | None = None, multiplicity: int | None = None, run_spec: RunSpec = RunSpec(gpus=1, storage=100), run_opts: RunOpts = RunOpts(), ) -> Run[ResultRef]: """ Submit an nn-xTB energy calculation for the topology at *topology_path*. Returns a :class:`~rush.runs.Run` handle. Call ``.fetch()`` to get the parsed result, or ``.save()`` to write it to disk. """ # Upload inputs topology_vobj = _to_topology_vobj(mol) charge = 0 # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, nnxtb = λ topology → try_nnxtb_rex ($run_spec) (nnxtb_rex::NnxtbConfig { compute_forces = $maybe_compute_forces, compute_frequencies = $maybe_compute_frequencies, charge = $maybe_charge, multiplicity = $maybe_multiplicity, }) (obj_j topology) in nnxtb "$topology_vobj_path" """).substitute( run_spec=run_spec._to_rex(), maybe_compute_forces=optional_str(compute_forces), maybe_compute_frequencies=optional_str(compute_frequencies), maybe_charge=f"Some (int {charge})" if charge is not None else None, maybe_multiplicity=optional_str(multiplicity), topology_vobj_path=topology_vobj["path"], ) try: return Run(_submit_rex(rex, run_opts), ResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise