Skip to content

Commit

Permalink
feat(metrics): add metric collector and support tree model training (#…
Browse files Browse the repository at this point in the history
…969)

* feat(xgli): add metric collector and apply to tree model

* fix(metrics): fix format

* fix(metrics): fix format

* fix(metrics): clear log

* feat(metrics): add more apis

* fix(metrics): fix lint

* fix(metrics): fix comments

* fix(metrics): fix commets

* fix(metrics): fix spaceline

* fix(metrics): fix comments

* fix(metrics): fix lint

* fix(metrics): fix cluster name

* feat(metrics): support custom global label

* fix(metrics): fix format
  • Loading branch information
lixiaoguang01 authored May 30, 2022
1 parent dece6ef commit bb5e6d4
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 2 deletions.
246 changes: 246 additions & 0 deletions fedlearner/common/metric_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright 2022 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8
import logging
from abc import ABC, abstractmethod

from os import environ
from threading import Lock
from typing import Optional, Union, Dict, Iterator

from opentelemetry import trace, _metrics as metrics
from opentelemetry._metrics.instrument import UpDownCounter
from opentelemetry._metrics.measurement import Measurement
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider, Span
from opentelemetry.sdk._metrics import MeterProvider
from opentelemetry.sdk._metrics.export import \
ConsoleMetricExporter, PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter \
import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc._metric_exporter \
import OTLPMetricExporter
from opentelemetry.sdk.trace.export import \
BatchSpanProcessor, ConsoleSpanExporter

_logger = logging.getLogger(__name__)


class AbstractCollector(ABC):

@abstractmethod
def emit_single_point(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass

@abstractmethod
def emit_timing(self,
name: str,
tags: Dict[str, str] = None):
pass

@abstractmethod
def emit_counter(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass

@abstractmethod
def emit_store(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass


class StubCollector(AbstractCollector):

class EmptyTrace(object):
def __init__(self):
pass

def __enter__(self):
pass

def __exit__(self, *a):
pass

def emit_single_point(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass

def emit_timing(self,
name: str,
tags: Dict[str, str] = None):
return self.EmptyTrace()

def emit_counter(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass

def emit_store(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
pass


class MetricCollector(AbstractCollector):
_DEFAULT_EXPORT_INTERVAL = 60000

class Callback:

def __init__(self) -> None:
self._measurement_list = []

def record(self, value: Union[int, float], tags: dict):
self._measurement_list.append(
Measurement(value=value, attributes=tags))

def __iter__(self):
return self

def __next__(self):
if len(self._measurement_list) == 0:
raise StopIteration
return self._measurement_list.pop(0)

def __call__(self):
return iter(self)

def __init__(
self,
service_name: Optional[str] = None,
export_interval_millis: Optional[float] = None,
custom_service_label: Optional[dict] = None,
):
if service_name is None:
service_name = environ.get('METRIC_COLLECTOR_SERVICE_NAME',
'default_metric_service')
cluster_name = environ.get('CLOUDNATIVE_CLUSTER', 'default_cluster')
if export_interval_millis is None:
try:
export_interval_millis = float(
environ.get('METRIC_COLLECTOR_EXPORT_INTERVAL_MILLIS',
self._DEFAULT_EXPORT_INTERVAL)
)
except ValueError:
_logger.error(
'Invalid value for export interval, using default %s ms',
self._DEFAULT_EXPORT_INTERVAL)
export_interval_millis = self._DEFAULT_EXPORT_INTERVAL

# for example, 'http://apm-server-apm-server:8200'
endpoint = environ.get('METRIC_COLLECTOR_EXPORT_ENDPOINT')
if endpoint is not None:
exporter = OTLPMetricExporter(endpoint=endpoint, insecure=True)
else:
exporter = ConsoleMetricExporter()

reader = PeriodicExportingMetricReader(
exporter=exporter,
export_interval_millis=export_interval_millis)
resource = Resource.create({
'service.name': service_name,
'deployment.environment': cluster_name
}.update(custom_service_label))
self._meter_provider = MeterProvider(
metric_readers=[reader],
resource=resource
)
metrics.set_meter_provider(self._meter_provider)
self._meter = metrics.get_meter_provider().get_meter(service_name)

if endpoint is not None:
exporter = OTLPSpanExporter(endpoint=endpoint)
else:
exporter = ConsoleSpanExporter()
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(
BatchSpanProcessor(exporter)
)
trace.set_tracer_provider(tracer_provider)
self._tracer = trace.get_tracer_provider().get_tracer(service_name)

self._lock = Lock()
self._cache: \
Dict[str, Union[UpDownCounter, MetricCollector.Callback]] = {}

def emit_single_point(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
cb = self.Callback()
self._meter.create_observable_gauge(
name=f'values.{name}', callback=cb
)
cb.record(value=value, tags=tags)

def emit_timing(self,
name: str,
tags: Dict[str, str] = None) -> Iterator[Span]:
return self._tracer.start_as_current_span(name=name, attributes=tags)

def emit_counter(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
if name not in self._cache:
with self._lock:
# Double check `self._cache` content.
if name not in self._cache:
counter = self._meter.create_up_down_counter(
name=f'values.{name}'
)
self._cache[name] = counter
assert isinstance(self._cache[name], UpDownCounter)
self._cache[name].add(value, attributes=tags)

def emit_store(self,
name: str,
value: Union[int, float],
tags: Dict[str, str] = None):
if name not in self._cache:
with self._lock:
# Double check `self._cache` content.
if name not in self._cache:
cb = self.Callback()
self._meter.create_observable_gauge(
name=f'values.{name}', callback=cb
)
self._cache[name] = cb
assert isinstance(self._cache[name], self.Callback)
self._cache[name].record(value=value, tags=tags)


enable = True
enable_env = environ.get('METRIC_COLLECTOR_ENABLE')
if enable_env is None:
enable = False
elif enable_env.lower() in ['false', 'f']:
enable = False

k8s_job_name = environ.get('APPLICATION_ID',
'default_k8s_job_name')
service_label = {'k8s_job_name': k8s_job_name}
metric_collector = MetricCollector(
custom_service_label=service_label) if enable else StubCollector()
2 changes: 1 addition & 1 deletion fedlearner/model/tree/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

# coding: utf-8

import pandas as pd
import numpy as np
from scipy import special as sp_special
import pandas as pd


def _roc_auc_score(label, pred):
Expand Down
8 changes: 8 additions & 0 deletions fedlearner/model/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from google.protobuf import text_format
import tensorflow.compat.v1 as tf
from fedlearner.common.metric_collector import metric_collector
from fedlearner.model.tree.packing import GradHessPacker
from fedlearner.model.tree.loss import LogisticLoss, MSELoss
from fedlearner.model.crypto import paillier, fixed_point_number
Expand Down Expand Up @@ -1292,8 +1293,15 @@ def _write_training_log(self, filename, header, metrics, pred):

def iter_metrics_handler(self, metrics, mode):
for name, value in metrics.items():
# TODO @lixiaoguang.01 old version, to be deleted
emit_store(name=name, value=value,
tags={'iteration': len(self._trees), 'mode': mode})
# new version
metrics_name = f'model.{mode}.tree_vertical.{name}'
metrics_label = {
'iteration': len(self._trees)
}
metric_collector.emit_store(metrics_name, value, metrics_label)

def fit(self,
features,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ matplotlib
flatten_dict
pyspark==3.0.2
pandas==1.1.5

opentelemetry-api==1.10.0
opentelemetry-sdk==1.10.0
opentelemetry-exporter-otlp==1.10.0

0 comments on commit bb5e6d4

Please sign in to comment.