from typing import Dict, List, Optional from mlflow.tracking import MlflowClient from mlflow.entities import ViewType from config import DEFAULT_TRACKING_URI def _client(tracking_uri=None): # type: (Optional[str]) -> MlflowClient return MlflowClient(tracking_uri=tracking_uri or DEFAULT_TRACKING_URI) def _tracking_uri(tracking_uri=None): # type: (Optional[str]) -> str return tracking_uri or DEFAULT_TRACKING_URI def get_experiments(tracking_uri=None): # type: (Optional[str]) -> List[Dict] client = _client(tracking_uri) experiments = client.search_experiments(view_type=ViewType.ACTIVE_ONLY) results = [] for exp in experiments: run_count = len(client.search_runs(experiment_ids=[exp.experiment_id])) results.append({ "experiment_id": exp.experiment_id, "name": exp.name, "lifecycle_stage": exp.lifecycle_stage, "run_count": run_count, }) return results def get_runs(tracking_uri=None, experiment_id="0"): # type: (Optional[str], str) -> List[Dict] client = _client(tracking_uri) runs = client.search_runs( experiment_ids=[experiment_id], run_view_type=ViewType.ACTIVE_ONLY, order_by=["attributes.start_time DESC"], ) return [ { "run_id": r.info.run_id, "run_name": r.info.run_name, "experiment_id": r.info.experiment_id, "status": r.info.status, "start_time": r.info.start_time, "end_time": r.info.end_time, } for r in runs ] def get_run_detail(tracking_uri=None, run_id=""): # type: (Optional[str], str) -> Dict client = _client(tracking_uri) run = client.get_run(run_id) return { "run_id": run.info.run_id, "run_name": run.info.run_name, "experiment_id": run.info.experiment_id, "status": run.info.status, "start_time": run.info.start_time, "end_time": run.info.end_time, "params": dict(run.data.params) if run.data.params else {}, "metrics": dict(run.data.metrics) if run.data.metrics else {}, "tags": dict(run.data.tags) if run.data.tags else {}, } def get_mlflow_link(tracking_uri=None, run_id=""): # type: (Optional[str], str) -> str client = _client(tracking_uri) run = client.get_run(run_id) base_url = _tracking_uri(tracking_uri).rstrip("/") return "{base}/#/experiments/{exp}/runs/{run}".format( base=base_url, exp=run.info.experiment_id, run=run.info.run_id, )