"""Miscallaneous utility functions for the psyplot package."""
# SPDX-FileCopyrightText: 2016-2024 University of Lausanne
# SPDX-FileCopyrightText: 2020-2021 Helmholtz-Zentrum Geesthacht
# SPDX-FileCopyrightText: 2021-2024 Helmholtz-Zentrum hereon GmbH
#
# SPDX-License-Identifier: LGPL-3.0-only
import inspect
import re
import sys
from difflib import get_close_matches
from itertools import chain, filterfalse
import six
from psyplot.docstring import dedent, docstrings
[docs]
def get_default_value(func, arg):
argspec = inspect.getfullargspec(func)
return next(
default
for a, default in zip(reversed(argspec[0]), reversed(argspec.defaults))
if a == arg
)
[docs]
def isstring(s):
return isinstance(s, str)
[docs]
def plugin_entrypoints(group="psyplot", name="name"):
"""This utility function gets the entry points of the psyplot plugins"""
if sys.version_info[:2] > (3, 7):
from importlib.metadata import entry_points
try:
eps = entry_points(group=group, name=name)
except TypeError: # python<3.10
eps = [
ep for ep in entry_points().get(group, []) if ep.name == name
]
else:
from pkg_resources import iter_entry_points
eps = iter_entry_points(group=group, name=name)
return eps
[docs]
class Defaultdict(dict):
"""An ordered :class:`collections.defaultdict`
Taken from http://stackoverflow.com/a/6190500/562769"""
def __init__(self, default_factory=None, *a, **kw):
if default_factory is not None and not callable(default_factory):
raise TypeError("first argument must be callable")
dict.__init__(self, *a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = value = self.default_factory()
return value
def __reduce__(self):
if self.default_factory is None:
args = tuple()
else:
args = (self.default_factory,)
return type(self), args, None, None, self.items()
[docs]
def copy(self):
"""Return a shallow copy of the dictionary"""
return self.__copy__()
def __copy__(self):
return type(self)(self.default_factory, self)
def __deepcopy__(self, memo):
import copy
return type(self)(self.default_factory, copy.deepcopy(self.items()))
def __repr__(self):
return "Defaultdict(%s, %s)" % (
self.default_factory,
dict.__repr__(self),
)
class _TempBool(object):
"""Wrapper around a boolean defining an __enter__ and __exit__ method
Notes
-----
If you want to use this class as an instance property, rather use the
:func:`_temp_bool_prop` because this class as a descriptor is ment to be a
class descriptor"""
#: default boolean value for the :attr:`value` attribute
default = False
#: boolean value indicating whether there shall be a validation or not
value = False
def __init__(self, default=False):
"""
Parameters
----------
default: bool
value of the object"""
self.default = default
self.value = default
self._entered = []
def __enter__(self):
self.value = not self.default
self._entered.append(1)
def __exit__(self, type, value, tb):
self._entered.pop(-1)
if not self._entered:
self.value = self.default
if six.PY2:
def __nonzero__(self):
return self.value
else:
def __bool__(self):
return self.value
def __repr__(self):
return repr(bool(self))
def __str__(self):
return str(bool(self))
def __call__(self, value=None):
"""
Parameters
----------
value: bool or None
If None, the current value will be negated. Otherwise the current
value of this instance is set to the given `value`"""
if value is None:
self.value = not self.value
else:
self.value = value
def __get__(self, instance, owner):
return self
def __set__(self, instance, value):
self.value = value
def _temp_bool_prop(propname, doc="", default=False):
"""Creates a property that uses the :class:`_TempBool` class
Parameters
----------
propname: str
The attribute name to use. The _TempBool instance will be stored in the
``'_' + propname`` attribute of the corresponding instance
doc: str
The documentation of the property
default: bool
The default value of the _TempBool class"""
def getx(self):
if getattr(self, "_" + propname, None) is not None:
return getattr(self, "_" + propname)
else:
setattr(self, "_" + propname, _TempBool(default))
return getattr(self, "_" + propname)
def setx(self, value):
getattr(self, propname).value = bool(value)
def delx(self):
getattr(self, propname).value = default
return property(getx, setx, delx, doc)
[docs]
def unique_everseen(iterable, key=None):
"""List unique elements, preserving order. Remember all elements ever seen.
Function taken from https://docs.python.org/2/library/itertools.html"""
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBCcAD', str.lower) --> A B C D
seen = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
[docs]
def is_remote_url(path):
patt = re.compile(r"^https?\://")
if not isinstance(path, six.string_types):
return all(map(patt.search, (s or "" for s in path)))
return bool(re.search(r"^https?\://", path))
[docs]
@docstrings.get_sections(
base="check_key", sections=["Parameters", "Returns", "Raises"]
)
@dedent
def check_key(
key,
possible_keys,
raise_error=True,
name="formatoption keyword",
msg=("See show_fmtkeys function for possible formatopion " "keywords"),
*args,
**kwargs,
):
"""
Checks whether the key is in a list of possible keys
This function checks whether the given `key` is in `possible_keys` and if
not looks for similar sounding keys
Parameters
----------
key: str
Key to check
possible_keys: list of strings
a list of possible keys to use
raise_error: bool
If not True, a list of similar keys is returned
name: str
The name of the key that shall be used in the error message
msg: str
The additional message that shall be used if no close match to
key is found
*args, **kwargs
They are passed to the :func:`difflib.get_close_matches` function
(i.e. `n` to increase the number of returned similar keys and
`cutoff` to change the sensibility)
Returns
-------
str
The `key` if it is a valid string, else an empty string
list
A list of similar formatoption strings (if found)
str
An error message which includes
Raises
------
KeyError
If the key is not a valid formatoption and `raise_error` is True"""
if key not in possible_keys:
similarkeys = get_close_matches(key, possible_keys, *args, **kwargs)
if similarkeys:
msg = ("Unknown %s %s! Possible similiar " "frasings are %s.") % (
name,
key,
", ".join(similarkeys),
)
else:
msg = ("Unknown %s %s! ") % (name, key) + msg
if not raise_error:
return "", similarkeys, msg
raise KeyError(msg)
else:
return key, [key], ""
[docs]
def sort_kwargs(kwargs, *param_lists):
"""Function to sort keyword arguments and sort them into dictionaries
This function returns dictionaries that contain the keyword arguments
from `kwargs` corresponding given iterables in ``*params``
Parameters
----------
kwargs: dict
Original dictionary
``*param_lists``
iterables of strings, each standing for a possible key in kwargs
Returns
-------
list
len(params) + 1 dictionaries. Each dictionary contains the items of
`kwargs` corresponding to the specified list in ``*param_lists``. The
last dictionary contains the remaining items"""
return chain(
(
{key: kwargs.pop(key) for key in params.intersection(kwargs)}
for params in map(set, param_lists)
),
[kwargs],
)
[docs]
def hashable(val):
"""Test if `val` is hashable and if not, get it's string representation
Parameters
----------
val: object
Any (possibly not hashable) python object
Returns
-------
val or string
The given `val` if it is hashable or it's string representation"""
if val is None:
return val
try:
hash(val)
except TypeError:
return repr(val)
else:
return val
[docs]
@docstrings.get_sections(base="join_dicts")
def join_dicts(dicts, delimiter=None, keep_all=False):
"""Join multiple dictionaries into one
Parameters
----------
dicts: list of dict
A list of dictionaries
delimiter: str
The string that shall be used as the delimiter in case that there
are multiple values for one attribute in the arrays. If None, they
will be returned as sets
keep_all: bool
If True, all formatoptions are kept. Otherwise only the intersection
Returns
-------
dict
The combined dictionary"""
if not dicts:
return {}
if keep_all:
all_keys = set(chain(*(d.keys() for d in dicts)))
else:
all_keys = set(dicts[0])
for d in dicts[1:]:
all_keys.intersection_update(d)
ret = {}
for key in all_keys:
vals = {hashable(d.get(key, None)) for d in dicts} - {None}
if len(vals) == 1:
ret[key] = next(iter(vals))
elif delimiter is None:
ret[key] = vals
else:
ret[key] = delimiter.join(map(str, vals))
return ret
[docs]
def is_iterable(iterable):
"""Test if an object is iterable
Parameters
----------
iterable: object
The object to test
Returns
-------
bool
True, if the object is an iterable object"""
try:
iter(iterable)
except TypeError:
return False
else:
return True