Source code for rush.exess._optimization

"""
EXESS geometry optimization for the Rush Python client.

Quick Links
-----------

- :func:`rush.exess.optimization`
- :class:`rush.exess.OptimizationResult`
"""

import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Literal, Self

from gql.transport.exceptions import TransportQueryError

from .._rex import optional_str
from ..mol import TRC, Residues, Topology
from ..objects import (
    RushObject,
    TRCRef,
    _to_residues_vobj,
    _to_topology_vobj,
)
from ..runs import Run, RunOpts, RunSpec
from ..session import _submit_rex
from ._energy import (
    AuxBasisT,
    BasisT,
    KSDFTKeywords,
    MethodT,
    SCFKeywords,
    StandardOrientationT,
    System,
    _KSDFTDefault,
)

# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------


[docs] @dataclass class OptimizationConvergenceCriteria: metric: str | None = None gradient_threshold: float | None = None delta_energy_threshold: float | None = None step_component_threshold: float | None = None def _to_rex(self, reference_fragment: int | None = None): return Template( """Some (exess_geo_opt_rex::OptimizationConvergenceCriteria { metric = $maybe_metric, gradient_threshold = $maybe_gradient_threshold, delta_energy_threshold = $maybe_delta_energy_threshold, step_component_threshold = $maybe_step_component_threshold, })""" ).substitute( maybe_metric=optional_str(self.metric), # TODO: enum prefix maybe_gradient_threshold=optional_str(self.gradient_threshold), maybe_delta_energy_threshold=optional_str(self.delta_energy_threshold), maybe_step_component_threshold=optional_str(self.step_component_threshold), )
type CoordinateSystemT = Literal["Cartesian", "NaturalInternal", "DelocalisedInternal"] type HessianGuessTypeT = Literal["Identity", "ScaledIdentity", "Schlegel", "Lindh"] type OptimizationAlgorithmTypeT = Literal[ "EigenvectorFollowing", "TrustRegionAugmentedHessian", "LBFGS" ]
[docs] @dataclass class TrustRegionKeywords: initial_radius: float | None = None max_radius: float | None = None min_radius: float | None = None increase_factor: float | None = None decrease_factor: float | None = None constrict_factor: float | None = None increase_threshold: float | None = None decrease_threshold: float | None = None rejection_threshold: float | None = None def _to_rex(self): return Template( """Some (exess_geo_opt_rex::TrustRegionKeywords { initial_radius = $maybe_initial_radius, max_radius = $maybe_max_radius, min_radius = $maybe_min_radius, increase_factor = $maybe_increase_factor, decrease_factor = $maybe_decrease_factor, constrict_factor = $maybe_constrict_factor, increase_threshold = $maybe_increase_threshold, decrease_threshold = $maybe_decrease_threshold, rejection_threshold = $maybe_rejection_threshold, })""" ).substitute( maybe_initial_radius=optional_str(self.initial_radius), maybe_max_radius=optional_str(self.max_radius), maybe_min_radius=optional_str(self.min_radius), maybe_increase_factor=optional_str(self.increase_factor), maybe_decrease_factor=optional_str(self.decrease_factor), maybe_constrict_factor=optional_str(self.constrict_factor), maybe_increase_threshold=optional_str(self.increase_threshold), maybe_decrease_threshold=optional_str(self.decrease_threshold), maybe_rejection_threshold=optional_str(self.rejection_threshold), )
type LBFGSLinesearchT = Literal[ "MoreThuente", "BacktrackingArmijo", "BacktrackingWolfe", "BacktrackingStrongWolfe" ]
[docs] @dataclass class LBFGSKeywords: linesearch: LBFGSLinesearchT = "BacktrackingStrongWolfe" n_corrections: int | None = None epsilon: float | None = None max_linesearch: int | None = None gtol: float | None = None def _to_rex(self): return Template( """Some (exess_geo_opt_rex::LBFGSKeywords { linesearch = $maybe_linesearch, n_corrections = $maybe_n_corrections, epsilon = $maybe_epsilon, max_linesearch = $maybe_max_linesearch, gtol = $maybe_gtol, })""" ).substitute( maybe_linesearch=optional_str( self.linesearch, "exess_geo_opt_rex::LBFGSLinesearch::" ), maybe_n_corrections=optional_str(self.n_corrections), maybe_epsilon=optional_str(self.epsilon), maybe_max_linesearch=optional_str(self.max_linesearch), maybe_gtol=optional_str(self.gtol), )
[docs] @dataclass class OptimizationKeywords: convergence_criteria: OptimizationConvergenceCriteria | None = None optimizer_reset_interval: int | None = None coordinate_system: CoordinateSystemT | None = None constraints: list[list[int]] | None = None hessian_guess: HessianGuessTypeT | None = None algorithm: OptimizationAlgorithmTypeT | None = None lbfgs_keywords: LBFGSKeywords | None = None frozen_distance_slippage_tolerance_angstroms: float | None = None frozen_angle_slippage_tolerance_degrees: float | None = None trust_region_keywords: TrustRegionKeywords | None = None fixed_atoms: list[int] | None = None free_atoms: list[int] | None = None fixed_fragments: list[int] | None = None free_fragments: list[int] | None = None fix_heavy: bool | None = None def _to_rex(self, max_iters): return Template( """Some (exess_geo_opt_rex::OptimizationKeywords { max_iters = $max_iters, convergence_criteria = $maybe_convergence_criteria, optimizer_reset_interval = $maybe_optimizer_reset_interval, coordinate_system = $maybe_coordinate_system, constraints = $maybe_constraints, hessian_guess = $maybe_hessian_guess, algorithm = $maybe_algorithm, lbfgs_keywords = $maybe_lbfgs_keywords, frozen_distance_slippage_tolerance_angstroms = $maybe_frozen_distance_slippage_tolerance_angstroms, frozen_angle_slippage_tolerance_degrees = $maybe_frozen_angle_slippage_tolerance_degrees, trust_region_keywords = $maybe_trust_region_keywords, fixed_atoms = $maybe_fixed_atoms, free_atoms = $maybe_free_atoms, fixed_fragments = $maybe_fixed_fragments, free_fragments = $maybe_free_fragments, fix_heavy = $maybe_fix_heavy, })""" ).substitute( max_iters=max_iters, maybe_convergence_criteria=( self.convergence_criteria._to_rex() if self.convergence_criteria is not None else "None" ), maybe_optimizer_reset_interval=optional_str(self.optimizer_reset_interval), maybe_coordinate_system=optional_str( self.coordinate_system, "exess_geo_opt_rex::CoordinateSystem::" ), # maybe_constraints=optional_list( # self.constraints, # lambda constraint: f"vec![{', '.join(f'exess_geo_opt_rex::AtomRef ({atom})' for atom in constraint)}]", # ), maybe_constraints="None", # TODO maybe_hessian_guess=optional_str( self.hessian_guess, "exess_geo_opt_rex::HessianGuessType::" ), maybe_algorithm=optional_str( self.algorithm, "exess_geo_opt_rex::OptimizationAlgorithmType::" ), maybe_lbfgs_keywords=( self.lbfgs_keywords._to_rex() if self.lbfgs_keywords is not None else "None" ), maybe_frozen_distance_slippage_tolerance_angstroms=optional_str( self.frozen_distance_slippage_tolerance_angstroms ), maybe_frozen_angle_slippage_tolerance_degrees=optional_str( self.frozen_angle_slippage_tolerance_degrees ), maybe_trust_region_keywords=( self.trust_region_keywords._to_rex() if self.trust_region_keywords is not None else "None" ), maybe_fixed_atoms=optional_str(self.fixed_atoms), maybe_free_atoms=optional_str(self.free_atoms), maybe_fixed_fragments=optional_str(self.fixed_fragments), maybe_free_fragments=optional_str(self.free_fragments), maybe_fix_heavy=optional_str(self.fix_heavy), )
# --------------------------------------------------------------------------- # Result types # ---------------------------------------------------------------------------
[docs] @dataclass class OptimizationStep: total_energy: float max_gradient_component: float
[docs] @dataclass class OptimizationResult: trajectory: list[Topology] steps: list[OptimizationStep]
[docs] @dataclass(frozen=True) class OptimizationResultPaths: trajectory: Path steps: Path
[docs] @dataclass(frozen=True) class OptimizationResultRef: """Lightweight reference to optimization outputs in the Rush object store.""" trajectory: RushObject steps: RushObject
[docs] @classmethod def from_raw_output(cls, res: Any) -> Self: """Parse raw ``collect_run`` output into an ``OptimizationResultRef``.""" if not isinstance(res, list) or len(res) != 2: raise ValueError( "optimization should return exactly 2 outputs (trajectory + steps), " f"got {type(res).__name__} with {len(res) if hasattr(res, '__len__') else '?'} items." ) return cls( trajectory=RushObject.from_dict(res[0]), steps=RushObject.from_dict(res[1]), )
[docs] def fetch(self) -> OptimizationResult: """Download optimization outputs and parse into Python objects.""" trajectory_raw = self.trajectory.fetch_list() steps_raw = self.steps.fetch_list() trajectory = [Topology.from_json(t) for t in trajectory_raw] steps = [OptimizationStep(**step) for step in steps_raw] return OptimizationResult(trajectory=trajectory, steps=steps)
[docs] def save(self) -> OptimizationResultPaths: """Download optimization outputs and save to the workspace.""" return OptimizationResultPaths( trajectory=self.trajectory.save(), steps=self.steps.save(), )
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def optimization( mol: TRC | TRCRef | tuple[Path | str | RushObject | Topology, Path | str | RushObject | Residues] | Path | str | RushObject | Topology, max_iters: int, optimization_keywords: OptimizationKeywords = OptimizationKeywords(), method: MethodT = "RestrictedKSDFT", basis: BasisT = "cc-pVDZ", aux_basis: AuxBasisT | None = None, standard_orientation: StandardOrientationT | None = None, force_cartesian_basis_sets: bool | None = None, scf_keywords: SCFKeywords | None = None, ksdft_keywords: KSDFTKeywords | _KSDFTDefault | None = _KSDFTDefault.DEFAULT, qm_fragments: list[int] | None = None, mm_fragments: list[int] | None = None, system: System | None = None, run_spec: RunSpec = RunSpec(gpus=1), run_opts: RunOpts = RunOpts(), ) -> Run[OptimizationResultRef]: """ Submit a geometry optimization for the topology at *topology_path*. Returns a :class:`~rush.runs.Run` handle. Call ``.fetch()`` to get the parsed trajectory and optimization steps, or ``.save()`` to write them to disk. """ ksdft_keywords = KSDFTKeywords.resolve(ksdft_keywords, method) # Upload inputs residues_vobj = None match mol: case TRC() | TRCRef(): topology_vobj = _to_topology_vobj(mol.topology) residues_vobj = _to_residues_vobj(mol.residues) case (t, r): topology_vobj = _to_topology_vobj(t) residues_vobj = _to_residues_vobj(r) case _: topology_vobj = _to_topology_vobj(mol) # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, exess = λ topology residues → try_exess_geo_opt_rex ($run_spec) (exess_geo_opt_rex::OptimizationParams { schema_version = "0.2.0", external_charges = None, model = Some (exess_geo_opt_rex::Model { method = exess_geo_opt_rex::Method::$method, basis = "$basis", aux_basis = $maybe_aux_basis, standard_orientation = $maybe_standard_orientation, force_cartesian_basis_sets = $maybe_force_cartesian_basis_sets, }), system = $maybe_system, keywords = exess_geo_opt_rex::Keywords { scf = $maybe_scf_keywords, ks_dft = $maybe_ks_keywords, rtat = None, frag = None, boundary = None, log = None, dynamics = None, integrals = None, debug = None, export = None, guess = None, force_field = None, optimization = $maybe_optimization_keywords, hessian = None, gradient = None, qmmm = $maybe_qmmm_keywords, machine_learning = None, regions = $maybe_regions, }, }) [ (obj_j topology) ] $residues_expr in exess "$topology_vobj_path" "$residues_vobj_path" """).substitute( run_spec=run_spec._to_rex(), method=method, basis=basis, maybe_aux_basis=optional_str(aux_basis), maybe_standard_orientation=optional_str( standard_orientation, "exess_rex::StandardOrientation::" ), maybe_force_cartesian_basis_sets=optional_str(force_cartesian_basis_sets), maybe_system=system._to_rex() if system is not None else "None", maybe_scf_keywords=( scf_keywords._to_rex() if scf_keywords is not None else "None" ), maybe_ks_keywords=( ksdft_keywords._to_rex() if ksdft_keywords is not None else "None" ), maybe_optimization_keywords=( optimization_keywords._to_rex(max_iters) if optimization_keywords is not None else "None" ), maybe_qmmm_keywords=( """Some (exess_qmmm_rex::QMMMKeywords { n_timesteps = 1, dt_ps = 0.002, temperature_kelvin = 290.0, pressure_atm = None, minimisation = None, trajectory = None, restraints = None, energy_csv = None, })""" if mm_fragments or (qm_fragments is not None) else "None" ), maybe_regions=( Template( """Some (exess_qmmm_rex::RegionKeywords { qm_fragments = $maybe_qm_fragments, mm_fragments = $maybe_mm_fragments, ml_fragments = Some [], })""" ).substitute( maybe_qm_fragments=optional_str(qm_fragments), maybe_mm_fragments=optional_str(mm_fragments), ) if not (qm_fragments is None and mm_fragments is None) else "None" ), residues_expr=( "(Some [ (obj_j residues) ])" if residues_vobj is not None else "None" ), topology_vobj_path=topology_vobj["path"], residues_vobj_path=residues_vobj["path"] if residues_vobj is not None else "", ) try: return Run(_submit_rex(rex, run_opts), OptimizationResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise