Source code for dplutils.observer.mlflow
from dplutils.observer import Observer
try:
import mlflow
except ImportError:
mlflow = None
[docs]
class MlflowObserver(Observer):
"""Mlflow-based observer
MLflow is an ML experiment tracker with included explorer UI. See
https://github.com/mlflow/mlflow/ for details.
Args:
run: Existing mlflow run object to use for tracking. In this case it is
assumed the tracking_uri is that set globally.
experiment: Name of experiment under which to create run (if run not supplied).
tracking_uri: tracking uri, e.g. ``file://...`` or ``mlflow://``,
etc. See mlflow docs for details.
mlflow_kwargs: In case an existing
run is not supplied, one will be
created, in which case mlflow_kwargs (excluding ``experiment_id``)
will be passed to its instantiation, using
``mlflow.MlflowClient.create_run``.
"""
[docs]
def __init__(self, run=None, experiment=None, tracking_uri=None, **mlflow_kwargs):
if mlflow is None:
raise ImportError("mlflow must be installed to create observer run!")
tracking_uri = tracking_uri or mlflow.get_tracking_uri()
self.mlflow_client = mlflow.MlflowClient(tracking_uri=tracking_uri)
if run is not None:
self.run = run
else:
expid = None
if experiment is not None:
exp = self.mlflow_client.get_experiment_by_name(experiment)
if exp is not None:
expid = exp.experiment_id
else:
expid = self.mlflow_client.create_experiment(experiment)
self.run = self.mlflow_client.create_run(experiment_id=expid, **mlflow_kwargs)
self.run_id = self.run.info.run_id
self._countercache = {}
def observe(self, name, value, **kwargs):
self.mlflow_client.log_metric(self.run_id, name, value)
def increment(self, name, value=1, **kwargs):
val = self._countercache.get(name, 0) + value
self._countercache[name] = val
self.mlflow_client.log_metric(self.run_id, name, val)
def param(self, name, value, **kwargs):
self.mlflow_client.log_param(self.run_id, name, value)