Source code for rush.objects

"""Object store references and helpers."""

import json
import tarfile
import uuid
from dataclasses import dataclass
from functools import singledispatch
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Literal, NewType, Self

import requests
import zstandard as zstd
from gql import FileVar, gql

from .convert import from_json
from .mol import TRC, Chains, Residues, Topology
from .session import _get_client, _get_config, _get_project_id

#: UUID identifying an object in the Rush object store.
ObjectID = NewType("ObjectID", str)


[docs] @dataclass(frozen=True) class RushObject: """Reference to an object in the Rush object store.""" #: UUID path in the object store. path: ObjectID #: Size in bytes. size: int #: Storage format. format: Literal["Json", "Bin"]
[docs] @classmethod def from_dict(cls, d: dict[str, Any]) -> Self: """Construct from a raw GraphQL output dict. Requires ``path``, ``size``, and ``format`` keys. """ try: return cls( path=ObjectID(d["path"]), size=d["size"], format=d["format"], ) except KeyError as e: raise ValueError( f"RushObject dict missing required key {e}; got keys: {list(d.keys())}" ) from e
[docs] def to_dict(self) -> dict[str, Any]: return {"path": str(self.path), "size": self.size, "format": self.format}
[docs] def fetch(self, extract: bool = False) -> list[Any] | dict[str, Any] | bytes: """Download this object into memory.""" if self.format.lower() == "json": return self.fetch_json() return self.fetch_bytes(extract=extract)
[docs] def fetch_json(self) -> list[Any] | dict[str, Any]: """Download this JSON object into memory.""" if self.format.lower() != "json": raise TypeError(f"{self.path} is not a Json object") return _fetch_json_object(self.path)
[docs] def fetch_dict(self) -> dict[str, Any]: """Download this JSON object and require a dictionary payload.""" if self.format.lower() != "json": raise TypeError(f"{self.path} is not a Json object") return _fetch_json_object_as_dict(self.path)
[docs] def fetch_list(self) -> list[Any]: """Download this JSON object and require a list payload.""" if self.format.lower() != "json": raise TypeError(f"{self.path} is not a Json object") return _fetch_json_object_as_list(self.path)
[docs] def fetch_bytes(self, extract: bool = False) -> bytes: """Download this binary object into memory.""" if self.format.lower() != "bin": raise TypeError(f"{self.path} is not a Bin object") return _fetch_bin_object(self.path, extract=extract)
[docs] def save( self, filepath: Path | str | None = None, name: str | None = None, ext: str | None = None, extract: bool = False, ) -> Path: """Download this object and save to the workspace. The file type is derived from :attr:`format` automatically. Pass *ext* to override the file extension (e.g. ``"hdf5"``, ``"a3m"``). """ if ext is None: ext = self.format.lower() workspace_project_dir = _get_config().workspace_dir / _get_project_id() if filepath is not None and name is None: filepath = Path(filepath) if isinstance(filepath, str) else filepath elif filepath is None and name is not None: filepath = workspace_project_dir / f"{name}.{ext}" elif filepath is None and name is None: filepath = workspace_project_dir / f"{self.path}.{ext}" else: raise Exception("Cannot specify both filepath and name") filepath.parent.mkdir(parents=True, exist_ok=True) if self.format.lower() == "json": d = self.fetch_json() with filepath.open("w") as f: json.dump(_clean_dict(d), f, indent=2) else: b = self.fetch_bytes(extract=extract) with filepath.open("wb") as f: f.write(b) return filepath
[docs] @dataclass(frozen=True) class TRCPaths: """Workspace paths for a saved TRC triplet.""" topology: Path residues: Path chains: Path
[docs] @dataclass(frozen=True) class TRCRef: """Reference to a single TRC triplet in the Rush object store.""" topology: RushObject residues: RushObject chains: RushObject
[docs] @classmethod def upload(cls, trc: TRC) -> Self: return cls( RushObject.from_dict(upload_object(trc.topology.to_dict())), RushObject.from_dict(upload_object(trc.residues.to_dict())), RushObject.from_dict(upload_object(trc.chains.to_dict())), )
[docs] def fetch(self) -> TRC: """Download and parse into a TRC.""" topology = self.topology.fetch_dict() residues = self.residues.fetch_dict() chains = self.chains.fetch_dict() return from_json( { "topology": topology, "residues": residues, "chains": chains, } )
[docs] def save(self) -> TRCPaths: """Download and save to the workspace.""" return TRCPaths( topology=self.topology.save(), residues=self.residues.save(), chains=self.chains.save(), )
[docs] def upload_object(input: Path | str | dict[str, Any]) -> dict[str, Any]: """ Upload an object at the filepath to the current project. Usually not necessary; the module functions should handle this automatically. """ mutation = gql(""" mutation UploadObject($file: Upload!, $typeinfo: Json!, $format: ObjectFormatEnum!, $project_id: String) { upload_object(file: $file, typeinfo: $typeinfo, format: $format, project_id: $project_id) { id object { path size format } base_url url } } """) if isinstance(input, dict): t_f = NamedTemporaryFile(mode="w", suffix=".json", delete=False) json.dump(input, t_f) t_f.close() return upload_object(t_f.name) filepath = Path(input) if isinstance(input, str) else input with filepath.open(mode="rb") as f: project_id = _get_project_id() if filepath.suffix == ".json": mutation.variable_values = { "file": FileVar(f), "format": "json", "typeinfo": { "k": "record", "t": {}, }, "project_id": project_id, } else: mutation.variable_values = { "file": FileVar(f), "format": "bin", "typeinfo": { "k": "record", "t": { "size": "u32", "path": { "k": "@", "t": "$Bytes", }, }, "n": "Object", }, "project_id": project_id, } result = _get_client().execute(mutation, upload_files=True) return result["upload_object"]["object"]
def _extract_object_archive(data: bytes) -> bytes: decompressed = zstd.ZstdDecompressor().decompress(data, max_output_size=int(1e9)) with tarfile.open(fileobj=BytesIO(decompressed)) as tar: tar_filenames = tar.getnames() # Handle empty tar archives if not tar_filenames: raise ValueError("Tar archive is empty - no files to extract") # Extract the appropriate file: # - If 1 file: extract that file # - If 2+ files: extract index 1 (skip index 0, which is often metadata) file_index = 1 if len(tar_filenames) >= 2 else 0 member = tar.getmember(tar_filenames[file_index]) # If we selected a directory, find the first actual file instead if member.isdir(): file_index = None for i, name in enumerate(tar_filenames): m = tar.getmember(name) if not m.isdir(): file_index = i break if file_index is None: raise ValueError( "Tar archive contains only directories, no files to extract" ) extracted_file = tar.extractfile(tar_filenames[file_index]) if extracted_file is None: raise ValueError( f"Failed to extract file '{tar_filenames[file_index]}' from tar archive" ) return extracted_file.read()
[docs] def fetch_object( path: str, extract: bool = False ) -> list[Any] | dict[str, Any] | bytes: """ Fetch the contents of the given Rush object store path directly into memory. Be careful: if the contents are too large, they might not fit into memory. Args: path: The Rush object store path to fetch. extract: Automatically extract tar.zst archives in memory before returning. Returns: the data from the object store path, as a dict (JSON objects) or bytes (bin objects) """ # TODO: enforce UUID type query = gql(""" query GetObject($path: String!) { object_path(path: $path) { url object { format size } } } """) query.variable_values = {"path": path} result = _get_client().execute(query) obj_descriptor = result["object_path"] # Json if "contents" in obj_descriptor: contents = obj_descriptor["contents"] if not isinstance(contents, (list, dict)): raise TypeError(f"Expected JSON contents for {path}, got {type(contents)}") return contents # Bin elif "url" in obj_descriptor: response = requests.get(obj_descriptor["url"]) response.raise_for_status() data = response.content return _extract_object_archive(data) if extract else data raise Exception(f"Object at path {path} has neither contents nor URL")
def _fetch_json_object(path: str, extract: bool = False) -> list[Any] | dict[str, Any]: raw = fetch_object(path, extract) if isinstance(raw, (str, bytes, bytearray)): raw = json.loads(raw) if not isinstance(raw, (list, dict)): raise TypeError(f"Expected JSON object for {path}, got {type(raw)}") return raw def _fetch_json_object_as_list(path: str, extract: bool = False) -> list[Any]: raw = fetch_object(path, extract) if isinstance(raw, (str, bytes, bytearray)): raw = json.loads(raw) if not isinstance(raw, list): raise TypeError(f"Expected JSON list object for {path}, got {type(raw)}") return raw def _fetch_json_object_as_dict(path: str, extract: bool = False) -> dict[str, Any]: raw = fetch_object(path, extract) if isinstance(raw, (str, bytes, bytearray)): raw = json.loads(raw) if not isinstance(raw, dict): raise TypeError(f"Expected JSON dict object for {path}, got {type(raw)}") return raw def _fetch_bin_object(path: str, extract: bool = False) -> bytes: raw = fetch_object(path, extract) if not isinstance(raw, (str, bytes, bytearray)): raise TypeError(f"Expected Bin object for {path}, got {type(raw)}") if isinstance(raw, str): raw = raw.encode() if isinstance(raw, bytearray): raw = bytes(raw) return raw
[docs] def save_object( path: str, filepath: Path | str | None = None, name: str | None = None, type: Literal["json", "bin"] | None = None, ext: str | None = None, extract: bool = False, ) -> Path: """Save a Rush object store path to the workspace. Prefer :meth:`RushObject.save` when you have a ``RushObject``. This function infers the format from the *type* parameter. """ if type is None: type = "json" format: Literal["Json", "Bin"] = "Json" if type == "json" else "Bin" obj = RushObject(path=ObjectID(path), size=0, format=format) return obj.save(filepath=filepath, name=name, ext=ext, extract=extract)
def _clean_dict(d: Any) -> Any: if isinstance(d, dict): return {k: _clean_dict(v) for k, v in d.items() if v is not None} if isinstance(d, list): return [_clean_dict(v) for v in d] return d def _json_content_name(prefix: str, d: dict[str, Any]) -> str: payload = json.dumps(_clean_dict(d), sort_keys=True, separators=(",", ":")) return f"{prefix}_{uuid.uuid5(uuid.NAMESPACE_OID, payload)}"
[docs] def save_json( d: dict[str, Any], filepath: Path | str | None = None, name: str | None = None, ) -> Path: """ Save a JSON file into the workspace folder. Convenient for saving non-object JSON output from a module run alongside the object outputs. """ if filepath is not None and name is None: filepath = Path(filepath) if isinstance(filepath, str) else filepath elif filepath is None and name is not None: filepath = _get_config().workspace_dir / _get_project_id() / f"{name}.json" else: raise Exception("Must specify either filepath or name") filepath.parent.mkdir(parents=True, exist_ok=True) with filepath.open("w") as f: json.dump(_clean_dict(d), f, indent=2) return filepath
@singledispatch def _to_topology_vobj(item: Any) -> dict[str, Any]: raise NotImplementedError(f"Cannot convert {type(item)} to a Topology vobj!") @_to_topology_vobj.register def _(trc: TRC) -> dict[str, Any]: return upload_object(trc.topology.to_dict()) @_to_topology_vobj.register def _(trc_ref: TRCRef) -> dict[str, Any]: return trc_ref.topology.to_dict() @_to_topology_vobj.register def _(path: Path | str) -> dict[str, Any]: return upload_object(path) @_to_topology_vobj.register def _(object: RushObject) -> dict[str, Any]: return object.to_dict() @_to_topology_vobj.register def _(t: Topology) -> dict[str, Any]: return upload_object(t.to_dict()) @singledispatch def _to_residues_vobj(item: Any) -> dict[str, Any]: raise NotImplementedError(f"Cannot convert {type(item)} to a Residues vobj!") @_to_residues_vobj.register def _(trc: TRC) -> dict[str, Any]: return upload_object(trc.residues.to_dict()) @_to_residues_vobj.register def _(trc_ref: TRCRef) -> dict[str, Any]: return trc_ref.residues.to_dict() @_to_residues_vobj.register def _(path: Path | str) -> dict[str, Any]: return upload_object(path) @_to_residues_vobj.register def _(object: RushObject) -> dict[str, Any]: return object.to_dict() @_to_residues_vobj.register def _(r: Residues) -> dict[str, Any]: return upload_object(r.to_dict()) @singledispatch def _to_chains_vobj(item: Any) -> dict[str, Any]: raise NotImplementedError(f"Cannot convert {type(item)} to a Chains vobj!") @_to_chains_vobj.register def _(trc: TRC) -> dict[str, Any]: return upload_object(trc.chains.to_dict()) @_to_chains_vobj.register def _(trc_ref: TRCRef) -> dict[str, Any]: return trc_ref.chains.to_dict() @_to_chains_vobj.register def _(path: Path | str) -> dict[str, Any]: return upload_object(path) @_to_chains_vobj.register def _(object: RushObject) -> dict[str, Any]: return object.to_dict() @_to_chains_vobj.register def _(c: Chains) -> dict[str, Any]: return upload_object(c.to_dict())