Files
mlflow-dashboard/services/mlflow_service.py

81 lines
2.5 KiB
Python
Raw Permalink Normal View History

from typing import Dict, List, Optional
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
from config import DEFAULT_TRACKING_URI
def _client(tracking_uri=None):
# type: (Optional[str]) -> MlflowClient
return MlflowClient(tracking_uri=tracking_uri or DEFAULT_TRACKING_URI)
def _tracking_uri(tracking_uri=None):
# type: (Optional[str]) -> str
return tracking_uri or DEFAULT_TRACKING_URI
def get_experiments(tracking_uri=None):
# type: (Optional[str]) -> List[Dict]
client = _client(tracking_uri)
experiments = client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
results = []
for exp in experiments:
run_count = len(client.search_runs(experiment_ids=[exp.experiment_id]))
results.append({
"experiment_id": exp.experiment_id,
"name": exp.name,
"lifecycle_stage": exp.lifecycle_stage,
"run_count": run_count,
})
return results
def get_runs(tracking_uri=None, experiment_id="0"):
# type: (Optional[str], str) -> List[Dict]
client = _client(tracking_uri)
runs = client.search_runs(
experiment_ids=[experiment_id],
run_view_type=ViewType.ACTIVE_ONLY,
order_by=["attributes.start_time DESC"],
)
return [
{
"run_id": r.info.run_id,
"run_name": r.info.run_name,
"experiment_id": r.info.experiment_id,
"status": r.info.status,
"start_time": r.info.start_time,
"end_time": r.info.end_time,
}
for r in runs
]
def get_run_detail(tracking_uri=None, run_id=""):
# type: (Optional[str], str) -> Dict
client = _client(tracking_uri)
run = client.get_run(run_id)
return {
"run_id": run.info.run_id,
"run_name": run.info.run_name,
"experiment_id": run.info.experiment_id,
"status": run.info.status,
"start_time": run.info.start_time,
"end_time": run.info.end_time,
"params": dict(run.data.params) if run.data.params else {},
"metrics": dict(run.data.metrics) if run.data.metrics else {},
"tags": dict(run.data.tags) if run.data.tags else {},
}
def get_mlflow_link(tracking_uri=None, run_id=""):
# type: (Optional[str], str) -> str
client = _client(tracking_uri)
run = client.get_run(run_id)
base_url = _tracking_uri(tracking_uri).rstrip("/")
return "{base}/#/experiments/{exp}/runs/{run}".format(
base=base_url,
exp=run.info.experiment_id,
run=run.info.run_id,
)