import json
import sys
from dataclasses import dataclass
from string import Template
from typing import Iterator
from gql.transport.exceptions import TransportQueryError
from rush import TRC, from_json
from rush.client import (
RunError,
RunOpts,
RunSpec,
_get_project_id,
_submit_rex,
collect_run,
download_object,
)
from rush.utils import bool_to_str, float_to_str
[docs]
@dataclass
class Auto3DStats:
f_max: float
converged: bool
e_rel_kcal_mol: float
e_tot_hartrees: float
[docs]
@dataclass
class Auto3DResult:
conformer: TRC
stats: Auto3DStats
[docs]
def auto3d(
smis: list[str],
k: int = 1,
batchsize_atoms: int = 1024,
capacity: int = 40,
convergence_threshold: float = 0.003,
enumerate_isomer: bool = True,
enumerate_tautomer: bool = False,
max_confs: int | None = None,
opt_steps: int = 5000,
patience: int = 1000,
threshold: float = 0.3,
run_spec: RunSpec = RunSpec(),
run_opts: RunOpts = RunOpts(),
collect=False,
):
"""
Runs Auto3D on a list of SMILES strings, returning either the TRC structure
or an error string.
"""
rex = Template("""let
auto3d = λ smis →
try_auto3d_rex
default_runspec_gpu
(auto3d_rex::Auto3dOptions {
k = Some (int $k),
batchsize_atoms = Some $batchsize_atoms,
capacity = Some $capacity,
convergence_threshold = Some $convergence_threshold,
enumerate_isomer = Some $enumerate_isomer,
enumerate_tautomer = Some $enumerate_tautomer,
job_name = None,
max_confs = $max_confs,
memory = None,
mpi_np = Some 4,
opt_steps = Some $opt_steps,
optimizing_engine = Some auto3d_rex::Auto3dOptimizingEngines::AIMNET,
patience = Some $patience,
threshold = Some $threshold,
verbose = Some false,
window = None,
})
$smis
in
auto3d $smis
""").substitute(
smis=f"[{', '.join([f'"{smi}"' for smi in smis])}]",
k=k,
batchsize_atoms=batchsize_atoms,
capacity=capacity,
convergence_threshold=float_to_str(convergence_threshold),
enumerate_isomer=bool_to_str(enumerate_isomer),
enumerate_tautomer=bool_to_str(enumerate_tautomer),
max_confs=max_confs,
opt_steps=opt_steps,
patience=patience,
threshold=float_to_str(threshold),
run_spec=run_spec._to_rex(),
)
try:
run_id = _submit_rex(_get_project_id(), rex, run_opts)
if not collect:
return run_id
result = collect_run(run_id)
if isinstance(result, RunError):
return result
def is_result_type(result):
return (
isinstance(result, dict)
and len(result) == 1
and ("Ok" in result or "Err" in result)
)
# TODO: no special cases for Result unwrapping
return [
next(iter(r_i.values())) if is_result_type(r_i) else r_i for r_i in result
]
except TransportQueryError as e:
if e.errors:
print("Error:", file=sys.stderr)
for error in e.errors:
print(f" {error['message']}", file=sys.stderr)
[docs]
def save_outputs(res) -> list[Iterator[Auto3DResult] | RunError] | str | RunError:
"""
Download output files from an auto3d run.
The auto3d rex computation returns a Rush object store pointers for TRCs
and stats for each conformer generated. There are up to k conformers
per input. Each input can either succeed, in which case a Iterator[Auto3DResult]
is returned that downloads and packages each conformer on the fly, or fail,
in which case the run error is returned.
If collect=False was used, the input will be a run ID string, which is
returned as-is for later collection by the caller.
Args:
res: Either:
- A run ID string (if collect=False was used)
- The successful output from auto3d()
- A RunError
Each VirtualObject dict has keys: 'path', 'size', 'format'.
Returns:
Either:
- A run ID string (if input was a run ID)
- list[Iterator[Auto3DResult] | RunError], if the run succeeded
- RunError if input is an error
"""
# Handle error case
if isinstance(res, RunError):
return res
# Handle run ID string (collect=False case)
if isinstance(res, str):
return res
# Handle run output
if isinstance(res, list):
def to_auto3dresult(res_i) -> Iterator[Auto3DResult]:
for trc_obj, stats in res_i:
trc_dict = {
"topology": json.loads(download_object(trc_obj[0]["path"])),
"residues": json.loads(download_object(trc_obj[1]["path"])),
"chains": json.loads(download_object(trc_obj[2]["path"])),
}
yield Auto3DResult(
from_json(trc_dict),
Auto3DStats(
stats["f_max"],
stats["converged"],
stats["e_rel_kcal_mol"],
stats["e_tot_hartrees"],
),
)
return [
RunError(res_i) if isinstance(res_i, str) else to_auto3dresult(res_i)
for res_i in res
]
# Fallback: return as-is (for debugging or unexpected formats)
return RunError(
f"Error: prepare_protein save_outputs received unexpected format: {type(res)}"
)