"""
EXESS geometry optimization for the Rush Python client.
Quick Links
-----------
- :func:`rush.exess.optimization`
- :class:`rush.exess.OptimizationResult`
"""
import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Literal, Self
from gql.transport.exceptions import TransportQueryError
from .._rex import optional_str
from ..mol import TRC, Residues, Topology
from ..objects import (
RushObject,
TRCRef,
_to_residues_vobj,
_to_topology_vobj,
)
from ..runs import Run, RunOpts, RunSpec
from ..session import _submit_rex
from ._energy import (
AuxBasisT,
BasisT,
KSDFTKeywords,
MethodT,
SCFKeywords,
StandardOrientationT,
System,
_KSDFTDefault,
)
# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OptimizationConvergenceCriteria:
metric: str | None = None
gradient_threshold: float | None = None
delta_energy_threshold: float | None = None
step_component_threshold: float | None = None
def _to_rex(self, reference_fragment: int | None = None):
return Template(
"""Some (exess_geo_opt_rex::OptimizationConvergenceCriteria {
metric = $maybe_metric,
gradient_threshold = $maybe_gradient_threshold,
delta_energy_threshold = $maybe_delta_energy_threshold,
step_component_threshold = $maybe_step_component_threshold,
})"""
).substitute(
maybe_metric=optional_str(self.metric), # TODO: enum prefix
maybe_gradient_threshold=optional_str(self.gradient_threshold),
maybe_delta_energy_threshold=optional_str(self.delta_energy_threshold),
maybe_step_component_threshold=optional_str(self.step_component_threshold),
)
type CoordinateSystemT = Literal["Cartesian", "NaturalInternal", "DelocalisedInternal"]
type HessianGuessTypeT = Literal["Identity", "ScaledIdentity", "Schlegel", "Lindh"]
type OptimizationAlgorithmTypeT = Literal[
"EigenvectorFollowing", "TrustRegionAugmentedHessian", "LBFGS"
]
[docs]
@dataclass
class TrustRegionKeywords:
initial_radius: float | None = None
max_radius: float | None = None
min_radius: float | None = None
increase_factor: float | None = None
decrease_factor: float | None = None
constrict_factor: float | None = None
increase_threshold: float | None = None
decrease_threshold: float | None = None
rejection_threshold: float | None = None
def _to_rex(self):
return Template(
"""Some (exess_geo_opt_rex::TrustRegionKeywords {
initial_radius = $maybe_initial_radius,
max_radius = $maybe_max_radius,
min_radius = $maybe_min_radius,
increase_factor = $maybe_increase_factor,
decrease_factor = $maybe_decrease_factor,
constrict_factor = $maybe_constrict_factor,
increase_threshold = $maybe_increase_threshold,
decrease_threshold = $maybe_decrease_threshold,
rejection_threshold = $maybe_rejection_threshold,
})"""
).substitute(
maybe_initial_radius=optional_str(self.initial_radius),
maybe_max_radius=optional_str(self.max_radius),
maybe_min_radius=optional_str(self.min_radius),
maybe_increase_factor=optional_str(self.increase_factor),
maybe_decrease_factor=optional_str(self.decrease_factor),
maybe_constrict_factor=optional_str(self.constrict_factor),
maybe_increase_threshold=optional_str(self.increase_threshold),
maybe_decrease_threshold=optional_str(self.decrease_threshold),
maybe_rejection_threshold=optional_str(self.rejection_threshold),
)
type LBFGSLinesearchT = Literal[
"MoreThuente", "BacktrackingArmijo", "BacktrackingWolfe", "BacktrackingStrongWolfe"
]
[docs]
@dataclass
class LBFGSKeywords:
linesearch: LBFGSLinesearchT = "BacktrackingStrongWolfe"
n_corrections: int | None = None
epsilon: float | None = None
max_linesearch: int | None = None
gtol: float | None = None
def _to_rex(self):
return Template(
"""Some (exess_geo_opt_rex::LBFGSKeywords {
linesearch = $maybe_linesearch,
n_corrections = $maybe_n_corrections,
epsilon = $maybe_epsilon,
max_linesearch = $maybe_max_linesearch,
gtol = $maybe_gtol,
})"""
).substitute(
maybe_linesearch=optional_str(
self.linesearch, "exess_geo_opt_rex::LBFGSLinesearch::"
),
maybe_n_corrections=optional_str(self.n_corrections),
maybe_epsilon=optional_str(self.epsilon),
maybe_max_linesearch=optional_str(self.max_linesearch),
maybe_gtol=optional_str(self.gtol),
)
[docs]
@dataclass
class OptimizationKeywords:
convergence_criteria: OptimizationConvergenceCriteria | None = None
optimizer_reset_interval: int | None = None
coordinate_system: CoordinateSystemT | None = None
constraints: list[list[int]] | None = None
hessian_guess: HessianGuessTypeT | None = None
algorithm: OptimizationAlgorithmTypeT | None = None
lbfgs_keywords: LBFGSKeywords | None = None
frozen_distance_slippage_tolerance_angstroms: float | None = None
frozen_angle_slippage_tolerance_degrees: float | None = None
trust_region_keywords: TrustRegionKeywords | None = None
fixed_atoms: list[int] | None = None
free_atoms: list[int] | None = None
fixed_fragments: list[int] | None = None
free_fragments: list[int] | None = None
fix_heavy: bool | None = None
def _to_rex(self, max_iters):
return Template(
"""Some (exess_geo_opt_rex::OptimizationKeywords {
max_iters = $max_iters,
convergence_criteria = $maybe_convergence_criteria,
optimizer_reset_interval = $maybe_optimizer_reset_interval,
coordinate_system = $maybe_coordinate_system,
constraints = $maybe_constraints,
hessian_guess = $maybe_hessian_guess,
algorithm = $maybe_algorithm,
lbfgs_keywords = $maybe_lbfgs_keywords,
frozen_distance_slippage_tolerance_angstroms = $maybe_frozen_distance_slippage_tolerance_angstroms,
frozen_angle_slippage_tolerance_degrees = $maybe_frozen_angle_slippage_tolerance_degrees,
trust_region_keywords = $maybe_trust_region_keywords,
fixed_atoms = $maybe_fixed_atoms,
free_atoms = $maybe_free_atoms,
fixed_fragments = $maybe_fixed_fragments,
free_fragments = $maybe_free_fragments,
fix_heavy = $maybe_fix_heavy,
})"""
).substitute(
max_iters=max_iters,
maybe_convergence_criteria=(
self.convergence_criteria._to_rex()
if self.convergence_criteria is not None
else "None"
),
maybe_optimizer_reset_interval=optional_str(self.optimizer_reset_interval),
maybe_coordinate_system=optional_str(
self.coordinate_system, "exess_geo_opt_rex::CoordinateSystem::"
),
# maybe_constraints=optional_list(
# self.constraints,
# lambda constraint: f"vec![{', '.join(f'exess_geo_opt_rex::AtomRef ({atom})' for atom in constraint)}]",
# ),
maybe_constraints="None", # TODO
maybe_hessian_guess=optional_str(
self.hessian_guess, "exess_geo_opt_rex::HessianGuessType::"
),
maybe_algorithm=optional_str(
self.algorithm, "exess_geo_opt_rex::OptimizationAlgorithmType::"
),
maybe_lbfgs_keywords=(
self.lbfgs_keywords._to_rex()
if self.lbfgs_keywords is not None
else "None"
),
maybe_frozen_distance_slippage_tolerance_angstroms=optional_str(
self.frozen_distance_slippage_tolerance_angstroms
),
maybe_frozen_angle_slippage_tolerance_degrees=optional_str(
self.frozen_angle_slippage_tolerance_degrees
),
maybe_trust_region_keywords=(
self.trust_region_keywords._to_rex()
if self.trust_region_keywords is not None
else "None"
),
maybe_fixed_atoms=optional_str(self.fixed_atoms),
maybe_free_atoms=optional_str(self.free_atoms),
maybe_fixed_fragments=optional_str(self.fixed_fragments),
maybe_free_fragments=optional_str(self.free_fragments),
maybe_fix_heavy=optional_str(self.fix_heavy),
)
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OptimizationStep:
total_energy: float
max_gradient_component: float
[docs]
@dataclass
class OptimizationResult:
trajectory: list[Topology]
steps: list[OptimizationStep]
[docs]
@dataclass(frozen=True)
class OptimizationResultPaths:
trajectory: Path
steps: Path
[docs]
@dataclass(frozen=True)
class OptimizationResultRef:
"""Lightweight reference to optimization outputs in the Rush object store."""
trajectory: RushObject
steps: RushObject
[docs]
@classmethod
def from_raw_output(cls, res: Any) -> Self:
"""Parse raw ``collect_run`` output into an ``OptimizationResultRef``."""
if not isinstance(res, list) or len(res) != 2:
raise ValueError(
"optimization should return exactly 2 outputs (trajectory + steps), "
f"got {type(res).__name__} with {len(res) if hasattr(res, '__len__') else '?'} items."
)
return cls(
trajectory=RushObject.from_dict(res[0]),
steps=RushObject.from_dict(res[1]),
)
[docs]
def fetch(self) -> OptimizationResult:
"""Download optimization outputs and parse into Python objects."""
trajectory_raw = self.trajectory.fetch_list()
steps_raw = self.steps.fetch_list()
trajectory = [Topology.from_json(t) for t in trajectory_raw]
steps = [OptimizationStep(**step) for step in steps_raw]
return OptimizationResult(trajectory=trajectory, steps=steps)
[docs]
def save(self) -> OptimizationResultPaths:
"""Download optimization outputs and save to the workspace."""
return OptimizationResultPaths(
trajectory=self.trajectory.save(),
steps=self.steps.save(),
)
# ---------------------------------------------------------------------------
# Submission
# ---------------------------------------------------------------------------
[docs]
def optimization(
mol: TRC
| TRCRef
| tuple[Path | str | RushObject | Topology, Path | str | RushObject | Residues]
| Path
| str
| RushObject
| Topology,
max_iters: int,
optimization_keywords: OptimizationKeywords = OptimizationKeywords(),
method: MethodT = "RestrictedKSDFT",
basis: BasisT = "cc-pVDZ",
aux_basis: AuxBasisT | None = None,
standard_orientation: StandardOrientationT | None = None,
force_cartesian_basis_sets: bool | None = None,
scf_keywords: SCFKeywords | None = None,
ksdft_keywords: KSDFTKeywords | _KSDFTDefault | None = _KSDFTDefault.DEFAULT,
qm_fragments: list[int] | None = None,
mm_fragments: list[int] | None = None,
system: System | None = None,
run_spec: RunSpec = RunSpec(gpus=1),
run_opts: RunOpts = RunOpts(),
) -> Run[OptimizationResultRef]:
"""
Submit a geometry optimization for the topology at *topology_path*.
Returns a :class:`~rush.runs.Run` handle. Call ``.fetch()`` to get the
parsed trajectory and optimization steps, or ``.save()`` to write them to disk.
"""
ksdft_keywords = KSDFTKeywords.resolve(ksdft_keywords, method)
# Upload inputs
residues_vobj = None
match mol:
case TRC() | TRCRef():
topology_vobj = _to_topology_vobj(mol.topology)
residues_vobj = _to_residues_vobj(mol.residues)
case (t, r):
topology_vobj = _to_topology_vobj(t)
residues_vobj = _to_residues_vobj(r)
case _:
topology_vobj = _to_topology_vobj(mol)
# Run rex
rex = Template("""let
obj_j = λ j →
VirtualObject { path = j, format = ObjectFormat::json, size = 0 },
exess = λ topology residues →
try_exess_geo_opt_rex
($run_spec)
(exess_geo_opt_rex::OptimizationParams {
schema_version = "0.2.0",
external_charges = None,
model = Some (exess_geo_opt_rex::Model {
method = exess_geo_opt_rex::Method::$method,
basis = "$basis",
aux_basis = $maybe_aux_basis,
standard_orientation = $maybe_standard_orientation,
force_cartesian_basis_sets = $maybe_force_cartesian_basis_sets,
}),
system = $maybe_system,
keywords = exess_geo_opt_rex::Keywords {
scf = $maybe_scf_keywords,
ks_dft = $maybe_ks_keywords,
rtat = None,
frag = None,
boundary = None,
log = None,
dynamics = None,
integrals = None,
debug = None,
export = None,
guess = None,
force_field = None,
optimization = $maybe_optimization_keywords,
hessian = None,
gradient = None,
qmmm = $maybe_qmmm_keywords,
machine_learning = None,
regions = $maybe_regions,
},
})
[ (obj_j topology) ]
$residues_expr
in
exess "$topology_vobj_path" "$residues_vobj_path"
""").substitute(
run_spec=run_spec._to_rex(),
method=method,
basis=basis,
maybe_aux_basis=optional_str(aux_basis),
maybe_standard_orientation=optional_str(
standard_orientation, "exess_rex::StandardOrientation::"
),
maybe_force_cartesian_basis_sets=optional_str(force_cartesian_basis_sets),
maybe_system=system._to_rex() if system is not None else "None",
maybe_scf_keywords=(
scf_keywords._to_rex() if scf_keywords is not None else "None"
),
maybe_ks_keywords=(
ksdft_keywords._to_rex() if ksdft_keywords is not None else "None"
),
maybe_optimization_keywords=(
optimization_keywords._to_rex(max_iters)
if optimization_keywords is not None
else "None"
),
maybe_qmmm_keywords=(
"""Some (exess_qmmm_rex::QMMMKeywords {
n_timesteps = 1,
dt_ps = 0.002,
temperature_kelvin = 290.0,
pressure_atm = None,
minimisation = None,
trajectory = None,
restraints = None,
energy_csv = None,
})"""
if mm_fragments or (qm_fragments is not None)
else "None"
),
maybe_regions=(
Template(
"""Some (exess_qmmm_rex::RegionKeywords {
qm_fragments = $maybe_qm_fragments,
mm_fragments = $maybe_mm_fragments,
ml_fragments = Some [],
})"""
).substitute(
maybe_qm_fragments=optional_str(qm_fragments),
maybe_mm_fragments=optional_str(mm_fragments),
)
if not (qm_fragments is None and mm_fragments is None)
else "None"
),
residues_expr=(
"(Some [ (obj_j residues) ])" if residues_vobj is not None else "None"
),
topology_vobj_path=topology_vobj["path"],
residues_vobj_path=residues_vobj["path"] if residues_vobj is not None else "",
)
try:
return Run(_submit_rex(rex, run_opts), OptimizationResultRef)
except TransportQueryError as e:
if e.errors:
for error in e.errors:
print(f"Error: {error['message']}", file=sys.stderr)
raise