74 lines
2.0 KiB
Python
74 lines
2.0 KiB
Python
import subprocess
|
|
import uuid
|
|
from typing import Dict, List
|
|
from fastapi import APIRouter, HTTPException
|
|
from schemas import ServeRequest, ServeStatus
|
|
|
|
router = APIRouter()
|
|
|
|
# In-memory store for serving processes
|
|
_serving_processes = {} # type: Dict[str, Dict]
|
|
|
|
|
|
@router.post("/serve", response_model=ServeStatus)
|
|
def start_serve(req: ServeRequest):
|
|
serve_id = uuid.uuid4().hex[:8]
|
|
cmd = [
|
|
"mlflow", "models", "serve",
|
|
"-m", req.model_uri,
|
|
"-p", str(req.port),
|
|
"--no-conda",
|
|
]
|
|
if req.tracking_uri:
|
|
cmd.extend(["--env-manager", "local"])
|
|
|
|
try:
|
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
except FileNotFoundError:
|
|
raise HTTPException(status_code=500, detail="mlflow CLI not found")
|
|
|
|
_serving_processes[serve_id] = {
|
|
"id": serve_id,
|
|
"model_uri": req.model_uri,
|
|
"port": req.port,
|
|
"pid": proc.pid,
|
|
"process": proc,
|
|
}
|
|
|
|
return ServeStatus(
|
|
id=serve_id,
|
|
model_uri=req.model_uri,
|
|
port=req.port,
|
|
pid=proc.pid,
|
|
status="running",
|
|
)
|
|
|
|
|
|
@router.get("/serve", response_model=List[ServeStatus])
|
|
def list_serve():
|
|
results = []
|
|
for sid, info in _serving_processes.items():
|
|
proc = info["process"]
|
|
status = "running" if proc.poll() is None else "stopped"
|
|
results.append(ServeStatus(
|
|
id=sid,
|
|
model_uri=info["model_uri"],
|
|
port=info["port"],
|
|
pid=info["pid"],
|
|
status=status,
|
|
))
|
|
return results
|
|
|
|
|
|
@router.delete("/serve/{serve_id}")
|
|
def stop_serve(serve_id: str):
|
|
if serve_id not in _serving_processes:
|
|
raise HTTPException(status_code=404, detail="Serve process not found")
|
|
|
|
proc = _serving_processes[serve_id]["process"]
|
|
if proc.poll() is None:
|
|
proc.terminate()
|
|
|
|
model_uri = _serving_processes.pop(serve_id)["model_uri"]
|
|
return {"message": f"Stopped serving {model_uri}"}
|