Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions panel/simdec_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bisect
import io
import re

from bokeh.models import PrintfTickFormatter
from bokeh.models.widgets.tables import NumberFormatter
Expand All @@ -15,10 +16,8 @@
from simdec.sensitivity_indices import SensitivityAnalysisResult
from simdec.visualization import sequential_cmaps, single_color_to_colormap


# panel app
pn.extension("tabulator")
pn.extension("floatpanel")
pn.extension("tabulator", "floatpanel", notifications=True)

pn.config.sizing_mode = "stretch_width"
pn.config.throttled = True
Expand All @@ -43,34 +42,82 @@
)


def _validate_csv_bytes(raw_bytes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but why not just using a big try/except and return a generic explanation of the error if we don't want the raw message from Pandas? In the end we don't have many errors to describe.

Otherwise for the regex, we typically want to compile the regex for performance reasons. Here you also have a re-evaluation in the list comprehension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first tried it with try/excepts but was unable to do it in a way that ensures the dashboard stops the buffering icon. I'll get back to work👍

"""Pre-parse validation. Returns an error string or None."""
try:
first_line = raw_bytes.decode("utf-8").split("\n")[0].strip()
except UnicodeDecodeError:
return "File encoding error. Please use files in UTF-8."

if "," not in first_line:
detected = (
"Semicolons(';')"
if ";" in first_line
else "tabs"
if "\t" in first_line
else "Unknown delimiter"
)
return f"Wrong column delimiter {detected}. Save the data with commas ',' as the delimiter"

col_names = [c.strip().strip('"').strip("'") for c in first_line.split(",")]
bad_cols = [c for c in col_names if re.search(r"[^A-Za-z0-9_ \-.]", c)]
if bad_cols:
return (
f"Special characters found in column name(s): {bad_cols}."
f"Column names may contain only letters, numbers and underscores."
f"Please rename columns {bad_cols} before uploading data again."
)
return None


@pn.cache
def load_data(text_fname):
if text_fname is None:
text_fname = "tests/data/stress.csv"
else:
text_fname = io.BytesIO(text_fname)
return pd.read_csv("tests/data/stress.csv")

raw_bytes = bytes(text_fname)

data = pd.read_csv(text_fname)
return data
# Run pre-validation
error = _validate_csv_bytes(raw_bytes)
if error:
pn.state.notifications.error(error, duration=0)
return None

# Try parsing
try:
text_fname = io.BytesIO(text_fname)
return pd.read_csv(text_fname)
except Exception as e:
pn.state.notifications.error(f"Could not parse CSV {e}.", duration=0)
return None


@pn.cache
def column_inputs(data, output):
if data is None:
return []
inputs = list(data.columns)
inputs.remove(output)
if output in inputs:
inputs.remove(output)
return inputs


@pn.cache
def column_output(data):
if data is None:
return []
return list(data.columns)


@pn.cache
def filtered_data(data, output_name):
if data is None or not output_name:
return pd.Series(dtype=float)
try:
return data[output_name]
except KeyError:
if isinstance(output_name, list):
return data.iloc[:, [0]]
return data.iloc[:, 0]


Expand Down Expand Up @@ -350,7 +397,7 @@ def csv_data(

interactive_column_output = pn.bind(column_output, interactive_file)
# hack to make the default selection faster
interactive_output_ = pn.bind(lambda x: x[0], interactive_column_output)
interactive_output_ = pn.bind(lambda x: x[0] if x else None, interactive_column_output)
selector_output = pn.widgets.Select(
name="Output", value=interactive_output_, options=interactive_column_output
)
Expand Down
79 changes: 78 additions & 1 deletion src/simdec/visualization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import functools
import itertools
from typing import Literal
from typing import Literal, Optional

import colorsys
import matplotlib as mpl
Expand Down Expand Up @@ -135,17 +135,25 @@ def palette(
def visualization(
*,
bins: pd.DataFrame,
bins2: Optional[pd.DataFrame] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this in its own function. A good argument for that is that we have 2 very distinct code paths in a conditional and no shared state between. That means we have 2 different semantics. I would make another function for the 2 output part and if we really wanted to have a single function, then have a higher level wrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good👍 I'll make it into a new function.

palette: list[list[float]],
n_bins: str | int = "auto",
kind: Literal["histogram", "boxplot"] = "histogram",
ax=None,
output_name: str = "Output 1",
output_name2: str = "Output 2",
xlim: Optional[tuple[float, float]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On optionals, prefer this syntax now xlim: tuple[float, float] | None = None,

ylim: Optional[tuple[float, float]] = None,
r_scatter: float = 1.0,
) -> plt.Axes:
"""Histogram plot of scenarios.

Parameters
----------
bins : DataFrame
Multidimensional bins.
bins2 : DataFrame
Multidimensional bins for output 2
palette : list of int of size (n, 4)
List of colours corresponding to scenarios.
n_bins : str or int
Expand All @@ -154,16 +162,85 @@ def visualization(
Histogram or Box Plot.
ax : Axes, optional
Matplotlib axis.
output_name : str, default "Output 1"
Name of the primary output variable.
output_name2 : str, default "Output 2"
Name of the second output variable.
xlim : tuple of float, optional
Minimum and maximum values for the x-axis (Output 1).
ylim : tuple of float, optional
Minimum and maximum values for the y-axis (Output 2).
Comment on lines +165 to +172
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me these are more to set on when working on the figure. This function is supposed to be a minimal block to build something on top as it takes in an ax and returns and ax.

r_scatter : float, default 1.0
The portion of data points displayed on the scatter plot (0 to 1).

Returns
-------
axs : Axes
Matplotlib axis for two-output graph.
ax : Axes
Matplotlib axis.

"""
# needed to get the correct stacking order
bins.columns = pd.RangeIndex(start=len(bins.columns), stop=0, step=-1)

if bins2 is not None:
fig, axs = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(8, 8))
axs[0, 1].axis("off")

sns.histplot(
bins,
multiple="stack",
stat="probability",
palette=palette,
common_bins=True,
common_norm=True,
bins=n_bins,
legend=False,
ax=axs[0, 0],
)
axs[0, 0].set_xlim(xlim)
axs[0, 0].set_box_aspect(1)
axs[0, 0].axis("off")

data = pd.concat([pd.melt(bins), pd.melt(bins2)["value"]], axis=1)
data.columns = ["c", "x", "y"]

if r_scatter < 1.0:
data = data.sample(frac=r_scatter)

sns.scatterplot(
data=data,
x="x",
y="y",
hue="c",
palette=palette,
ax=axs[1, 0],
legend=False,
)
axs[1, 0].set(xlabel=output_name, ylabel=output_name2)
axs[1, 0].set_box_aspect(1)

sns.histplot(
data,
y="y",
hue="c",
multiple="stack",
stat="probability",
palette=palette,
common_bins=True,
common_norm=True,
bins=40,
legend=False,
ax=axs[1, 1],
)
axs[1, 1].set_ylim(ylim)
axs[1, 1].set_box_aspect(1)
axs[1, 1].axis("off")

fig.subplots_adjust(wspace=-0.015, hspace=0)
return axs[1, 0]

if kind == "histogram":
ax = sns.histplot(
bins,
Expand Down
32 changes: 32 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import pandas as pd
import matplotlib.pyplot as plt
import simdec as sd


def test_visualization_single_output():
bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]})
palette = [[1, 0, 0, 1], [0, 1, 0, 1]]

ax = sd.visualization(bins=bins, palette=palette, kind="histogram")
assert isinstance(ax, plt.Axes)

ax_box = sd.visualization(bins=bins, palette=palette, kind="boxplot")
assert isinstance(ax_box, plt.Axes)


def test_visualization_two_outputs():
bins = pd.DataFrame({"s1": [1, 2]})
bins2 = pd.DataFrame({"s1": [5, 6]})
palette = [[1, 0, 0, 1]]

ax = sd.visualization(bins=bins, bins2=bins2, palette=palette)

assert ax.get_xlabel() == "Output 1"
assert len(ax.figure.axes) == 4


def test_visualization_invalid_kind():
bins = pd.DataFrame({"s1": [1]})
with pytest.raises(ValueError, match="'kind' can only be 'histogram' or 'boxplot'"):
sd.visualization(bins=bins, palette=[[1, 0, 0, 1]], kind="invalid")
Loading