import functools
import inspect
import logging
import os
import warnings
from collections.abc import Callable
from functools import wraps
from importlib import import_module
from inspect import isclass
from inspect import ismethod
from pkgutil import walk_packages
from types import MethodType
from typing import List
from typing import Set
from typing import Type
from typing import Union
from sklearn.base import BaseEstimator
from sklearn.utils.metaestimators import _IffHasAttrDescriptor
from sklearn_instrumentation.types import Estimator
logger = logging.getLogger(__name__)
[docs]def compose_decorators(decorators: List[Callable]) -> Callable:
"""Compose multiple decorators into one.
Helper function for combining multiple instrumentation decorators into one.
:param list(Callable) decorators: A list of instrumentation decorators to be
combined into a single decorator.
"""
def composed(estimator: Estimator, func: Callable, **dkwargs) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
wrapped_func = func
for decorator in decorators:
wrapped_func = decorator(estimator, wrapped_func, **dkwargs)
return wrapped_func(*args, **kwargs)
return wrapper
return composed
[docs]def get_sklearn_estimator_from_method(func: Callable) -> BaseEstimator:
"""Get the estimator of a method or delegate.
Raises TypeError is the instance is not a BaseEstimator.
:param Callable func: A bound method or delegator function of a BaseEstimator
instance
:return: The BaseEstimator instance of the method or delegator.
"""
err = "Passed function is not a method or delegate of a BaseEstimator"
if ismethod(func):
obj = func.__self__
if isinstance(obj, BaseEstimator):
return func.__self__
else:
try:
for cell in func.__closure__:
obj = cell.cell_contents
if isinstance(obj, BaseEstimator):
return obj
except TypeError as exc:
raise TypeError(err)
raise TypeError(err)
def get_method_class_name(method: Callable) -> str:
if isinstance(method, property):
return method.fget.__qualname__.split(".")[0]
else:
return method.__qualname__.split(".")[0]
[docs]def get_method_class(estimator: Type[BaseEstimator], method_name: str) -> Type:
"""Get the class owner of the (possibly inherited) method."""
method = getattr(estimator, method_name)
method_class_name = get_method_class_name(method=method)
if estimator.__name__ == method_class_name:
return estimator
for class_ in estimator.mro():
if class_.__name__ == method_class_name:
return class_
raise AttributeError("Unable to determine method's class.")
[docs]def is_class_method(func: Callable) -> bool:
"""Indicate if the method belongs to a class (opposed to an instance)."""
if list(inspect.signature(func).parameters.keys())[0] == "self":
return True
[docs]def is_instance_method(func: Callable) -> bool:
"""Indicate if the method belongs to an instance of a class (opposed to the class)."""
return not is_class_method(func)
[docs]def get_descriptor(func: Callable) -> _IffHasAttrDescriptor:
"""Get the corresponding ``_IffHasAttrDescriptor``."""
for cell in func.__closure__:
obj = cell.cell_contents
if isinstance(obj, _IffHasAttrDescriptor):
return obj
[docs]def is_delegator(func: Callable) -> bool:
"""Indicate if the method is delegated using ``_IffHasAttrDescriptor``."""
try:
for cell in getattr(func, "__closure__", []):
obj = cell.cell_contents
if isinstance(obj, _IffHasAttrDescriptor):
return True
except TypeError as exc:
pass
return False
[docs]def method_is_inherited(estimator: Estimator, method: Callable) -> bool:
"""Indicate if the estimator's method is inherited from a parent class."""
method_class_name = get_method_class_name(method=method)
try:
estimator_class_name = estimator.__name__
except AttributeError:
estimator_class_name = estimator.__class__.__name__
return method_class_name != estimator_class_name
[docs]def has_instrumentation(
estimator: Union[BaseEstimator, Type[BaseEstimator]], method_name: str
) -> bool:
"""Indicate if the estimator's method is instrumented."""
method = getattr(estimator, method_name)
instr_attrib_name = f"_skli_{method_name}"
if hasattr(estimator, instr_attrib_name):
return True
if is_delegator(method):
descriptor = get_descriptor(method)
if hasattr(descriptor, instr_attrib_name):
return True
return False
[docs]def non_self_arg(func: Callable, args: tuple, idx: int):
"""Get the value of a corresponding arg index ignoring self for class methods."""
if is_class_method(func):
return args[idx + 1]
else:
return args[idx]
[docs]def get_arg_by_key(func: Callable, args: tuple, key: str):
"""Get the value of a corresponding arg name as found in a function's signature."""
keys = list(inspect.signature(func).parameters.keys())
idx = keys.index(key)
if is_delegator(func):
return args[idx - 1]
return args[idx]
[docs]def get_estimators_in_packages(
package_names: List[str],
) -> Set[Type[BaseEstimator]]:
"""Get all BaseEstimators from a list of packages.
:param list(str) package_names: a list of package names from which to get BaseEstimators
:return: A dictionary of fully qualified class names as keys and classes as values
"""
base_estimators = set()
for package_name in package_names:
base_estimators = base_estimators.union(
get_estimators_in_package(package_name=package_name)
)
return base_estimators
[docs]def get_estimators_in_package(
package_name: str = "sklearn",
) -> Set[Type[BaseEstimator]]:
"""Get all BaseEstimators from a package.
:param str package_name: a package name from which to get BaseEstimators
:return: A dictionary of fully qualified class names as keys and classes as values
"""
base_estimators = set()
package = import_module(package_name)
package_dir = os.path.dirname(package.__file__)
for (_, module_name, _) in walk_packages(
[package_dir], prefix=package.__name__ + "."
):
if "test" in module_name:
continue
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
module = import_module(module_name)
except ImportError:
logger.warning(f"Unable to import {package_name}.{module_name}")
continue
for module_attribute_name in dir(module):
module_attribute = getattr(module, module_attribute_name)
if isclass(module_attribute) and issubclass(
module_attribute, BaseEstimator
):
base_estimators.add(module_attribute)
return base_estimators
def get_name(estimator: Estimator, func: Union[Callable, MethodType]) -> str:
if isinstance(func, MethodType):
self_qname = f"{func.__self__.__class__.__qualname__}.{func.__name__}"
if self_qname == func.__qualname__:
name = func.__qualname__
else:
name = f"{self_qname} ({func.__qualname__})"
else:
module = inspect.getmodule(func)
cls = getattr(module, func.__qualname__.split(".")[0], None)
if isinstance(estimator, type):
obj_name = estimator.__qualname__
else:
obj_name = estimator.__class__.__name__
if cls:
cls_name = cls.__qualname__
else:
cls_name = None
if obj_name == cls_name or cls_name is None:
name = func.__qualname__
else:
name = f"{obj_name}.{func.__name__} ({cls_name}.{func.__name__})"
return name