Source code for psy_view.plotmethods
"""Plotmethod widgets.
This module defines the widgets to interface with the mapplot, plot2d and
lineplot plotmethods.
"""
# SPDX-FileCopyrightText: 2020-2021 Helmholtz-Zentrum Geesthacht
# SPDX-FileCopyrightText: 2021-2024 Helmholtz-Zentrum hereon GmbH
#
# SPDX-License-Identifier: LGPL-3.0-only
from __future__ import annotations
import contextlib
import dataclasses
import textwrap
from enum import Enum
from functools import partial
from itertools import chain, cycle
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import matplotlib.colors as mcol
import psy_simple.widgets.colors as pswc
import psyplot.data as psyd
import xarray as xr
from psyplot.utils import unique_everseen
from psyplot_gui.common import get_icon as get_psy_icon
from PyQt5 import QtCore, QtGui, QtWidgets
import psy_view.dialogs as dialogs
import psy_view.utils as utils
from psy_view.rcsetup import rcParams
if TYPE_CHECKING:
from psyplot.data import InteractiveList
from psyplot.plotter import Plotter
from psyplot.project import Project
from xarray import DataArray, Dataset
T = TypeVar("T", bound="GridCell")
[docs]
class PlotType(str, Enum):
"""A value for the 2D ``plot`` formatoption.
See Also
--------
psy_simple.plotters.Simple2DPlotter.plot
psy_maps.plotters.FieldPlotter.plot
"""
mesh = "mesh"
contourf = "contourf"
contour = "contour"
poly = "poly"
#: A mapping from projection identifier to the text we want to use in the GUI
projection_map = {
"cf": "Default",
"cyl": "Cylindric",
"robin": "Robinson",
"ortho": "Orthographic",
"moll": "Mollweide",
"northpole": "Arctic (Northpole)",
"southpole": "Antarctic (Southpole)",
}
[docs]
@dataclasses.dataclass
class GridCell:
"""A grid cell within a QGridLayout managing one QWidget."""
#: the :class:`PyQt5.QtWidgets.QWidget` instance (i.e. the widget) or a
#: layout (:class:`PyQt5.QtWidgets.QLayout`)
qobject: Union[QtWidgets.QWidget, QtWidgets.QLayout]
#: The starting column for the widget. If None, it will be estimated based
#: on the other widgets in the row
column: Optional[int] = None
#: The number of columns to cover
colspan: int = 1
#: A boolean whether to add stretch or not
stretch: bool = False
[docs]
@classmethod
def from_alias(
cls: Type[T],
o: Union[QtWidgets.QWidget, QtWidgets.QLayout],
c: Optional[int] = None,
cs: int = 1,
s: bool = False,
) -> T:
"""Create a :class:`GridCell` from shorter kws.
Parameters
----------
o: QWidget or QLayout
The alias for :attr:`qobject`
c: int or None, optional
The alias for :attr:`column`
cs: int, optional
The alias for :attr:`colspan`
s: bool, optional
The alias for :attr:`stretch`
Returns
-------
GridCell
The generated instance of :class:`GridCell`
"""
return cls(qobject=o, column=c, colspan=cs, stretch=s)
[docs]
class PlotMethodWidget(QtWidgets.QWidget):
"""Base class for interfacing a psyplot plotmethod.
This method serves as a base class for interfacing any of the psyplot
plot methods registered via :func:`psyplot.project.register_plotter`.
The name of the plotmethod should be implemented as the :attr:`plotmethod`
attribute.
"""
plotmethod: ClassVar[str] = ""
#: trigger a replot of this widget. This can be emitted with the
#: :meth:`trigger_replot` method
replot = QtCore.pyqtSignal(str)
#: trigger a replot of this widget. This can be emitted with the
#: :meth:`trigger_reset` method
reset = QtCore.pyqtSignal(str)
#: signalize that the widget has been changed but not plot changes are
#: needed
changed = QtCore.pyqtSignal(str)
array_info = None
layout: QtWidgets.QGridLayout = None
def __init__(
self, get_sp: Callable[[], Optional[Project]], ds: Optional[Dataset]
):
super().__init__()
self._get_sp = get_sp
self.setup()
if hasattr(self, "layout"):
self.setLayout(self.layout)
self.refresh(ds)
[docs]
def setup(self) -> None:
"""Set up the widget during initialization."""
self.layout = QtWidgets.QGridLayout()
self.setup_widgets()
self.setup_widget_grid()
@property
def formatoption_rows(self) -> List[List[GridCell]]:
"""Get a mapping from row name to a row of :class:`GridCells`."""
rows: List[List[GridCell]] = []
for func in self.fmt_setup_functions:
rows.extend(self.get_rows(func))
return rows
[docs]
def get_rows(self, func: Callable) -> List[List[GridCell]]:
"""Get the rows of the formatoption widgets.
This method should take callable from the :attr:`fmt_setup_functions`
list and return the rows corresponding to :attr:`formatoption_rows`.
"""
return [[]]
@property
def fmt_setup_functions(self) -> List[Callable]:
"""Get a list of rows for formatoptions.
This property returns a list of callable. Each callable should setup
a horizonal (or widget) that is added to the :attr:`layout` vbox.
"""
return []
[docs]
def setup_widgets(self) -> None:
"""Set up the widgets for this plotmethod."""
for func in self.fmt_setup_functions:
func()
[docs]
def setup_separation_line(self) -> None:
"""Just a convenience function to create a separation line.
This method does nothing but tells :meth:`get_rows` to add an instance
of :class:`QHline`.
"""
pass
[docs]
def setup_widget_grid(self) -> None:
"""Setup the widget grid based on :attr:`formatoption_rows`."""
rows = self.formatoption_rows
layout = self.layout
for i, row in enumerate(rows):
col: int = 0
for gc in row:
col = gc.column if gc.column is not None else col
if isinstance(gc.qobject, QtWidgets.QLayout):
layout.addLayout(gc.qobject, i, col, 1, gc.colspan)
else:
layout.addWidget(gc.qobject, i, col, 1, gc.colspan)
col += gc.colspan
layout.setRowStretch(len(rows), 1)
@property
def sp(self) -> Optional[Project]:
"""Get the subproject of this plotmethod interface."""
return getattr(self._get_sp(), self.plotmethod, None)
@property
def data(self) -> Union[DataArray, InteractiveList]:
"""Get the data of this plotmethod interface."""
if self.sp is None:
raise ValueError("No plot has yet been initialized")
else:
return self.sp[0]
@property
def plotter(self) -> Optional[Plotter]:
"""Get the first plotter of the :attr:`sp` project."""
if self.sp and self.sp.plotters:
return self.sp.plotters[0]
else:
return None
@property
def formatoptions(self) -> List[str]:
"""Get the formatoption keys of this plotmethod."""
if self.plotter is not None:
return list(self.plotter)
else:
import psyplot.project as psy
return list(getattr(psy.plot, self.plotmethod).plotter_cls())
[docs]
def get_fmts(
self, var: DataArray, init: bool = False
) -> Dict[Union[Hashable, str, Any], Any]:
"""Get the formatoptions for a new plot.
Parameters
----------
var: xarray.Variable
The variable in the base dataset
init: bool
If True, call the :meth:`init_dims` method to inject necessary
formatoptions and dimensions for the initialization.
Returns
-------
dict
A mapping from formatoption or dimension to the corresponding value
for the plotmethod.
"""
ret = {}
if init:
ret.update(self.init_dims(var))
return ret
[docs]
def init_dims(
self, var: DataArray
) -> Dict[Union[Hashable, str, Any], Any]:
"""Get the formatoptions for a new plot.
Parameters
----------
var: xarray.Variable
The variable in the base dataset
Returns
-------
dict
A mapping from formatoption or dimension to the corresponding value
for the plotmethod.
"""
return {}
[docs]
def refresh(self, ds: Optional[Dataset]) -> None:
"""Refresh this widget from the given dataset."""
self.setEnabled(bool(self.sp))
[docs]
def trigger_replot(self) -> None:
"""Emit the :attr:`replot` signal to replot the project."""
self.replot.emit(self.plotmethod)
[docs]
def trigger_reset(self):
"""Emit the :attr:`reset` signal to reinitialize the project."""
self.array_info = self.sp.array_info(standardize_dims=False)[
self.sp[0].psy.arr_name
]
self.reset.emit(self.plotmethod)
[docs]
def trigger_refresh(self):
"""Emit the :attr:`changed` signal to notify changes in the plot."""
self.changed.emit(self.plotmethod)
[docs]
def get_slice(
self, x: float, y: float
) -> Optional[Dict[Hashable, Union[int, slice]]]:
"""Get the slice for the selected coordinates.
This method is called when the user clicks on the coordinates in the
plot.
See Also
--------
psy_view.ds_widget.DatasetWidget.display_line
Notes
-----
This is reimplemented in the :class:`MapPlotWidget`.
"""
return None
[docs]
def valid_variables(self, ds: Dataset) -> List[Hashable]:
"""Get a list of variables that can be visualized with this plotmethod.
Parameters
----------
ds: xarray.Dataset
The dataset to use
Returns
-------
list of str
List of variable names to plot
"""
ret = []
plotmethod = getattr(ds.psy.plot, self.plotmethod)
for v in list(ds):
init_kws = self.init_dims(ds[v]) # type: ignore
dims = init_kws.get("dims", {})
decoder = init_kws.get("decoder")
if plotmethod.check_data(ds, v, dims, decoder)[0][0]:
ret.append(v)
return ret
[docs]
class QHLine(QtWidgets.QFrame):
"""A horizontal seperation line."""
def __init__(self):
super().__init__()
self.setMinimumWidth(1)
self.setFixedHeight(20)
self.setFrameShape(QtWidgets.QFrame.HLine)
self.setFrameShadow(QtWidgets.QFrame.Sunken)
self.setSizePolicy(
QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Minimum
)
[docs]
class MapPlotWidget(PlotMethodWidget):
"""A widget to control the mapplot plotmethod."""
plotmethod = "mapplot"
@property
def sp(self) -> Optional[Project]:
sp = super().sp
if sp:
arrays: List[str] = [
data.psy.arr_name
for data in sp
if not isinstance(data, psyd.InteractiveList)
]
return sp(arr_name=arrays)
return sp
[docs]
def get_rows(self, func: Callable) -> List[List[GridCell]]:
"""Get the rows of the formatoption widgets.
This method should take callable from the :attr:`fmt_setup_functions`
list and return the rows corresponding to :attr:`formatoption_rows`.
"""
if func == self.setup_color_buttons:
row = [
GridCell(QtWidgets.QLabel("Colormap")),
GridCell(self.btn_cmap),
GridCell(self.btn_cmap_settings, stretch=True),
]
elif func == self.setup_plot_buttons:
row = [
GridCell(QtWidgets.QLabel("Plot type")),
GridCell(self.combo_plot),
GridCell(self.btn_datagrid, stretch=True),
]
elif func == self.setup_projection_buttons:
row = [
GridCell(QtWidgets.QLabel("Projection")),
GridCell(self.btn_proj),
GridCell(self.btn_proj_settings, stretch=True),
]
elif func == self.setup_labels_button:
row = [GridCell(self.btn_labels, colspan=3)]
elif func == self.setup_separation_line:
row = [GridCell(QHLine(), colspan=3)]
elif func == self.setup_dimension_box:
row = [GridCell(self.dimension_box, colspan=3)]
elif func == self.setup_fix_extent_buttons:
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.btn_fix_extent)
hbox.addWidget(self.btn_fix_lonlatbox)
hbox.addWidget(self.btn_reset_extent)
row = [
GridCell(QtWidgets.QLabel("Viewport")),
GridCell(hbox, colspan=3),
]
else:
raise ValueError(f"Unknown function {func}")
return [row]
@property
def fmt_setup_functions(self) -> List[Callable]:
"""Get a list of rows for formatoptions.
This property returns a list of callable. Each callable should setup
a horizonal (or widget) that is added to the :attr:`layout` vbox.
"""
return [
self.setup_color_buttons,
self.setup_plot_buttons,
self.setup_projection_buttons,
self.setup_labels_button,
self.setup_separation_line,
self.setup_fix_extent_buttons,
self.setup_separation_line,
self.setup_dimension_box,
]
[docs]
def setup_labels_button(self) -> None:
"""Add a button to modify the text labels."""
self.btn_labels = utils.add_pushbutton(
"Edit labels",
self.edit_labels,
"Edit title, colorbar labels, etc.",
)
[docs]
def setup_fix_extent_buttons(self):
"""Setup the buttons for fixing the extent or data"""
self.btn_fix_extent = utils.add_pushbutton(
"Fix view",
self.fix_extent,
(
"Fix the current view such that it is not update on redrawing "
"the plot."
),
)
self.btn_fix_lonlatbox = utils.add_pushbutton(
"Restrict data",
self.fix_data,
(
"Extract the data for the plot, such that it matches the "
"current view."
),
)
self.btn_reset_extent = utils.add_pushbutton(
"Reset",
self.reset_extent_and_lonlatbox,
(
"Reset the view and the data selection to the default "
"(i.e. to the full data)."
),
)
[docs]
def fix_extent(self):
"""Call the fix_map_extent method of all map plotters."""
for plotter in self.sp.plotters:
if hasattr(plotter, "fix_map_extent"):
plotter.fix_map_extent()
[docs]
def fix_data(self):
"""Call the fix_lonlatbox method of all map plotters."""
for plotter in self.sp.plotters:
if hasattr(plotter, "fix_lonlatbox"):
plotter.fix_lonlatbox()
[docs]
def reset_extent_and_lonlatbox(self):
"""Reset the extent and the lonlatbox to the defaults."""
for plotter in self.sp.plotters:
plotter.update(
lonlatbox=plotter.lonlatbox.default,
map_extent=plotter.map_extent.default,
force=True,
)
[docs]
def setup_plot_buttons(self) -> None:
"""Setup the second row of formatoption widgets."""
self.combo_plot = QtWidgets.QComboBox()
self.plot_types: List[Optional[PlotType]] = [
PlotType.mesh,
PlotType.contourf,
PlotType.contour,
PlotType.poly,
None,
]
self.combo_plot.setEditable(False)
self.combo_plot.addItems(
[
"Default",
"Filled contours",
"Contours",
"Gridcell polygons",
"Disable",
]
)
self.combo_plot.currentIndexChanged.connect(self._set_plot_type)
self.btn_datagrid = utils.add_pushbutton(
"Gridcell boundaries",
self.toggle_datagrid,
"Toggle the visibility of grid cell boundaries",
)
self.btn_datagrid.setCheckable(True)
return
[docs]
def setup_color_buttons(self) -> None:
"""Set up the buttons to change the colormap, etc."""
self.btn_cmap = pswc.CmapButton()
self.btn_cmap.setSizePolicy(
QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
)
self.btn_cmap.setToolTip("Select a different colormap")
self.btn_cmap.colormap_changed.connect(self.set_cmap)
self.btn_cmap.colormap_changed[mcol.Colormap].connect(self.set_cmap)
self.setup_cmap_menu()
self.btn_cmap_settings = utils.add_pushbutton(
utils.get_icon("color_settings"),
self.edit_color_settings,
"Edit color settings",
icon=True,
)
[docs]
def open_cmap_dialog(self, N: int = 10) -> None:
"""Open the dialog to change the colormap.
Parameters
----------
N: int
The number of colormaps to show
See Also
--------
psy_simple.widgets.colors.CmapButton
"""
if self.plotter:
N = self.plotter.plot.mappable.get_cmap().N
else:
N = 10
self.btn_cmap.open_cmap_dialog(N)
[docs]
def get_projection_label(self, proj: str) -> str:
"""Get the label for a projection in the GUI.
Parameters
----------
proj: str
The projection value for the
:attr:`~psy_maps.plotters.FieldPlotter.plot` formatoption
Returns
-------
str
The label of the projection in the GUI
See Also
--------
get_projection_label
projection_map
psy_maps.plotters.FieldPlotter.plot
"""
return projection_map.get(proj, proj)
[docs]
def get_projection_value(self, label: str) -> str:
"""Get the value for the `projection` formatoption.
This method is the inverse of :meth:`get_projection_label`.
Parameters
----------
label: str
The projection label that is used in the GUI
:attr:`~psy_maps.plotters.FieldPlotter.plot` formatoption
Returns
-------
str
The value to use for the ``projection`` formatoption
See Also
--------
get_projection_label
projection_map
psy_maps.plotters.FieldPlotter.plot
"""
inv_map = {lbl: proj for proj, lbl in projection_map.items()}
return inv_map.get(label, rcParams["projections"][0])
[docs]
def setup_projection_buttons(self) -> None:
"""Set up the buttons to modify the basemap."""
self.btn_proj = utils.add_pushbutton(
self.get_projection_label(rcParams["projections"][0]),
self.choose_next_projection,
"Change the basemap projection",
toolbutton=True,
)
self.btn_proj.setMenu(self.setup_projection_menu())
self.btn_proj.setSizePolicy(
QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
)
self.btn_proj.setPopupMode(QtWidgets.QToolButton.MenuButtonPopup)
self.btn_proj_settings = utils.add_pushbutton(
utils.get_icon("proj_settings"),
self.edit_basemap_settings,
"Edit basemap settings",
icon=True,
)
[docs]
def setup_dimension_box(self) -> None:
"""Set up a box to control, what is the x and y-dimension."""
self.dimension_box = QtWidgets.QGridLayout()
self.dimension_box.addWidget(QtWidgets.QLabel("x-Dimension:"), 0, 0)
self.combo_xdim = QtWidgets.QComboBox()
self.dimension_box.addWidget(self.combo_xdim, 0, 1)
self.dimension_box.addWidget(QtWidgets.QLabel("y-Dimension:"), 0, 2)
self.combo_ydim = QtWidgets.QComboBox()
self.dimension_box.addWidget(self.combo_ydim, 0, 3)
self.dimension_box.addWidget(QtWidgets.QLabel("x-Coordinate:"), 1, 0)
self.combo_xcoord = QtWidgets.QComboBox()
self.dimension_box.addWidget(self.combo_xcoord, 1, 1)
self.dimension_box.addWidget(QtWidgets.QLabel("y-Coordinate:"), 1, 2)
self.combo_ycoord = QtWidgets.QComboBox()
self.dimension_box.addWidget(self.combo_ycoord, 1, 3)
self.combo_xdim.currentTextChanged.connect(self.set_xcoord)
self.combo_ydim.currentTextChanged.connect(self.set_ycoord)
for combo in self.coord_combos:
combo.currentIndexChanged.connect(self.trigger_refresh)
[docs]
def set_xcoord(self, text: str) -> None:
"""Set the name of the x-coordinate."""
self.set_combo_text(self.combo_xcoord, text)
[docs]
def set_ycoord(self, text: str) -> None:
"""Set the name of the y-coordinate."""
self.set_combo_text(self.combo_ycoord, text)
[docs]
def set_combo_text(self, combo: QtWidgets.QComboBox, text: str) -> None:
"""Convenience function to update set the current text of a combobox.
Parameters
----------
combo: PyQt5.QtWidgets.QComboBox
The combobox to modify
text: str
The item to use"""
items = list(map(combo.itemText, range(combo.count())))
if text in items:
combo.setCurrentIndex(items.index(text))
[docs]
def init_dims(
self, var: DataArray
) -> Dict[Union[Hashable, str, Any], Any]:
"""Get the formatoptions for a new plot.
This method updates the coordinates combo boxes with the
x- and y-coordinate of the variable.
Parameters
----------
var: xarray.Variable
The variable in the base dataset
Returns
-------
dict
A mapping from formatoption or dimension to the corresponding value
for the plotmethod.
"""
ret = super().init_dims(var)
dims: Dict[Hashable, Union[int, slice]] = {}
xdim = ydim = None
if self.combo_xdim.currentIndex():
xdim = self.combo_xdim.currentText()
if xdim in var.dims:
dims[xdim] = slice(None)
if self.combo_ydim.currentIndex():
ydim = self.combo_ydim.currentText()
if ydim in var.dims:
dims[ydim] = slice(None)
if dims:
missing = [dim for dim in var.dims if dim not in dims]
for dim in missing:
dims[dim] = 0
if len(dims) == 1 and xdim != ydim:
if xdim is None:
xdim = missing[-1]
else:
ydim = missing[-1]
dims[missing[-1]] = slice(None) # keep the last dimension
ret["dims"] = dims
if self.combo_xcoord.currentIndex():
xcoord = self.combo_xcoord.currentText()
ret["decoder"] = {"x": {xcoord}}
if self.combo_ycoord.currentIndex():
ycoord = self.combo_ycoord.currentText()
ret.setdefault("decoder", {})
ret["decoder"]["y"] = {ycoord}
if (
xdim is not None
and xdim in var.dims
and ydim is not None
and ydim in var.dims
):
ret["transpose"] = var.dims.index(xdim) < var.dims.index(ydim)
return ret
[docs]
def valid_variables(self, ds: Dataset) -> List[Hashable]:
"""Get a list of variables that can be visualized with this plotmethod.
Parameters
----------
ds: xarray.Dataset
The dataset to use
Returns
-------
list of str
List of variable names to plot
"""
valid = super().valid_variables(ds)
if not any(combo.count() for combo in self.coord_combos) or not any(
combo.currentIndex() for combo in self.coord_combos
):
return valid
if self.combo_xdim.currentIndex():
xdim = self.combo_xdim.currentText()
valid = [v for v in valid if xdim in ds[v].dims]
if self.combo_ydim.currentIndex():
ydim = self.combo_xdim.currentText()
valid = [v for v in valid if ydim in ds[v].dims]
if self.combo_xcoord.currentIndex():
xc_dims = set(ds[self.combo_xcoord.currentText()].dims)
valid = [v for v in valid if xc_dims.intersection(ds[v].dims)]
if self.combo_ycoord.currentIndex():
yc_dims = set(ds[self.combo_ycoord.currentText()].dims)
valid = [v for v in valid if yc_dims.intersection(ds[v].dims)]
return valid
@property
def coord_combos(self) -> List[QtWidgets.QComboBox]:
"""Get the combo boxes for x- and y-dimension and -coordinates."""
return [
self.combo_xdim,
self.combo_ydim,
self.combo_xcoord,
self.combo_ycoord,
]
[docs]
@contextlib.contextmanager
def block_combos(self) -> Iterator[None]:
"""Temporarilly block any signal of the :attr:`coord_combos`."""
for combo in self.coord_combos:
combo.blockSignals(True)
yield
for combo in self.coord_combos:
combo.blockSignals(False)
[docs]
def setEnabled(self, b: bool) -> None:
"""Enable or disable the projection and color buttons.
Parameters
----------
b: bool
If True, enable the buttons, else disable.
"""
self.btn_proj_settings.setEnabled(b)
self.proj_settings_action.setEnabled(b)
self.btn_datagrid.setEnabled(b)
self.color_settings_action.setEnabled(b)
self.btn_cmap_settings.setEnabled(b)
self.btn_labels.setEnabled(b)
self.btn_fix_extent.setEnabled(b)
self.btn_fix_lonlatbox.setEnabled(b)
self.btn_reset_extent.setEnabled(b)
[docs]
def set_cmap(self, cmap: str) -> None:
"""Update the plotter with the given colormap.
Parameters
----------
cmap: str
The colormap name.
"""
plotter = self.plotter
if plotter and "cmap" in plotter:
plotter.update(cmap=cmap)
[docs]
def toggle_datagrid(self) -> None:
"""Toggle the visibility of the grid cell boundaries."""
if self.plotter:
if self.btn_datagrid.isChecked():
self.plotter.update(datagrid="k-")
else:
self.plotter.update(datagrid=None)
[docs]
def edit_labels(self) -> None:
"""Open the dialog to edit the text labels in the plot."""
dialogs.LabelDialog.update_project(
self.sp, "figtitle", "title", "clabel"
)
[docs]
def edit_color_settings(self) -> None:
"""Open the dialog to edit the color settings."""
if self.sp and self.plotter:
dialogs.CmapDialog.update_project(self.sp)
if isinstance(self.plotter.cmap.value, str):
self.btn_cmap.setText(self.plotter.cmap.value)
else:
self.btn_cmap.setText("Custom")
[docs]
def choose_next_projection(self) -> None:
"""Choose the next projection from the rcParams `projection` value."""
select = False
nprojections = len(rcParams["projections"])
current = self.get_projection_value(self.btn_proj.text())
for i, proj in enumerate(cycle(rcParams["projections"])):
if proj == current:
select = True
elif select or i == nprojections:
break
self.set_projection(proj)
[docs]
def set_projection(self, proj: str) -> None:
"""Update the projection of the plot with the given projection.
Parameters
----------
projection: str
The projection name for the
:attr:`~psy_maps.plotters.FieldPlotter.projection` formatoption.
"""
self.btn_proj.setText(self.get_projection_label(proj))
plotter = self.plotter
if plotter and "projection" in plotter:
plotter.update(projection=proj)
def _set_plot_type(self, i: int) -> None:
"""Set the plot type from the index in :attr:`combo_plot`."""
self.set_plot_type(self.plot_types[i])
[docs]
def set_plot_type(self, plot_type: Optional[PlotType]) -> None:
"""Update the value for the ``plot`` formatoption
Parameters
----------
ploy_type: :class:`PlotType` or None
The value for the :attr:`~psy_simple.plotters.Simple2DPlotter.plot`
formatoption.
"""
plotter = self.plotter
if plotter:
plotter.update(plot=plot_type)
[docs]
def edit_basemap_settings(self) -> None:
"""Open a dialog to edit the basemap and projection settings."""
if self.plotter:
dialogs.BasemapDialog.update_plotter(self.plotter)
[docs]
def get_fmts(
self, var: DataArray, init: bool = False
) -> Dict[Union[Hashable, str, Any], Any]:
"""Get the formatoptions for a new plot.
Parameters
----------
var: xarray.Variable
The variable in the base dataset
init: bool
If True, call the :meth:`init_dims` method to inject necessary
formatoptions and dimensions for the initialization.
Returns
-------
dict
A mapping from formatoption or dimension to the corresponding value
for the plotmethod.
"""
fmts: Dict[Union[Hashable, str, Any], Any] = {}
fmts["cmap"] = self.btn_cmap.text()
if "projection" in self.formatoptions:
fmts["projection"] = self.get_projection_value(
self.btn_proj.text()
)
if "time" in var.dims:
fmts["title"] = "%(time)s"
if "long_name" in var.attrs:
fmts["clabel"] = "%(long_name)s"
else:
fmts["clabel"] = "%(name)s"
if "units" in var.attrs:
fmts["clabel"] += " %(units)s"
fmts["plot"] = self.plot_types[self.combo_plot.currentIndex()]
if fmts["plot"] == "contour":
# we need to set a global map extend, see
# https://github.com/SciTools/cartopy/issues/1673
fmts["map_extent"] = "global"
if init:
fmts.update(self.init_dims(var))
return fmts
[docs]
def refresh(self, ds: Optional[Dataset]) -> None:
"""Refresh this widget from the given dataset."""
self.setEnabled(bool(self.sp))
auto = "Set automatically"
self.refresh_from_sp()
with self.block_combos():
if ds is None:
ds = xr.Dataset()
current_dims = set(
map(
self.combo_xdim.itemText, range(1, self.combo_xdim.count())
)
)
ds_dims = list(
map(str, (dim for dim, n in ds.dims.items() if n > 1))
)
if current_dims != set(ds_dims):
self.combo_xdim.clear()
self.combo_ydim.clear()
self.combo_xdim.addItems([auto] + ds_dims)
self.combo_ydim.addItems([auto] + ds_dims)
current_coords = set(
map(
self.combo_xcoord.itemText,
range(1, self.combo_xcoord.count()),
)
)
ds_coords = list(
map(str, (c for c, arr in ds.coords.items() if arr.ndim))
)
if current_coords != set(ds_coords):
self.combo_xcoord.clear()
self.combo_ycoord.clear()
self.combo_xcoord.addItems([auto] + ds_coords)
self.combo_ycoord.addItems([auto] + ds_coords)
enable_combos = not bool(self.sp)
if not enable_combos and self.combo_xdim.isEnabled():
self.reset_combos = [
combo.currentIndex() == 0 for combo in self.coord_combos
]
elif enable_combos and not self.combo_xdim.isEnabled():
for reset, combo in zip(self.reset_combos, self.coord_combos):
if reset:
combo.setCurrentIndex(0)
self.reset_combos = [False] * len(self.coord_combos)
for combo in self.coord_combos:
combo.setEnabled(enable_combos)
if not enable_combos:
data = self.data
xdim = str(data.psy.get_dim("x"))
ydim = str(data.psy.get_dim("y"))
self.combo_xdim.setCurrentText(xdim)
self.combo_ydim.setCurrentText(ydim)
xcoord = data.psy.get_coord("x")
xcoord = xcoord.name if xcoord is not None else xdim
ycoord = data.psy.get_coord("y")
ycoord = ycoord.name if ycoord is not None else ydim
self.combo_xcoord.setCurrentText(xcoord)
self.combo_ycoord.setCurrentText(ycoord)
[docs]
def refresh_from_sp(self) -> None:
"""Refresh this widget from the plotters project."""
plotter = self.plotter
if plotter:
if isinstance(plotter.projection.value, str):
self.btn_proj.setText(
self.get_projection_label(plotter.projection.value)
)
else:
self.btn_proj.setText("Custom")
if isinstance(plotter.cmap.value, str):
self.btn_cmap.setText(plotter.cmap.value)
else:
self.btn_cmap.setText("Custom")
[docs]
def transform(self, x: float, y: float) -> Tuple[float, float]:
"""Transform coordinates of a point to the plots projection.
Parameters
----------
x: float
The x-coordinate in axes coordinates
y: float
The y-coordinate in axes coordinates
Returns
-------
float
The transformed x-coordinate `x`
float
The transformed y-coordinate `y`
"""
import cartopy.crs as ccrs
import numpy as np
plotter = self.plotter
if not plotter:
raise ValueError(
"Cannot transform the coordinates as no plot is shown "
"currently!"
)
x, y = plotter.transform.projection.transform_point(
x, y, plotter.ax.projection
)
# shift if necessary
if isinstance(plotter.transform.projection, ccrs.PlateCarree):
coord = plotter.plot.xcoord
if coord.min() >= 0 and x < 0:
x -= 360
elif coord.max() <= 180 and x > 180:
x -= 360
if "rad" in coord.attrs.get("units", "").lower():
x = np.deg2rad(x)
y = np.deg2rad(y)
return x, y
[docs]
def get_slice(
self, x: float, y: float
) -> Optional[Dict[Hashable, Union[int, slice]]]:
"""Get the slice for the selected coordinates.
This method is called when the user clicks on the coordinates in the
plot.
See Also
--------
psy_view.ds_widget.DatasetWidget.display_line
Notes
-----
This is reimplemented in the :class:`MapPlotWidget`.
"""
import numpy as np
data = self.data.psy.base.psy[self.data.name]
x, y = self.transform(x, y)
plotter = self.plotter
if not plotter:
raise ValueError(
"Cannot transform the coordinates as no plot is shown "
"currently!"
)
fmto = plotter.plot
xcoord = fmto.xcoord
ycoord = fmto.ycoord
if fmto.decoder.is_unstructured(fmto.raw_data) or xcoord.ndim == 2:
xy = xcoord.values.ravel() + 1j * ycoord.values.ravel()
dist = np.abs(xy - (x + 1j * y))
imin = np.nanargmin(dist)
if xcoord.ndim == 2:
ncols = data.shape[-2]
return dict(zip(data.dims[-2:], [imin // ncols, imin % ncols]))
else:
return {data.dims[-1]: imin}
else:
ix: int = xcoord.indexes[xcoord.name].get_loc(x, method="nearest")
iy: int = ycoord.indexes[ycoord.name].get_loc(y, method="nearest")
return dict(zip(data.dims[-2:], [iy, ix]))
[docs]
class Plot2DWidget(MapPlotWidget):
"""A widget to control the plot2d plotmethod."""
plotmethod = "plot2d"
@property
def fmt_setup_functions(self) -> List[Callable]:
"""Get a list of rows for formatoptions.
This property returns a list of callable. Each callable should setup
a horizonal (or widget) that is added to the :attr:`layout` vbox.
"""
ret = super().fmt_setup_functions
ret.remove(self.setup_projection_buttons)
idx_extent = ret.index(self.setup_fix_extent_buttons)
# remove the separator and the function
del ret[idx_extent - 1]
del ret[idx_extent - 1]
return ret
[docs]
def setEnabled(self, b: bool) -> None:
"""Enable or disable the datagrid and color buttons.
Parameters
----------
b: bool
If True, enable the buttons, else disable.
"""
self.btn_datagrid.setEnabled(b)
self.btn_cmap_settings.setEnabled(b)
self.btn_labels.setEnabled(b)
[docs]
def edit_labels(self) -> None:
"""Open the dialog to edit the text labels in the plot."""
dialogs.LabelDialog.update_project(
self.sp, "figtitle", "title", "clabel", "xlabel", "ylabel"
)
[docs]
def transform(self, x: float, y: float) -> Tuple[float, float]:
"""Reimplemented to not transform the coordinates."""
return x, y
[docs]
def refresh_from_sp(self) -> None:
"""Refresh this widget from the plotters project."""
plotter = self.plotter
if plotter:
if isinstance(plotter.cmap.value, str):
self.btn_cmap.setText(plotter.cmap.value)
[docs]
class LinePlotWidget(PlotMethodWidget):
"""A widget to control the lineplot plotmethod."""
plotmethod = "lineplot"
@property
def fmt_setup_functions(self) -> List[Callable]:
return [
self.setup_dimension_combo,
self.setup_line_selection,
self.setup_labels_button,
]
[docs]
def get_rows(self, func: Callable) -> List[List[GridCell]]:
"""Get the rows of the formatoption widgets.
This method should take callable from the :attr:`fmt_setup_functions`
list and return the rows corresponding to :attr:`formatoption_rows`.
"""
if func == self.setup_dimension_combo:
row = [
GridCell(QtWidgets.QLabel("x-Dimension:")),
GridCell(self.combo_dims),
]
elif func == self.setup_line_selection:
row = [
GridCell(QtWidgets.QLabel("Active line:")),
GridCell(self.combo_lines),
GridCell(self.btn_add),
GridCell(self.btn_del),
]
elif func == self.setup_labels_button:
row = [GridCell(self.btn_labels, colspan=4)]
else:
raise ValueError(f"Unknown function {func}")
return [row]
[docs]
def setup_dimension_combo(self) -> None:
"""Set up the combobox for selecting the x-dimension."""
self.combo_dims = QtWidgets.QComboBox()
self.combo_dims.setEditable(False)
self.combo_dims.currentIndexChanged.connect(self.trigger_reset)
self.combo_dims.setSizePolicy(
QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
)
[docs]
def setup_line_selection(self) -> None:
"""Setup the row for selecting new lines."""
self.combo_lines = QtWidgets.QComboBox()
self.combo_lines.setEditable(False)
self.combo_lines.setSizePolicy(
QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
)
self.combo_lines.currentIndexChanged.connect(self.trigger_refresh)
self.btn_add = utils.add_pushbutton(
QtGui.QIcon(get_psy_icon("plus")),
lambda: self.add_line(),
"Add a line to the plot",
icon=True,
)
self.btn_del = utils.add_pushbutton(
QtGui.QIcon(get_psy_icon("minus")),
self.remove_line,
"Add a line to the plot",
icon=True,
)
[docs]
def setup_labels_button(self) -> None:
"""Add a button to modify the text labels."""
self.btn_labels = utils.add_pushbutton(
"Edit labels",
self.edit_labels,
"Edit title, x-label, legendlabels, etc.",
)
@property
def xdim(self) -> str:
"""Get the x-dimension for the plot."""
return self.combo_dims.currentText()
@xdim.setter
def xdim(self, xdim: Hashable) -> None:
if xdim != self.combo_dims.currentText():
self.combo_dims.setCurrentText(str(xdim))
@property
def data(self) -> DataArray:
"""The first array in the list."""
data = super().data
if len(data) - 1 < self.combo_lines.currentIndex():
return data[0]
else:
return data[self.combo_lines.currentIndex()]
[docs]
def add_line(self, name: Hashable = None, **sl) -> None:
"""Add a line to the plot.
This method adds a new line for the plot specified by the given
`name` of the variable and the slices.
Parameters
----------
name: str
The variable name to display
``**sl``
The slices to turn the `name` variable into a 1D-array.
"""
if not self.sp:
raise ValueError("No plot has yet been initialized!")
ds = self.data.psy.base
xdim = self.xdim
if name is None:
name, ok = QtWidgets.QInputDialog.getItem(
self,
"New line",
"Select a variable",
[key for key, var in ds.items() if xdim in var.dims],
)
if not ok:
return
arr = ds.psy[name]
for key, val in self.data.psy.idims.items():
if key in arr.dims:
sl.setdefault(key, val)
for dim in arr.dims:
if dim != xdim:
sl.setdefault(dim, 0)
self.sp[0].append(arr.psy[sl], new_name=True)
item = self.item_texts[-1]
self.sp.update(replot=True)
self.combo_lines.addItem(item)
self.combo_lines.setCurrentText(item)
self.trigger_refresh()
[docs]
def remove_line(self) -> None:
"""Remove the current line from the plot."""
if not self.sp:
raise ValueError(
"No plot has yet been initialized, so I cannot remove any line!"
)
i = self.combo_lines.currentIndex()
self.sp[0].pop(i)
self.sp.update(replot=True)
self.combo_lines.setCurrentText(self.item_texts[i - 1 if i else 0])
self.changed.emit(self.plotmethod)
@property
def item_texts(self) -> List[str]:
"""Get the labels for the individual lines."""
if not self.sp:
return []
return [
f"Line {i}: {arr.psy._short_info()}"
for i, arr in enumerate(self.sp[0])
]
[docs]
def init_dims(
self, var: DataArray
) -> Dict[Union[Hashable, str, Any], Any]:
"""Get the formatoptions for a new plot.
Parameters
----------
var: xarray.Variable
The variable in the base dataset
Returns
-------
dict
A mapping from formatoption or dimension to the corresponding value
for the plotmethod.
"""
ret: Dict[Union[Hashable, str, Any], Any] = {}
xdim: Union[None, Hashable, str] = self.xdim or next(
(d for d in var.dims if var[d].size > 1), None # type: ignore
)
if self.array_info:
arr_names = {}
for arrname, d in self.array_info.items():
if arrname != "attrs":
dims = d["dims"].copy()
if xdim in dims:
for dim, sl in dims.items():
if not isinstance(sl, int):
dims[dim] = 0
dims[xdim] = slice(None)
dims["name"] = d["name"]
arr_names[arrname] = dims
ret["arr_names"] = arr_names
del self.array_info
else:
if xdim not in var.dims:
xdim = next((d for d in var.dims if var[d].size > 1), None)
if xdim is None:
raise ValueError(
f"Cannot plot variable {var.name} with size smaller than "
"2"
)
ret[xdim] = slice(None)
for d in var.dims:
if d != xdim:
ret[d] = 0
return ret
[docs]
def edit_labels(self) -> None:
"""Open the dialog to edit the text labels in the plot."""
dialogs.LabelDialog.update_project(
self.sp, "figtitle", "title", "xlabel", "ylabel", "legendlabels"
)
[docs]
@contextlib.contextmanager
def block_combos(self) -> Iterator[None]:
"""Temporarilly block any signal of the combo boxes."""
self.combo_dims.blockSignals(True)
self.combo_lines.blockSignals(True)
yield
self.combo_dims.blockSignals(False)
self.combo_lines.blockSignals(False)
[docs]
def valid_variables(self, ds: Dataset) -> List[Hashable]:
"""Get a list of variables that can be visualized with this plotmethod.
Parameters
----------
ds: xarray.Dataset
The dataset to use
Returns
-------
list of str
List of variable names to plot
"""
valid = list(ds)
if not self.sp or len(self.sp[0]) < 2:
return valid
else:
current_dim = self.combo_dims.currentText()
return [v for v in valid if current_dim in ds[v].dims]
[docs]
def setEnabled(self, b: bool) -> None:
"""Enable or disable the datagrid and color buttons.
Parameters
----------
b: bool
If True, enable the buttons, else disable.
"""
self.btn_add.setEnabled(b)
self.btn_del.setEnabled(b and self.sp and len(self.sp[0]) > 1)
self.btn_labels.setEnabled(b)
[docs]
def refresh(self, ds) -> None:
"""Refresh this widget from the given dataset."""
self.setEnabled(bool(self.sp))
if self.sp:
with self.block_combos():
self.combo_dims.clear()
all_dims = list(
chain.from_iterable(
[
[d for i, d in enumerate(a.dims) if a.shape[i] > 1]
for a in arr.psy.iter_base_variables
]
for arr in self.sp[0]
)
)
intersection = set(all_dims[0])
for dims in all_dims[1:]:
intersection.intersection_update(dims)
new_dims = list(
filter(
lambda d: d in intersection,
unique_everseen(chain.from_iterable(all_dims)),
)
)
self.combo_dims.addItems(new_dims)
self.combo_dims.setCurrentIndex(
new_dims.index(self.data.dims[-1])
)
# fill lines combo
current = self.combo_lines.currentIndex()
self.combo_lines.clear()
descriptions = self.item_texts
short_descs = [textwrap.shorten(s, 50) for s in descriptions]
self.combo_lines.addItems(short_descs)
for i, desc in enumerate(descriptions):
self.combo_lines.setItemData(
i, desc, QtCore.Qt.ToolTipRole
)
if current < len(descriptions):
self.combo_lines.setCurrentText(short_descs[current])
else:
with self.block_combos():
self.combo_dims.clear()
self.combo_lines.clear()