Skip to content

Commit ba232cc

Browse files
Copilotxadupre
andauthored
feat: add sankey flow plotting helper and unit test
Agent-Logs-Url: https://github.com/sdpython/teachpyx/sessions/68526f45-813e-4c5a-959d-84afb8a8544a Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent cd9795d commit ba232cc

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

_unittests/ut_faq/test_faq_python.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import unittest
22
import datetime
3-
from teachpyx.faq.faq_python import get_month_name, get_day_name, class_getitem
3+
import matplotlib
4+
from teachpyx.faq.faq_python import (
5+
class_getitem,
6+
get_day_name,
7+
get_month_name,
8+
graph_sankey,
9+
)
10+
11+
matplotlib.use("Agg")
412

513

614
class TestFaqPython(unittest.TestCase):
@@ -17,6 +25,17 @@ def test_day_name(self):
1725
def test_class_getitem(self):
1826
class_getitem()
1927

28+
def test_graph_sankey(self):
29+
ax, diagrams = graph_sankey(
30+
[1, -0.25, -0.75],
31+
labels=["input", "loss", "output"],
32+
orientations=[0, 1, -1],
33+
trunklength=2.0,
34+
title="flux",
35+
)
36+
self.assertEqual(ax.get_title(), "flux")
37+
self.assertEqual(len(diagrams), 1)
38+
2039

2140
if __name__ == "__main__":
2241
unittest.main()

teachpyx/faq/faq_python.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,45 @@ def __init__(self):
850850

851851
a = A[2]()
852852
assert a.__class__.__name__ == "A2"
853+
854+
855+
def graph_sankey(
856+
flows, labels=None, orientations=None, ax=None, title=None, **kwargs
857+
):
858+
"""
859+
Draws a :epkg:`Sankey` graph to represent flows.
860+
861+
:param flows: list of positive/negative flows
862+
:param labels: labels associated to each flow
863+
:param orientations: list of orientations (-1, 0, 1)
864+
:param ax: axis to use, if None, creates one
865+
:param title: graph title
866+
:param kwargs: additional parameters forwarded to
867+
``matplotlib.sankey.Sankey.add``
868+
:return: axis, sankey diagrams
869+
"""
870+
if len(flows) < 2:
871+
raise ValueError("flows must contain at least two values.")
872+
if abs(sum(flows)) > 1e-10:
873+
raise ValueError("The sum of all flows must be 0.")
874+
if labels is None:
875+
labels = [None] * len(flows)
876+
if orientations is None:
877+
orientations = [0] * len(flows)
878+
if len(labels) != len(flows):
879+
raise ValueError("labels and flows must have the same length.")
880+
if len(orientations) != len(flows):
881+
raise ValueError("orientations and flows must have the same length.")
882+
883+
import matplotlib.pyplot as plt
884+
from matplotlib.sankey import Sankey
885+
886+
if ax is None:
887+
_, ax = plt.subplots(1, 1)
888+
889+
sankey = Sankey(ax=ax)
890+
sankey.add(flows=flows, labels=labels, orientations=orientations, **kwargs)
891+
diagrams = sankey.finish()
892+
if title is not None:
893+
ax.set_title(title)
894+
return ax, diagrams

0 commit comments

Comments
 (0)