81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from typing import Dict, List, Optional
|
|
from mlflow.tracking import MlflowClient
|
|
from mlflow.entities import ViewType
|
|
from config import DEFAULT_TRACKING_URI
|
|
|
|
|
|
def _client(tracking_uri=None):
|
|
# type: (Optional[str]) -> MlflowClient
|
|
return MlflowClient(tracking_uri=tracking_uri or DEFAULT_TRACKING_URI)
|
|
|
|
|
|
def _tracking_uri(tracking_uri=None):
|
|
# type: (Optional[str]) -> str
|
|
return tracking_uri or DEFAULT_TRACKING_URI
|
|
|
|
|
|
def get_experiments(tracking_uri=None):
|
|
# type: (Optional[str]) -> List[Dict]
|
|
client = _client(tracking_uri)
|
|
experiments = client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
|
|
results = []
|
|
for exp in experiments:
|
|
run_count = len(client.search_runs(experiment_ids=[exp.experiment_id]))
|
|
results.append({
|
|
"experiment_id": exp.experiment_id,
|
|
"name": exp.name,
|
|
"lifecycle_stage": exp.lifecycle_stage,
|
|
"run_count": run_count,
|
|
})
|
|
return results
|
|
|
|
|
|
def get_runs(tracking_uri=None, experiment_id="0"):
|
|
# type: (Optional[str], str) -> List[Dict]
|
|
client = _client(tracking_uri)
|
|
runs = client.search_runs(
|
|
experiment_ids=[experiment_id],
|
|
run_view_type=ViewType.ACTIVE_ONLY,
|
|
order_by=["attributes.start_time DESC"],
|
|
)
|
|
return [
|
|
{
|
|
"run_id": r.info.run_id,
|
|
"run_name": r.info.run_name,
|
|
"experiment_id": r.info.experiment_id,
|
|
"status": r.info.status,
|
|
"start_time": r.info.start_time,
|
|
"end_time": r.info.end_time,
|
|
}
|
|
for r in runs
|
|
]
|
|
|
|
|
|
def get_run_detail(tracking_uri=None, run_id=""):
|
|
# type: (Optional[str], str) -> Dict
|
|
client = _client(tracking_uri)
|
|
run = client.get_run(run_id)
|
|
return {
|
|
"run_id": run.info.run_id,
|
|
"run_name": run.info.run_name,
|
|
"experiment_id": run.info.experiment_id,
|
|
"status": run.info.status,
|
|
"start_time": run.info.start_time,
|
|
"end_time": run.info.end_time,
|
|
"params": dict(run.data.params) if run.data.params else {},
|
|
"metrics": dict(run.data.metrics) if run.data.metrics else {},
|
|
"tags": dict(run.data.tags) if run.data.tags else {},
|
|
}
|
|
|
|
|
|
def get_mlflow_link(tracking_uri=None, run_id=""):
|
|
# type: (Optional[str], str) -> str
|
|
client = _client(tracking_uri)
|
|
run = client.get_run(run_id)
|
|
base_url = _tracking_uri(tracking_uri).rstrip("/")
|
|
return "{base}/#/experiments/{exp}/runs/{run}".format(
|
|
base=base_url,
|
|
exp=run.info.experiment_id,
|
|
run=run.info.run_id,
|
|
)
|