diff --git a/docs/cookbook/global-steel-flows.csv b/docs/cookbook/global-steel-flows.csv
new file mode 100644
index 0000000..32d29f9
--- /dev/null
+++ b/docs/cookbook/global-steel-flows.csv
@@ -0,0 +1,84 @@
+source,target,value,type
+DR,EF,65.8,forward
+DR,loss1,0.5,loss
+BF,EF,44.6,forward
+BF,OBC,826.2,forward
+BF,OHF,34,forward
+BF,FIC,23.6,forward
+BF,loss1,6.8,loss
+SP,loss1,5.7,loss
+SP,EF,351,forward
+SP,OBC,206,forward
+SP,FIC,11.6,forward
+EF,SM,410.3,forward
+EF,loss2,51.2,loss
+OBC,SM,898.8,forward
+OBC,loss2,133.8,loss
+OHF,SM,29.6,forward
+OHF,loss2,4.4,loss
+SM,CC bloom,100.6,forward
+SM,CC billet,487.6,forward
+SM,CC slab,646.7,forward
+SM,ingot casting,89.1,forward
+SM,SPC,4,forward
+SM,loss3,10.5,loss
+scrap iron and additives,FIC,33.6,forward
+CC bloom,SEM,99.2,forward
+CC bloom,SPC,0.7,backward
+CC bloom,loss4,0.6,loss
+CC bloom,CC bloom,3.3,backward
+CC billet,RBM,442.2,forward
+CC billet,HSM,41.8,forward
+CC billet,SP,2,backward
+ingot casting,SPC,6.7,backward
+ingot casting,PRM,82,forward
+ingot casting,loss4,0.5,loss
+ingot casting,ingot casting,1.4,backward
+SPC,cast steel/iron,10.5,forward
+SPC,loss4,0.1,loss
+SPC,SPC,9.6,backward
+FIC,cast steel/iron,68.3,forward
+FIC,loss4,0.5,loss
+FIC,FIC,34.7,backward
+PRM,SEM,5.2,forward
+PRM,RBM,49.1,forward
+PRM,PLM,6.9,forward
+PRM,HSM,14.6,forward
+PRM,SP,5.3,backward
+PRM,loss5,0.8,loss
+SEM,sections,94,forward
+SEM,SP,8.9,backward
+SEM,loss6,1.6,loss
+RBM,STP,30,forward
+RBM,bar and rod,431.8,forward
+RBM,SP,22,backward
+RBM,loss6,7.5,loss
+PLM,TWP,15,forward
+PLM,plate,110,forward
+PLM,SP,12.3,backward
+PLM,loss6,1.6,loss
+HSM,TWP,51.7,forward
+HSM,CRM,288.1,forward
+HSM,GP,10.3,forward
+HSM,hot rolled coil,190,forward
+CC billet,loss4,1.6,loss
+CC billet,CC billet,8.8,backward
+CC slab,PLM,132,forward
+CC slab,HSM,508,forward
+CC slab,SP,3.7,backward
+CC slab,loss4,3,loss
+CC slab,CC slab,16.5,backward
+TWP,tube,62.4,forward
+TWP,SP,4.3,backward
+CRM,cold rolled coil,145.3,forward
+CRM,GP,116.1,forward
+HSM,SP,18.8,backward
+HSM,loss6,5.6,loss
+STP,tube,27.7,forward
+STP,SP,2.3,backward
+CRM,SP,14.2,backward
+GP,cold rolled coil,113.2,forward
+GP,hot rolled coil,10,forward
+GP,SP,3.2,backward
+TM,cold rolled coil,11.6,forward
+TM,SP,0.8,backward
diff --git a/docs/cookbook/global-steel-processes.csv b/docs/cookbook/global-steel-processes.csv
new file mode 100644
index 0000000..d359301
--- /dev/null
+++ b/docs/cookbook/global-steel-processes.csv
@@ -0,0 +1,38 @@
+id,layer,band,type
+DR,0,1,
+BF,0,1,
+SP,0,1,
+EF,1,1,
+OBC,1,1,
+OHF,1,1,
+SM,2,1,
+CC bloom,3,1,CC
+CC billet,3,1,CC
+CC slab,3,1,CC
+ingot casting,3,1,other
+SPC,3,1,other
+FIC,3,1,other
+PRM,4,1,
+SEM,5,1,
+RBM,5,1,
+PLM,5,1,
+HSM,5,1,
+STP,6,1,
+TWP,6,1,
+CRM,6,1,
+GP,7,1,
+TM,7,1,
+sections,8,1,t/p/s
+tube,8,1,t/p/s
+bar and rod,8,1,
+plate,8,1,t/p/s
+cold rolled coil,8,1,rolled coil
+hot rolled coil,8,1,rolled coil
+cast steel/iron,8,1,
+loss1,1,0,
+loss2,2,0,
+loss3,3,0,
+loss4,4,0,
+loss5,5,0,
+loss6,6,0,
+scrap iron and additives,2,0,
diff --git a/docs/cookbook/hybrid-sankey-diagrams-paper-fruit-example.ipynb b/docs/cookbook/hybrid-sankey-diagrams-paper-fruit-example.ipynb
index 5faac35..7d1caea 100644
--- a/docs/cookbook/hybrid-sankey-diagrams-paper-fruit-example.ipynb
+++ b/docs/cookbook/hybrid-sankey-diagrams-paper-fruit-example.ipynb
@@ -13,10 +13,12 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
+ "from attr import evolve\n",
+ "import pandas as pd\n",
"from floweaver import *"
]
},
@@ -29,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -45,9 +47,96 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " source | \n",
+ " target | \n",
+ " material | \n",
+ " time | \n",
+ " value | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " farm1 | \n",
+ " eat1 | \n",
+ " apples | \n",
+ " 2011-08-01 | \n",
+ " 2.720691 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " eat1 | \n",
+ " landfill Cambridge | \n",
+ " apples | \n",
+ " 2011-08-01 | \n",
+ " 1.904484 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " eat1 | \n",
+ " composting Cambridge | \n",
+ " apples | \n",
+ " 2011-08-01 | \n",
+ " 0.816207 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " farm1 | \n",
+ " eat1 | \n",
+ " apples | \n",
+ " 2011-08-02 | \n",
+ " 8.802195 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " eat1 | \n",
+ " landfill Cambridge | \n",
+ " apples | \n",
+ " 2011-08-02 | \n",
+ " 6.161537 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " source target material time value\n",
+ "0 farm1 eat1 apples 2011-08-01 2.720691\n",
+ "1 eat1 landfill Cambridge apples 2011-08-01 1.904484\n",
+ "2 eat1 composting Cambridge apples 2011-08-01 0.816207\n",
+ "3 farm1 eat1 apples 2011-08-02 8.802195\n",
+ "4 eat1 landfill Cambridge apples 2011-08-02 6.161537"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"dataset._flows.head()"
]
@@ -61,9 +150,98 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " type | \n",
+ " location | \n",
+ " function | \n",
+ " sector | \n",
+ "
\n",
+ " \n",
+ " id | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " inputs | \n",
+ " stock | \n",
+ " * | \n",
+ " inputs | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " farm1 | \n",
+ " process | \n",
+ " Cambridge | \n",
+ " small farm | \n",
+ " farming | \n",
+ "
\n",
+ " \n",
+ " farm2 | \n",
+ " process | \n",
+ " Cambridge | \n",
+ " small farm | \n",
+ " farming | \n",
+ "
\n",
+ " \n",
+ " farm3 | \n",
+ " process | \n",
+ " Ely | \n",
+ " small farm | \n",
+ " farming | \n",
+ "
\n",
+ " \n",
+ " farm4 | \n",
+ " process | \n",
+ " Ely | \n",
+ " allotment | \n",
+ " farming | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " type location function sector\n",
+ "id \n",
+ "inputs stock * inputs NaN\n",
+ "farm1 process Cambridge small farm farming\n",
+ "farm2 process Cambridge small farm farming\n",
+ "farm3 process Ely small farm farming\n",
+ "farm4 process Ely allotment farming"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"dataset._dim_process.head()"
]
@@ -77,7 +255,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -100,7 +278,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -130,7 +308,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -152,7 +330,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -177,22 +355,37 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6b2016bccf944911b2d938601ab28ada",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': '__w2_compost_0', 'type': 'group', 'title': '', 'nodes': ['__w2_compost_0^*']}, {'…"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"sdd = SankeyDefinition(nodes, bundles, ordering,\n",
" flow_partition=dataset.partition('material'))\n",
- "weave(sdd, dataset) \\\n",
- " .to_widget(width=570, height=550, margins=dict(left=70, right=90))"
+ "sankey_data.to_widget(width=570, height=550, margins=dict(left=70, right=90))"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "venvsankey3.10.9",
"language": "python",
- "name": "python3"
+ "name": "venvsankey3.10.9"
},
"language_info": {
"codemirror_mode": {
@@ -204,9 +397,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.3"
+ "version": "3.10.9"
}
},
"nbformat": 4,
- "nbformat_minor": 1
+ "nbformat_minor": 4
}
diff --git a/docs/cookbook/layout-optimisation-2.ipynb b/docs/cookbook/layout-optimisation-2.ipynb
new file mode 100644
index 0000000..38c6553
--- /dev/null
+++ b/docs/cookbook/layout-optimisation-2.ipynb
@@ -0,0 +1,234 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "1a72f1bd-c89a-4131-a229-e3911c603838",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from attr import evolve\n",
+ "import pandas as pd\n",
+ "from floweaver import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "9dc2965f-d3f5-4de7-9467-a18d9c9e1cef",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = Dataset.from_csv('fruit_flows.csv', 'fruit_processes.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "3e7f8ceb-6cf2-4efb-8cbc-71f4588ed9cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "farm_ids = ['farm{}'.format(i) for i in range(1, 16)]\n",
+ "\n",
+ "farm_partition_5 = Partition.Simple('process', [('Other farms', farm_ids[5:])] + farm_ids[:5])\n",
+ "partition_fruit = Partition.Simple('material', ['bananas', 'apples', 'oranges'])\n",
+ "partition_sector = Partition.Simple('process.sector', ['government', 'industry', 'domestic'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "4e3f75d8-ac18-4658-8a87-3743a2aae488",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nodes = {\n",
+ " 'inputs': ProcessGroup(['inputs'], title='Inputs'),\n",
+ " 'compost': ProcessGroup('function == \"composting stock\"', title='Compost'),\n",
+ " 'farms': ProcessGroup('function in [\"allotment\", \"large farm\", \"small farm\"]', farm_partition_5),\n",
+ " 'eat': ProcessGroup('function == \"consumers\" and location != \"London\"', partition_sector,\n",
+ " title='consumers by sector'),\n",
+ " 'landfill': ProcessGroup('function == \"landfill\" and location != \"London\"', title='Landfill'),\n",
+ " 'composting': ProcessGroup('function == \"composting process\" and location != \"London\"', title='Composting'),\n",
+ "\n",
+ " 'fruit': Waypoint(partition_fruit, title='fruit type'),\n",
+ " 'w1': Waypoint(direction='L', title=''),\n",
+ " 'w2': Waypoint(direction='L', title=''),\n",
+ " 'export fruit': Waypoint(Partition.Simple('material', ['apples', 'bananas', 'oranges'])),\n",
+ " 'exports': Waypoint(title='Exports'),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "d745d2c7-fd7f-4a78-8d5e-13ea4ea304e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ordering = [\n",
+ " [[], ['inputs', 'compost'], []],\n",
+ " [[], ['farms'], ['w2']],\n",
+ " [['exports'], ['fruit'], []],\n",
+ " [[], ['eat'], []],\n",
+ " [['export fruit'], ['landfill', 'composting'], ['w1']],\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ddc41c09-4de7-45de-9399-098338272673",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bundles = [\n",
+ " Bundle('inputs', 'farms'),\n",
+ " Bundle('compost', 'farms'),\n",
+ " Bundle('farms', 'eat', waypoints=['fruit']),\n",
+ " Bundle('farms', 'compost', waypoints=['w2']),\n",
+ " Bundle('eat', 'landfill'),\n",
+ " Bundle('eat', 'composting'),\n",
+ " Bundle('composting', 'compost', waypoints=['w1', 'w2']),\n",
+ " Bundle('farms', Elsewhere, waypoints=['exports', 'export fruit']),\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "4dcbf636-162a-45f1-a064-35509036e97d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "74732ba63d4c4f5893d0900b3f039210",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': '__w2_compost_0', 'type': 'group', 'title': '', 'nodes': ['__w2_compost_0^*']}, {'…"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sdd = SankeyDefinition(nodes, bundles, ordering,\n",
+ " flow_partition=dataset.partition('material'))\n",
+ "sankey_data = weave(sdd, dataset)\n",
+ "sankey_data.to_widget(width=700, height=450, margins=dict(left=70, right=90))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "5982f210-f3e4-401c-96a4-55499d2b04dd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10.3086576461792\n"
+ ]
+ }
+ ],
+ "source": [
+ "sankey_data_evolved = optimise_node_order(sankey_data, group_nodes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "0dfa2a67-429a-423a-822a-862aee5d2548",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fe4cafd87eeb49f1b5f1ff56862c2849",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(SankeyWidget(groups=[{'id': '__w2_compost_0', 'type': 'group', 'title': '', 'nodes': ['__w2_com…"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(width=700, height=450, margins=dict(left=100, right=120), debugging=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "1be49fec-0a87-40b1-9e0d-47db9219e532",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3.057634115219116\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9a04b5a3ba154477ac713b436c517016",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': '__w2_compost_0', 'type': 'group', 'title': '', 'nodes': ['__w2_compost_0^*']}, {'…"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(layout=optimise_node_positions(sankey_data_evolved, scale = 0.02))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "33c745e1-a08c-43d6-94eb-dc6a6fb5e6d5",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venvsankey3.10.9",
+ "language": "python",
+ "name": "venvsankey3.10.9"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/cookbook/layout-optimisation-3.ipynb b/docs/cookbook/layout-optimisation-3.ipynb
new file mode 100644
index 0000000..a2c7891
--- /dev/null
+++ b/docs/cookbook/layout-optimisation-3.ipynb
@@ -0,0 +1,166 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8a757c91-2d0d-41a0-b7eb-7616c8c75f7e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from attr import evolve\n",
+ "import pandas as pd\n",
+ "from floweaver import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "425e31cd-2ddc-4424-9494-55d03feb8af2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "flows = pd.read_csv('global-steel-flows.csv').to_dict('records')\n",
+ "processes = pd.read_csv('global-steel-processes.csv').to_dict('records')\n",
+ "dataset = Dataset.from_csv('global-steel-flows.csv','global-steel-processes.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "c3465193-10c9-4121-9562-770086efd9de",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nodes = generate_nodes(processes)\n",
+ "ordering = generate_ordering(processes)\n",
+ "ordering,nodes,bundles = generate_waypoints_bundles(processes, flows, ordering, nodes)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "2c892e69-3f7e-4b72-85bc-8ea8c48459fa",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4c72ccafc3c64810a24b86797c7aa134",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': 'wp0', 'type': 'group', 'title': '', 'nodes': ['wp0^*']}, {'id': 'fwpBFFIC0', 'typ…"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sdd = SankeyDefinition(nodes, bundles, ordering, flow_partition=dataset.partition('type'))\n",
+ "sankey_data = weave(sdd, dataset)\n",
+ "sankey_data.to_widget(width=1000, height=550, margins=dict(left=70, right=90))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "321b0ee0-5704-43d5-8a1c-c40bb3844361",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5.4460608959198\n"
+ ]
+ }
+ ],
+ "source": [
+ "sankey_data_evolved = optimise_node_order(sankey_data, group_nodes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "78f312e7-766e-47b9-acc9-7965925b60cd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fb22f404852745aaa9abd8001d2901fe",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(SankeyWidget(groups=[{'id': 'wp0', 'type': 'group', 'title': '', 'nodes': ['wp0^*']}, {'id': 'f…"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(width=1000, height=550, margins=dict(left=100, right=120), debugging=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "73aebb9f-d4f9-4944-88d3-34f8f40ce75d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1.390228033065796\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3bb38e3f338d4e1f83c3c80abc12a660",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': 'wp0', 'type': 'group', 'title': '', 'nodes': ['wp0^*']}, {'id': 'fwpBFFIC0', 'typ…"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(layout=optimise_node_positions(sankey_data_evolved, scale = 0.16, width=1000, height=650, margins=dict(left=100, right=120)))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venvsankey3.10.9",
+ "language": "python",
+ "name": "venvsankey3.10.9"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/cookbook/layout-optimisation.ipynb b/docs/cookbook/layout-optimisation.ipynb
new file mode 100644
index 0000000..d20a977
--- /dev/null
+++ b/docs/cookbook/layout-optimisation.ipynb
@@ -0,0 +1,286 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "38a096d4-8ad1-4f2d-b8be-6affdcd71c21",
+ "metadata": {},
+ "source": [
+ "# Layout optimisation\n",
+ "\n",
+ "This example uses the same data as in the [US energy consumption example](us-energy-consumption.ipynb) to demonstrate node order and position optimisation. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "6c86326f-4c4b-4468-a583-049f503f1af7",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from attr import evolve\n",
+ "import pandas as pd\n",
+ "from floweaver import *"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "692e1ccf-c748-45b9-87f2-d7e8e4f5f021",
+ "metadata": {},
+ "source": [
+ "Load the data and set up the Sankey Diagram Definition, as in the previous example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "83d56948-5c13-4254-bf85-0c4a009627ea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = Dataset.from_csv(\"us-energy-consumption.csv\", dim_process_filename=\"us-energy-consumption-processes.csv\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f490b5b0-ad19-48fc-bef3-0456e154e864",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sources = ['Solar', 'Nuclear', 'Hydro', 'Wind', 'Geothermal',\n",
+ " 'Natural_Gas', 'Coal', 'Biomass', 'Petroleum']\n",
+ "\n",
+ "uses = ['Residential', 'Commercial', 'Industrial', 'Transportation']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9664a9bb-f21b-411d-8571-315cb58b42b2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nodes = {\n",
+ " 'sources': ProcessGroup('type == \"source\"', Partition.Simple('process', sources), title='Sources'),\n",
+ " 'imports': ProcessGroup(['Net_Electricity_Import'], title='Net electricity imports'),\n",
+ " 'electricity': ProcessGroup(['Electricity_Generation'], title='Electricity Generation'),\n",
+ " 'uses': ProcessGroup('type == \"use\"', partition=Partition.Simple('process', uses)),\n",
+ " \n",
+ " 'energy_services': ProcessGroup(['Energy_Services'], title='Energy services'),\n",
+ " 'rejected': ProcessGroup(['Rejected_Energy'], title='Rejected energy'),\n",
+ " \n",
+ " 'direct_use': Waypoint(Partition.Simple('source', [\n",
+ " # This is a hack to hide the labels of the partition, there should be a better way...\n",
+ " (' '*i, [k]) for i, k in enumerate(sources)\n",
+ " ])),\n",
+ "}\n",
+ "\n",
+ "ordering = [\n",
+ " [[], ['sources'], []],\n",
+ " [['imports'], ['electricity', 'direct_use'], []],\n",
+ " [[], ['uses'], []],\n",
+ " [[], ['rejected', 'energy_services'], []]\n",
+ "]\n",
+ "\n",
+ "bundles = [\n",
+ " Bundle('sources', 'electricity'),\n",
+ " Bundle('sources', 'uses', waypoints=['direct_use']),\n",
+ " Bundle('electricity', 'uses'),\n",
+ " Bundle('imports', 'uses'),\n",
+ " Bundle('uses', 'energy_services'),\n",
+ " Bundle('uses', 'rejected'),\n",
+ " Bundle('electricity', 'rejected'),\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "0c5cfcd8-9f3a-407f-8edf-00c9822e42de",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "palette = {\n",
+ " 'Solar': 'gold',\n",
+ " 'Nuclear': 'red',\n",
+ " 'Hydro': 'blue',\n",
+ " 'Wind': 'purple',\n",
+ " 'Geothermal': 'brown',\n",
+ " 'Natural_Gas': 'steelblue',\n",
+ " 'Coal': 'black',\n",
+ " 'Biomass': 'lightgreen',\n",
+ " 'Petroleum': 'green',\n",
+ " 'Electricity': 'orange',\n",
+ " 'Rejected energy': 'lightgrey',\n",
+ " 'Energy services': 'dimgrey',\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "a26aa4f7-e57e-48c6-b1f1-6f80ec294755",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sdd = SankeyDefinition(nodes, bundles, ordering,\n",
+ " flow_partition=dataset.partition('type'))\n",
+ "sankey_data = weave(sdd, dataset, palette=palette)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6bd354ef-b106-45ef-9333-e9abcdf2f8a7",
+ "metadata": {},
+ "source": [
+ "This is the default, un-optimised layout:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d112fb8e-73d6-4b27-a15e-df481b4d4860",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d80a44262aac424fabb551d8e63cd454",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': 'sources', 'type': 'process', 'title': 'Sources', 'nodes': ['sources^Solar', 'sour…"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data.to_widget(width=700, height=450, margins=dict(left=100, right=120))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "81d211d5-d8ee-4a3d-887c-898eb1fffc2c",
+ "metadata": {},
+ "source": [
+ "Optimise the node ordering:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4e7a3d95-6167-4175-9b34-17bd3c79ced8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5.276489734649658\n"
+ ]
+ }
+ ],
+ "source": [
+ "sankey_data_evolved = optimise_node_order(sankey_data, group_nodes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "93987910-ac87-49b3-a976-a4f132586901",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fa4fba3f80c84239a2571d6cf00119e5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(SankeyWidget(groups=[{'id': 'sources', 'type': 'process', 'title': 'Sources', 'nodes': ['source…"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(width=700, height=450, margins=dict(left=100, right=120), debugging=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9ebd8d00-bc92-4858-88e4-214f48b6c6d4",
+ "metadata": {},
+ "source": [
+ "Optimise the node positions to make flows as straight as possible:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "28392d80-9ace-47a5-ad96-5b2f43b5f2e8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.5095484256744385\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a3a825fda81e4aa9a650e572824f8e08",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "SankeyWidget(groups=[{'id': 'sources', 'type': 'process', 'title': 'Sources', 'nodes': ['sources^Solar', 'sour…"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sankey_data_evolved.to_widget(layout=optimise_node_positions(sankey_data_evolved, scale=1.5))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venvsankey3.10.9",
+ "language": "python",
+ "name": "venvsankey3.10.9"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/installation.rst b/docs/installation.rst
index bc7d741..3f09956 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -10,6 +10,12 @@ install floweaver using pip:
$ pip install floweaver
+To also install the additional dependencies for optimising diagram layout, use:
+
+.. code-block:: shell
+
+ $ pip install floweaver[mip]
+
If you use Jupyter notebooks -- a good way to get started -- you will also want
to install `ipysankeywidget `_,
an IPython widget to interactively display Sankey diagrams::
@@ -62,3 +68,5 @@ To open Jupyter Notebook and begin to work on the Sankey. Write in the Command L
$ jupyter notebook
[not sure about this :D]
+
+To use the optimisation tools, `gcc` also needs to be installed, e.g. using Homebrew: `brew install gcc`.
diff --git a/floweaver/__init__.py b/floweaver/__init__.py
index f183bd8..3223d0a 100644
--- a/floweaver/__init__.py
+++ b/floweaver/__init__.py
@@ -9,11 +9,15 @@
from .results_graph import results_graph
from .augment_view_graph import elsewhere_bundles, augment
from .hierarchy import Hierarchy
-from .sankey_data import SankeyData, SankeyLink, SankeyNode
+from .sankey_data import SankeyData, SankeyLink, SankeyNode, SankeyLayout
from .color_scales import CategoricalScale, QuantitativeScale
from .weave import weave
+from .diagram_optimisation import optimise_node_order, optimise_node_positions, optimise_hybrid_model
+from .dataset_manipulation import generate_nodes, generate_ordering, generate_waypoints_bundles
__all__ = ['Dataset', 'Partition', 'Group', 'SankeyDefinition', 'ProcessGroup',
'Waypoint', 'Bundle', 'Elsewhere', 'view_graph', 'results_graph',
'elsewhere_bundles', 'augment', 'Hierarchy', 'weave', 'SankeyData',
- 'SankeyLink', 'SankeyNode', 'CategoricalScale', 'QuantitativeScale']
+ 'SankeyLink', 'SankeyNode', 'SankeyLayout', 'CategoricalScale', 'QuantitativeScale',
+ "optimise_node_order", "optimise_node_positions", "optimise_hybrid_model",
+ "generate_nodes", "generate_ordering", "generate_waypoints_bundles"]
diff --git a/floweaver/dataset_manipulation.py b/floweaver/dataset_manipulation.py
new file mode 100644
index 0000000..96c32ed
--- /dev/null
+++ b/floweaver/dataset_manipulation.py
@@ -0,0 +1,196 @@
+# Module containing the functions for flow data set manipulation
+from .sankey_data import SankeyLayout, SankeyData, SankeyNode, SankeyLink
+from .sankey_definition import ProcessGroup, Bundle, Waypoint
+## Helper function
+def isNaN(num):
+ return num != num
+
+## Function that will assemble the nodes dictionary
+def generate_nodes(node_def, group_by = 'id', node_uid = 'id', partition_groups = None):
+ # Loop through the dictionary and find all distinct node types and add them to the node dictionary
+ node_types = []
+ nodes = {}
+ group_lists = {}
+ for node in node_def:
+ # If NaN means the group_by field for that node is empty
+ if isNaN(node[group_by]):
+ # Add the node_uid as a node since not a part of a group
+ nodes[node[node_uid]] = ProcessGroup(node_uid + ' == "' + node[node_uid] + '"')
+
+ elif node[group_by] not in node_types:
+ nodes[node[group_by]] = ProcessGroup(group_by + ' == "' + node[group_by] + '"')
+
+ # Populate the group_lists dictionary
+ if partition_groups == 'all':
+ group_lists[node[group_by]] = []
+ group_lists[node[group_by]].append(node[node_uid])
+
+ elif partition_groups and node[group_by] in partition_groups:
+ group_lists[node[group_by]] = []
+ group_lists[node[group_by]].append(node[node_uid])
+
+ node_types.append(node[group_by])
+
+ # If the group_by already visited, need to add to the group_lists regardless
+ else:
+ # Populate the group_lists dictionary
+ if partition_groups == 'all':
+ group_lists[node[group_by]].append(node[node_uid])
+
+ elif partition_groups and node[group_by] in partition_groups:
+ group_lists[node[group_by]].append(node[node_uid])
+
+ # Now loop through group_lists and add all the partitions
+ for group in group_lists:
+ nodes[group].partition = Partition.Simple('process', group_lists[group])
+
+ return nodes
+
+## Assemble the ordering array
+def generate_ordering(node_def, group_by='id', node_uid='id'):
+ # Will first loop through and determine the dimension for the ordering array
+ layers = 0
+ bands = 0
+ for node in node_def:
+ layers = max(layers,node['layer'])
+ bands = max(bands,node['band'])
+
+ ordering = [ [ [] for i in range(bands + 1) ] for i in range(layers + 1) ]
+
+ # NOTE: This is limited to the assumption that all nodes in a group are in the same layer/band
+ visited = []
+ for node in node_def:
+
+ if isNaN(node[group_by]):
+ ordering[node['layer']][node['band']].append(node[node_uid])
+ elif node[group_by] not in visited:
+ visited.append(node[group_by])
+ ordering[node['layer']][node['band']].append(node[group_by])
+
+ return ordering
+
+## Function that returns the ordering, nodes and bundles
+def generate_waypoints_bundles(node_def, flows, ordering, nodes, group_by = 'id', node_uid = 'id'):
+ # Function takes in everything required to make decisions and output updated definition
+
+ # Generate a dictionary of nodes:layer pairs to increase code efficiency
+ node_layers = {}
+ for node in node_def:
+ node_layers[node[node_uid]] = node['layer']
+
+ # Generate a dictionary of nodes:bands pairs to increase code efficiency
+ node_bands = {}
+ for node in node_def:
+ node_bands[node[node_uid]] = node['band']
+
+ reverse_present = False # Variable declaring whether a reverse waypoint 'band' has been created
+
+ bundles = [] # Empty list for storing the bundles as they are generated
+
+ # Generate a dictionary of nodes:group_by pairs to increase code efficiency
+ node_group = {}
+ for node in node_def:
+ if isNaN(node[group_by]):
+ node_group[node[node_uid]] = node[node_uid]
+ else:
+ node_group[node[node_uid]] = node[group_by]
+
+ for flow in flows:
+
+ # Create a flow_selection if required
+ fs = 'source == "' + node_group[flow['source']] + 'L" and target == "' + node_group[flow['target']] + '"'
+
+ # Store the node layers in variables for code clarity
+ target = node_layers[flow['target']]
+ source = node_layers[flow['source']]
+
+ #print(node_group[flow['source']],node_group[flow['target']])
+
+ # If the target is in the same or prior layer to the source
+ if (target <= source) and (flow['source'] != flow['target']):
+
+ #If this is the first reverse flow then add the reverse wp band
+ if not reverse_present:
+ reverse_present = True
+ for layer in ordering:
+ layer.append([]) # Add an empty list, a new layer at the bottom
+
+ # Will loop through all layers between the source and target inclusive adding the wps
+ for layer in range(target, source + 1):
+
+ # If no reverse waypoint already added then add one for all required layers
+ if not ordering[layer][-1]:
+ ordering[layer][-1].append('wp' + str(layer))
+ nodes['wp' + str(layer)] = Waypoint(direction='L', title = '')
+
+ # If group is the same but the node is different need to add the flow selection
+ if (node_group[flow['source']] == node_group[flow['target']]) and (Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ flow_selection = fs)) not in bundles:
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ flow_selection = fs))
+
+ # If in the same layer then only one waypoint required
+ elif target == source and (Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=['wp' + str(source)]) not in bundles):
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=['wp' + str(source)]))
+
+ # If the bundle isn't in the bundles list and the layers are not the same
+ elif ((Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=['wp' + str(source),'wp' + str(target)]) not in bundles) and
+ target != source):
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=['wp' + str(source),'wp' + str(target)]))
+
+ # If its not a reverse flow, will either be a long/short forward flow
+
+ # First provide all the logic for a long flow (ie target-source > 1)
+ elif (target - source) > 1:
+
+ #Create a temporary waypoint list
+ wp_list = []
+
+ uid = 0 # Create a uid counter for the waypoint nodes
+ # Loop through all the layers in between the source and target exclusive
+ for layer in range(source + 1, target):
+
+ wp_name = 'fwp' + node_group[flow['source']] + node_group[flow['target']] + str(uid)
+
+ # Check if corresponding waypoint already exists
+ if wp_name not in ordering[layer][node_bands[flow['source']]]:
+ ordering[layer][node_bands[flow['source']]].append(wp_name)
+ nodes[wp_name] = Waypoint(direction='R', title = '')
+
+ # Add the wp to the wp list for this flow
+ wp_list.append(wp_name)
+ uid += 1
+
+ # Add the bundle with waypoints if not already existing
+ if (Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=wp_list)) not in bundles:
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ waypoints=wp_list))
+
+ # For flows from a node to itself:
+ elif (Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ flow_selection = fs)) not in bundles and (flow['source'] == flow['target']):
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']],
+ flow_selection = fs))
+
+ # ELSE: if not a reverse or a long flow its going to be a normal short flow. Simple bundle
+ elif (flow['source'] != flow['target']) and (Bundle(node_group[flow['source']],
+ node_group[flow['target']]) not in bundles):
+ bundles.append(Bundle(node_group[flow['source']],
+ node_group[flow['target']]))
+
+ return ordering, nodes, bundles
\ No newline at end of file
diff --git a/floweaver/diagram_optimisation.py b/floweaver/diagram_optimisation.py
new file mode 100644
index 0000000..c8cc416
--- /dev/null
+++ b/floweaver/diagram_optimisation.py
@@ -0,0 +1,1023 @@
+# Module containing all the functions for Floweaver SDD optimisation
+from mip import *
+from functools import cmp_to_key
+from attr import evolve
+import statistics
+import time
+from ipysankeywidget import SankeyWidget
+from ipywidgets import Layout, Output
+from .sankey_data import SankeyLayout
+
+# Function that returns the inputs required for the optimisation model to function
+def model_inputs(sankey_data, group_nodes = False):
+
+ ## Create the node band/layer sets for the model and a dictionary of node:{layer,band}
+ order = sankey_data.ordering.layers
+ node_layer_set = [ [] for i in range(len(order))]
+ node_band_set = [ [ [] for i in range(len(order)) ] for i in range(len(order[0])) ]
+ node_dict = {}
+ for i in range(len(order)):
+ for j in range(len(order[i])):
+ for k in order[i][j]:
+ # Append in correct locations
+ node_layer_set[i].append(k)
+ node_band_set[j][i].append(k)
+ # Add to the node_dict in correct location
+ node_dict[k] = {'layer':i, 'band':j}
+
+ ## Now need to create all the edge sets (main, exit, return)
+ flows = sankey_data.links
+ nodes = sankey_data.nodes
+ edges = [ [] for i in range(len(order))] # Set of main edges by layer
+ exit_edges = [ [] for i in range(len(order))] # Set of exit edges by layer
+ return_edges = [ [] for i in range(len(order))] # Set of main edges by layer
+ edge_weight = {} # Empty dict for edge weights
+
+ # Create a node_dir dictionary containing the node directions
+ node_dir = {}
+ for node in nodes:
+ node_dir[node.id] = node.direction
+
+ for flow in flows:
+
+ sl = node_dict[flow.source]['layer'] # save source layer to variable
+ tl = node_dict[flow.target]['layer'] # save target layer to variable
+
+ # FIRST CONDITION: If the nodes are in the same layer then exit or return edge
+ if sl == tl:
+
+ # If the source node has a direction of 'L' then it will be a return node
+ if node_dir[flow.source] == 'L':
+ return_edges[sl].append((flow.source,flow.target))
+ edge_weight[(flow.source,flow.target)] = flow.link_width
+ # If the source node has a direction of 'R' then it will be an exit node
+ else:
+ exit_edges[sl].append((flow.source,flow.target))
+ edge_weight[(flow.source,flow.target)] = flow.link_width
+
+ else: # If not return/exit then just a normal edge to add to edges main
+
+ # BUT need to have the lower layer node first so use if statements
+ if sl < tl:
+ edges[sl].append((flow.source,flow.target))
+ edge_weight[(flow.source,flow.target)] = flow.link_width
+ else:
+ edges[tl].append((flow.target,flow.source))
+ edge_weight[(flow.target,flow.source)] = flow.link_width
+
+ # Wrap all the lists etc into a model inputs dictionary
+ model_inputs = {
+ 'node_layer_set': node_layer_set,
+ 'node_band_set': node_band_set,
+ 'edges': edges,
+ 'exit_edges': exit_edges,
+ 'return_edges': return_edges,
+ 'edge_weight': edge_weight
+ }
+
+ # If the nodes are being grouped:
+ if group_nodes:
+
+ # Create the group_ordering list
+ group_ordering = [ [] for layer in order ]
+ groups = {}
+
+ ##### LOOP THROUGH ALL THE LAYERS IN THE ORDER, IF ENDS WITH * THEN NOT A GROUP!
+ # IN TURN ADD THE GROUPS TO THE ORDER, AND CONSTRUCT THE GROUPS, BY SPLITTING ON THE CARROT
+ # Loop through all the layer indices
+ for i in range(len(order)):
+
+ # Loop through each band in each layer
+ for band in order[i]:
+
+ # Loop through each node within the band:
+ for node in band:
+
+ # Create temp variable of the node split
+ temp = node.split('^')
+ # If the second item in the list is a * then its not part of a group, can ignore
+ if temp[1] != '*':
+
+ # If the group not already in groups dictionary
+ if temp[0] not in groups.keys():
+ groups[temp[0]] = []
+
+ # Add the node to the list
+ groups[temp[0]].append(node)
+
+ # If the group not in the ordering, add it to the ordering
+ if temp[0] not in group_ordering[i]:
+ group_ordering[i].append(temp[0])
+
+ # Add the two new model parameters to the model dict
+ model_inputs['groups'] = groups
+ model_inputs['group_ordering'] = group_ordering
+
+ return model_inputs
+
+## Function that takes in the inputs and optimises the model
+def optimise_node_order_model(model_inputs, group_nodes = False):
+
+ # Raise an error if the
+ if group_nodes and ('group_ordering' or 'groups') not in model_inputs.keys():
+ raise Exception('The provided model input does not contain the key \'node_groups')
+
+ ### Define the model
+ m = Model("sankey")
+
+ # Unpack the model input dictionary
+ node_layer_set = model_inputs['node_layer_set']
+ node_band_set = model_inputs['node_band_set']
+ edges = model_inputs['edges']
+ exit_edges = model_inputs['exit_edges']
+ return_edges = model_inputs['return_edges']
+ edge_weight = model_inputs['edge_weight']
+
+ # Create a list of all the node pairings in each layer
+ pairs_by_layer = [[ (u1,u2) for u1 in layer
+ for u2 in layer
+ if u1 != u2 ]
+ for layer in node_layer_set ]
+
+ ### Binary Decision Variables Section
+
+ # Create a dictionary of binary decision variables called 'x' containing the relative positions of the nodes in a layer
+ x = { k: m.add_var(var_type=BINARY) for layer in pairs_by_layer for k in layer }
+
+ # If utilising group_nodes then execute the following code
+ if group_nodes:
+
+ group_ordering = model_inputs['group_ordering']
+ groups = model_inputs['groups']
+
+ # Create a list of all the y binary variables (regarding the relative position of nodes to node groups)
+ node_group_pairs = [ [] for layer in node_layer_set ]
+
+ # The group_ordering is done by LAYER only - just like node_layer_set.
+ for i in range(len(node_layer_set)):
+ for U in group_ordering[i]:
+ for u2 in node_layer_set[i]:
+ # Only add the pairing IF the node, u2 is not in the group U.
+ if u2 not in groups[U]:
+ node_group_pairs[i].append((U,u2))
+
+ # Now generate all the binary variables 'y' for the relative position of node_groups and nodes
+ y = { k: m.add_var(var_type=BINARY) for layer in node_group_pairs for k in layer }
+
+ # Create a dictionary of binary decision variables called 'c' containing whether any two edges cross
+ c_main_main = { (u1v1,u2v2): m.add_var(var_type=BINARY) for Ek in edges for u1v1 in Ek for u2v2 in Ek
+ if u1v1 != u2v2
+ }
+
+ # Dictionary for binary decision variables for an 'exit' flow crossing with a 'forward' flow
+ c_exit_forward = { (u1v1,u2wp): m.add_var(var_type=BINARY) for Ek in edges for Ee in exit_edges
+ # Check if the edges are in the same layer or not
+ if edges.index(Ek) == exit_edges.index(Ee)
+ for u1v1 in Ek for u2wp in Ee
+ # Ignore edges from the same starting node 'u'
+ if u1v1[0] != u2wp[0]
+ }
+
+ # Dictionary of binary decision variables for the crossing of two 'exit' flows
+ c_exit_exit = { (u1wp1,u2wp2): m.add_var(var_type=BINARY) for Ee in exit_edges for u1wp1 in Ee for u2wp2 in Ee
+ # Do not add variable for a flow crossing itself
+ if u1wp1 != u2wp2
+ }
+
+ # Dictionary of binary decision variables for the crossing of return and forward flows
+ c_return_forward = { (u1v1,wpv2): m.add_var(var_type=BINARY) for Ek in edges for Er in return_edges
+ # Check if the return flow is one layer in front of the forward flow
+ if edges.index(Ek) + 1 == return_edges.index(Er)
+ for u1v1 in Ek
+ for wpv2 in Er
+ # Ignore edges to the same 'v' node
+ if u1v1[1] != wpv2[1]
+ }
+
+ # Dictionary of binary decision variables for the crossing of two 'return' flows
+ c_return_return = { (wp1v1,wp2v2): m.add_var(var_type=BINARY) for Er in return_edges for wp1v1 in Er for wp2v2 in Er
+ # Do not add variable for a flow crossing itself
+ if wp1v1 != wp2v2
+ }
+
+ # Objective Function
+
+ # This cell contains the objective function in full, will need to latter be modified
+
+ m.objective = minimize( # Area of main edge crossings
+ xsum(edge_weight[u1v1]*edge_weight[u2v2]*c_main_main[u1v1,u2v2]
+ for (u1v1,u2v2) in c_main_main.keys()) +
+ # Area of crossings between exit and main edges
+ xsum(edge_weight[u1v1]*edge_weight[u2wp]*c_exit_forward[u1v1,u2wp]
+ for (u1v1,u2wp) in c_exit_forward.keys()) +
+ # Area of crossings between exit edges
+ xsum(edge_weight[u1wp1]*edge_weight[u2wp2]*c_exit_exit[u1wp1,u2wp2]
+ for (u1wp1,u2wp2) in c_exit_exit.keys()) +
+ # Area of crossings between return and main edges
+ xsum(edge_weight[u1v1]*edge_weight[wpv2]*c_return_forward[u1v1,wpv2]
+ for (u1v1,wpv2) in c_return_forward.keys()) +
+ # Area of crossings between return edges
+ xsum(edge_weight[wp1v1]*edge_weight[wp2v2]*c_return_return[wp1v1,wp2v2]
+ for (wp1v1,wp2v2) in c_return_return.keys())
+ )
+
+ ### Constraints section, the following cells will contain all the constraints to be added to the model
+
+ # If grouping nodes generate the required constraints
+ if group_nodes:
+
+ #########################################
+ for i in range(len(node_layer_set)):
+ for u1 in node_layer_set[i]:
+
+ # First figure out what group u1 is in
+ U = ''
+ for group in groups:
+ if u1 in groups[group]:
+ U = group
+
+ for u2 in node_layer_set[i]:
+
+ if U: # Check if U is an empty string, meaning not in a group
+
+ # Apply the constraint ONLY if u2 not in U
+ if u2 not in groups[U]:
+
+ # Add the constraint
+ m += (y[U,u2] == x[u1,u2])
+
+ ## Constraints for the ordering variables 'x'
+ layer_index = 0
+ for layer in node_layer_set:
+ for u1 in layer:
+ for u2 in layer:
+ # Do not refer a node to itself
+ if u1 != u2:
+ # x is Binary, either u1 above u2 or u2 above u1 (total of the two 'x' values must be 1)
+ m += (x[u1,u2] + x[u2,u1] == 1)
+
+ ## Band constraints
+ # return the relative band positions of u1 and u2
+ for band in node_band_set:
+ # Find the band index for u1 and u2
+ if u1 in band[layer_index]:
+ u1_band = node_band_set.index(band)
+ if u2 in band[layer_index]:
+ u2_band = node_band_set.index(band)
+ # Determine 'x' values based off the band indices (note 0 is the highest band)
+ if u1_band < u2_band:
+ m += (x[u1,u2] == 1)
+ elif u1_band > u2_band:
+ m += (x[u1,u2] == 0)
+ # No else constraint necessary
+
+ ## Transitivity Constraints
+ for u3 in layer:
+ if u1 != u3 and u2 != u3:
+ m += (x[u3,u1] >= x[u3,u2] + x[u2,u1] - 1)
+ # Increment the current layer by 1
+ layer_index += 1
+
+ ## Constraints for c_main_main
+ for Ek in edges:
+ for (u1,v1) in Ek:
+ for (u2,v2) in Ek:
+ # Only consider 'c' values for crossings where the edges are not the same and the start/end nodes are different
+ if (u1,v1) != (u2,v2) and u1 != u2 and v1 != v2:
+ m += (c_main_main[(u1,v1),(u2,v2)] + x[u2,u1] + x[v1,v2] >= 1)
+ m += (c_main_main[(u1,v1),(u2,v2)] + x[u1,u2] + x[v2,v1] >= 1)
+
+ ## Constraits for c_exit_forward
+ for Ek in edges:
+ for Ee in exit_edges:
+ # Only consider the combinations of edges where the edges are in the same layer
+ if edges.index(Ek) == exit_edges.index(Ee):
+ for (u1,v1) in Ek:
+ for (u2,wp) in Ee:
+ # Only consider 'c' values for the crossings where the starting nodes is NOT the same
+ if u1 != u2 and u1 != wp:
+ m += (c_exit_forward[(u1,v1),(u2,wp)] + x[u2,u1] + x[u1,wp] >= 1)
+ m += (c_exit_forward[(u1,v1),(u2,wp)] + x[u1,u2] + x[wp,u1] >= 1)
+
+ ## Constraints for c_exit_exit
+ for Ee in exit_edges:
+ for (u1,wp1) in Ee:
+ for (u2,wp2) in Ee:
+ # Only consider 'c' values for the crossings where the start and waypoints are not the same
+ if u1 != u2 and wp1 != wp2:
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u1,u2] + x[u2,wp1] + x[wp1,wp2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u2,u1] + x[wp1,u2] + x[wp2,wp1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u1,wp2] + x[wp2,wp1] + x[wp1,u2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp2,u1] + x[wp1,wp2] + x[u2,wp1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp1,u2] + x[u2,u1] + x[u1,wp2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u2,wp1] + x[u1,u2] + x[wp2,u1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp1,wp2] + x[wp2,u1] + x[u1,u2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp2,wp1] + x[u1,wp2] + x[u2,u1] >= 1)
+
+ ## Constraints for c_return_forward
+ for Ek in edges:
+ for Er in return_edges:
+ # Only consider 'c' values if the return flow is one layer in front of the forward flow
+ if edges.index(Ek) + 1 == return_edges.index(Er):
+ for (u1,v1) in Ek:
+ for (wp,v2) in Er:
+ # Only consider values where the final nodes are not the same
+ # AND the final node of the main flow is not the waypoint
+ if v1 != v2 and v1 != wp:
+ m += (c_return_forward[(u1,v1),(wp,v2)] + x[v2,v1] + x[v1,wp] >= 1)
+ m += (c_return_forward[(u1,v1),(wp,v2)] + x[v1,v2] + x[wp,v1] >= 1)
+
+ ## Constraints for c_return_return
+ for Er in return_edges:
+ for (wp1,v1) in Er:
+ for (wp2,v2) in Er:
+ # Only consider edges where the waypoint and end nodes are not the same
+ if wp1 != wp2 and v1 != v2:
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v1,v2] + x[v2,wp1] + x[wp1,wp2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v2,v1] + x[wp1,v2] + x[wp2,wp1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v1,wp2] + x[wp2,wp1] + x[wp1,v2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp2,v1] + x[wp1,wp2] + x[v2,wp1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp1,v2] + x[v2,v1] + x[v1,wp2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v2,wp1] + x[v1,v2] + x[wp2,v1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp1,wp2] + x[wp2,v1] + x[v1,v2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp2,wp1] + x[v1,wp2] + x[v2,v1] >= 1)
+
+ ### Optimise the Model using a ILP Solver
+ start_time = time.time()
+ status = m.optimize(max_seconds=20)
+ end_time = time.time()
+ runtime = end_time - start_time
+ print(runtime)
+
+ ### Define a function that decodes the solution (i.e. compares nodes in a layer)
+
+ def cmp_nodes(u1,u2):
+ # If the optmimised x is >= 0.99 then u1 above u2 - thus u1 comes first
+ if x[u1,u2].x >= 0.99:
+ return -1
+ else:
+ return 1
+
+ ### Return Solution
+
+ # Optimised node order arranged in layers
+ sorted_order = [ sorted(layer,key=cmp_to_key(cmp_nodes)) for layer in node_layer_set ]
+
+ # Optimised order arranged in layers and bands
+ banded_order = [[] for i in range(len(node_layer_set))]
+
+ for i in range(len(node_layer_set)):
+ start_index = 0
+ for band in node_band_set:
+ end_index = len(band[i]) + start_index
+ banded_order[i].append(sorted_order[i][start_index:end_index])
+ start_index = end_index
+
+ return banded_order
+
+
+def optimise_node_order(sankey_data, group_nodes=False):
+ """Optimise node order to avoid flows crossings.
+
+ Returns new version of `sankey_data` with updated `ordering`.
+ """
+
+ model = model_inputs(sankey_data, group_nodes=group_nodes)
+ opt_order = optimise_node_order_model(model, group_nodes=group_nodes)
+ new_sankey_data = evolve(sankey_data, ordering=opt_order)
+ return new_sankey_data
+
+
+# Create a function that creates all the required inputs for the straightness optimisation model
+def straightness_model(sankey_data):
+
+ ## Create the node_layer_set
+ order = sankey_data.ordering.layers
+ node_layer_set = [ [] for i in range(len(order))]
+ node_band_set = [ [] for i in range(len(order[0])) ]
+ node_dict = {}
+ # loop through and add all the nodes into the node layer set
+ for i in range(len(order)):
+ for j in range(len(order[i])):
+ for k in order[i][j]:
+ # Append in correct locations
+ node_layer_set[i].append(k)
+ node_band_set[j].append(k)
+ # Add to the node_dict in correct location
+ node_dict[k] = {'layer':i, 'band':j, 'w_in':0, 'w_out':0}
+
+ # Create the flows list
+ flows = sankey_data.links
+ # Create the empty edges dictionary
+ edges = []
+ # Create edge weights dictionary
+ edge_weight = {}
+
+ for flow in flows:
+
+ sl = node_dict[flow.source]['layer'] # save source layer to variable
+ tl = node_dict[flow.target]['layer'] # save target layer to variable
+
+ # Ensure we are only considering the forward/main flows
+ if sl < tl:
+ edges.append((flow.source,flow.target))
+ edge_weight[(flow.source,flow.target)] = flow.link_width
+
+ # Determine the 'node weights' by assertaining the maximum of either in or out of each node
+ for flow in flows:
+
+ # Calculate the maximum possible weight of each node
+ node_dict[flow.source]['w_out'] += flow.link_width
+ node_dict[flow.target]['w_in'] += flow.link_width
+
+ # Figure out the maximum weight and assign it to a dictionary of node weightings
+ node_weight = {}
+ for node in node_dict:
+ # Assign value of the max weight!
+ node_weight[node] = max(node_dict[node]['w_in'], node_dict[node]['w_out'])
+
+ model_inputs = {
+ 'node_layer_set': node_layer_set,
+ 'node_band_set': node_band_set,
+ 'edges': edges,
+ 'edge_weight': edge_weight,
+ 'node_weight': node_weight
+ }
+
+ return model_inputs
+
+
+# Define a new function for optimising the vertical position
+def optimise_position_model(model_inputs, scale, wslb = 1):
+
+ ### Define the model
+ m = Model("sankey")
+
+ # Unpack the model input dictionary
+ node_layer_set = model_inputs['node_layer_set']
+ node_band_set = model_inputs['node_band_set']
+ edges = model_inputs['edges']
+ edge_weight = model_inputs['edge_weight']
+ node_weight = model_inputs['node_weight']
+
+ y = { node: m.add_var(name=f'y[{node}]', var_type=CONTINUOUS)
+ for layer in node_layer_set for node in layer
+ }
+
+ # Create the white space variables
+ d = {}
+ for i in range(len(node_layer_set)):
+
+ # Add the base_line to first node variable
+ d[('b',node_layer_set[i][0])] = m.add_var(var_type=CONTINUOUS, lb = 0)
+
+ # loop through all the pairings
+ for j in range(len(node_layer_set[i])):
+ if j+1 != len(node_layer_set[i]):
+ d[(node_layer_set[i][j],node_layer_set[i][j+1])] = m.add_var(var_type=CONTINUOUS, lb = wslb)
+
+ # Completely straight
+ b = {}
+ M = 10*max(node_weight.values())
+ penalty = statistics.stdev(edge_weight.values())
+
+ # Create all the deviation variables
+ s = {}
+ for edge in edges:
+ s[edge] = m.add_var(var_type = CONTINUOUS)
+ b[edge] = m.add_var(var_type = BINARY)
+
+ # Create a list of all the node pairings in each layer
+ pairs_by_layer = [[ (u1,u2) for u1 in layer
+ for u2 in layer
+ if u1 != u2 ]
+ for layer in node_layer_set ]
+
+ ### Binary Decision Variables Section
+ # Create a dictionary of binary decision variables called 'x' containing the relative positions of the nodes in a layer
+ x = {} # { k: m.add_var(var_type=BINARY) for layer in pairs_by_layer for k in layer }
+
+ ### Now go through and create the constraints
+
+ ## First create the constraints linking y values to white_spaces and weights
+
+ # Create the list of lists containing all the variables for each node y coord to perform xsum!
+ node_lists = {}
+ for layer in node_layer_set:
+
+ # Loop through all the nodes in the layer and do it accordingly
+ for i, node in enumerate(layer):
+ node_lists[node] = []
+ # All nodes require the baseline spacing
+ node_lists[node].append(d[('b',layer[0])])
+ if i != 0:
+ # If not the first node, need to add whitespace for all prior node pairs and prior node weights
+ for j in range(i):
+ # If i+1 is in range
+ #if j+1 != len(node_layer_set[i]):
+ if j+1 != len(layer):
+ # For each node up to i add the weight
+ node_lists[node].append(node_weight[layer[j]]*scale)
+ node_lists[node].append(d[(layer[j],layer[j+1])])
+ # Now the list has been assembled add the constraint!
+ m += (y[node] == xsum(node_lists[node][i] for i in range(len(node_lists[node]))))
+
+ ## Constraints for the ordering variables 'x'
+ layer_index = 0
+ for layer in node_layer_set:
+ for u1 in layer:
+ for u2 in layer:
+ # Do not refer a node to itself
+ if u1 != u2:
+ # x is Binary, either u1 above u2 or u2 above u1 (total of the two 'x' values must be 1)
+ #m += (x[u1,u2] + x[u2,u1] == 1)
+
+ u1_pos = node_layer_set[layer_index].index(u1)
+ u2_pos = node_layer_set[layer_index].index(u2)
+
+ # Determine 'x' values based off the node position (note 0 is the highest)
+ x[u1,u2] = 1 if u1_pos < u2_pos else 0
+
+ # Increment the current layer by 1
+ layer_index += 1
+
+ ## Create all the straightness constraints
+ # Loop through all the edges and add the two required constraints
+ for (u,v) in edges:
+ index_v = [i for i, layers in enumerate(node_layer_set) if v in layers]
+ index_u = [i for i, layers in enumerate(node_layer_set) if u in layers]
+ layer_v = [j for j in node_layer_set[index_v[0]] if (u,j) in edges and j != v]
+ layer_u = [j for j in node_layer_set[index_u[0]] if (j,v) in edges and j != u]
+
+ y_start = y[u] + xsum(edge_weight[(u,w)]*x[w,v]*scale for w in layer_v)
+ y_end = y[v] + xsum(edge_weight[(w,v)]*x[w,u]*scale for w in layer_u)
+
+ m += y_start - y_end <= M * b[(u,v)]
+ m += y_end - y_start <= M * b[(u,v)]
+ m += (s[(u,v)] >= y_start - y_end)
+ m += (s[(u,v)] >= -(y_start - y_end))
+
+ ## Create all the band constraints (ie higher bands above lower bands)
+
+ # Loop through the node_band_set and add all the nodes accordingly
+ # First loop through
+ for i, bandu in enumerate(node_band_set):
+ for u in bandu:
+
+ # Now for each 'u' node loop through all the other nodes
+ for j, bandv in enumerate(node_band_set):
+ for v in bandv:
+
+ # Only add the constraint if the second band is greater than the first
+ if j > i:
+ m += (y[v] >= y[u] + node_weight[u]*scale)
+
+ ### OBJECTIVE FUNCTION: (MINIMISE DEVIATION + extra) * FLOW WEIGHT^2
+ m.objective = minimize( xsum((s[edge] + penalty*scale*edge_weight[edge]*b[edge]) * (scale*edge_weight[edge])**2 for edge in s.keys()))
+
+ start_time = time.time()
+ # Run the model and optimise!
+ status = m.optimize(max_seconds=10)
+
+ end_time = time.time()
+ runtime = end_time - start_time
+ print(runtime)
+
+ ### Decode the solution by running through and creating simplified dictionary
+ y_coordinates = {}
+ for node in y:
+ y_coordinates[node] = y[node].x
+ return y_coordinates
+
+
+def optimise_node_positions(sankey_data,
+ width=None,
+ height=None,
+ margins=None,
+ scale=None,
+ minimum_gap=10):
+ """Optimise node positions to maximise straightness.
+
+ Returns new version of `sankey_data` with `node_positions` set.
+ """
+
+ # Apply default margins if not specified
+ if margins is None:
+ margins = {}
+ margins = {
+ "top": 50,
+ "bottom": 15,
+ "left": 130,
+ "right": 130,
+ **margins,
+ }
+
+ model = straightness_model(sankey_data)
+
+ if scale is None:
+ # FIXME can optimise this too, if not specified? Or calculate from
+ # `height` and `minimum_gap`, if specified.
+ scale = 1
+
+ # Optimise the y-coordinates of the nodes
+ ys = optimise_position_model(model, scale, wslb=minimum_gap)
+ ys = {k: y + margins['top'] for k, y in ys.items()}
+
+ # Work out appropriate diagram height, if not specified explicitly
+ if height is None:
+ max_y1 = max(y0 + model['node_weight'][k] for k, y0 in ys.items())
+ height = max_y1 + margins['bottom']
+
+ # X-coordinates
+
+ n_layers = len(sankey_data.ordering.layers)
+
+ # Work out appropriate diagram height, if not specified explicitly
+ if width is None:
+ # FIXME this could be smarter, and consider how much curvature there is:
+ # if all flows are thin or relatively straight, the layers can be closer
+ # together.
+ width = 150 * (n_layers - 1) + margins['left'] + margins['right']
+
+ # Ascertain the max possible space inc margins
+ max_w = max(0, width - margins['left'] - margins['right'])
+ xs = {
+ node_id: margins['left'] + i / (n_layers - 1) * max_w
+ for i, layer in enumerate(sankey_data.ordering.layers)
+ for band in layer
+ for node_id in band
+ }
+
+ # Overall layout
+ node_positions = {
+ node.id: [xs[node.id], ys[node.id]]
+ for node in sankey_data.nodes
+ }
+ layout = SankeyLayout(width=width, height=height, scale=scale, node_positions=node_positions)
+ return layout
+
+
+# Code for running the multi-objective MIP model
+def optimise_hybrid_model(straightness_model,
+ crossing_model,
+ group_nodes = False,
+ wslb = 1,
+ wsub = 10,
+ crossing_weight = 0.5,
+ straightness_weight = 0.5):
+
+ ### Define the model
+ m = Model("sankey")
+
+ ##########################################################################################################
+ # MINIMISE THE CROSSINGS MODEL
+ ##########################################################################################################
+
+ # Raise an error if the
+ if group_nodes and ('group_ordering' or 'groups') not in crossing_model.keys():
+ raise Exception('The provided model input does not contain the key \'node_groups')
+
+ # Unpack the model input dictionary
+ node_layer_set = crossing_model['node_layer_set']
+ node_band_set = crossing_model['node_band_set']
+ edges = crossing_model['edges']
+ exit_edges = crossing_model['exit_edges']
+ return_edges = crossing_model['return_edges']
+ edge_weight = crossing_model['edge_weight']
+
+ # Create a list of all the node pairings in each layer
+ pairs_by_layer = [[ (u1,u2) for u1 in layer
+ for u2 in layer
+ if u1 != u2 ]
+ for layer in node_layer_set ]
+
+ ### Binary Decision Variables Section
+
+ # Create a dictionary of binary decision variables called 'x' containing the relative positions of the nodes in a layer
+ x = { k: m.add_var(var_type=BINARY) for layer in pairs_by_layer for k in layer }
+
+ # If utilising group_nodes then execute the following code
+ if group_nodes:
+
+ group_ordering = crossing_model['group_ordering']
+ groups = crossing_model['groups']
+
+ # Create a list of all the y binary variables (regarding the relative position of nodes to node groups)
+ node_group_pairs = [ [] for layer in node_layer_set ]
+
+ # The group_ordering is done by LAYER only - just like node_layer_set.
+ for i in range(len(node_layer_set)):
+ for U in group_ordering[i]:
+ for u2 in node_layer_set[i]:
+ # Only add the pairing IF the node, u2 is not in the group U.
+ if u2 not in groups[U]:
+ node_group_pairs[i].append((U,u2))
+
+ # Now generate all the binary variables 'y' for the relative position of node_groups and nodes
+ g = { k: m.add_var(var_type=BINARY) for layer in node_group_pairs for k in layer }
+
+ # Create a dictionary of binary decision variables called 'c' containing whether any two edges cross
+ c_main_main = { (u1v1,u2v2): m.add_var(var_type=BINARY) for Ek in edges for u1v1 in Ek for u2v2 in Ek
+ if u1v1 != u2v2
+ }
+
+ # Dictionary for binary decision variables for an 'exit' flow crossing with a 'forward' flow
+ c_exit_forward = { (u1v1,u2wp): m.add_var(var_type=BINARY) for Ek in edges for Ee in exit_edges
+ # Check if the edges are in the same layer or not
+ if edges.index(Ek) == exit_edges.index(Ee)
+ for u1v1 in Ek for u2wp in Ee
+ # Ignore edges from the same starting node 'u'
+ if u1v1[0] != u2wp[0]
+ }
+
+ # Dictionary of binary decision variables for the crossing of two 'exit' flows
+ c_exit_exit = { (u1wp1,u2wp2): m.add_var(var_type=BINARY) for Ee in exit_edges for u1wp1 in Ee for u2wp2 in Ee
+ # Do not add variable for a flow crossing itself
+ if u1wp1 != u2wp2
+ }
+
+ # Dictionary of binary decision variables for the crossing of return and forward flows
+ c_return_forward = { (u1v1,wpv2): m.add_var(var_type=BINARY) for Ek in edges for Er in return_edges
+ # Check if the return flow is one layer in front of the forward flow
+ if edges.index(Ek) + 1 == return_edges.index(Er)
+ for u1v1 in Ek
+ for wpv2 in Er
+ # Ignore edges to the same 'v' node
+ if u1v1[1] != wpv2[1]
+ }
+
+ # Dictionary of binary decision variables for the crossing of two 'return' flows
+ c_return_return = { (wp1v1,wp2v2): m.add_var(var_type=BINARY) for Er in return_edges for wp1v1 in Er for wp2v2 in Er
+ # Do not add variable for a flow crossing itself
+ if wp1v1 != wp2v2
+ }
+
+ ### Constraints section, the following cells will contain all the constraints to be added to the model
+
+ # If grouping nodes generate the required constraints
+ if group_nodes:
+
+ for i in range(len(node_layer_set)):
+ for u1 in node_layer_set[i]:
+
+ # First figure out what group u1 is in
+ U = ''
+ for group in groups:
+ if u1 in groups[group]:
+ U = group
+
+ for u2 in node_layer_set[i]:
+
+ if U: # Check if U is an empty string, meaning not in a group
+
+ # Apply the constraint ONLY if u2 not in U
+ if u2 not in groups[U]:
+
+ # Add the constraint
+ m += (g[U,u2] == x[u1,u2])
+
+ ## Constraints for the ordering variables 'x'
+ layer_index = 0
+ for layer in node_layer_set:
+ for u1 in layer:
+ for u2 in layer:
+ # Do not refer a node to itself
+ if u1 != u2:
+ # x is Binary, either u1 above u2 or u2 above u1 (total of the two 'x' values must be 1)
+ m += (x[u1,u2] + x[u2,u1] == 1)
+
+ ## Band constraints
+ # return the relative band positions of u1 and u2
+ for band in node_band_set:
+ # Find the band index for u1 and u2
+ if u1 in band[layer_index]:
+ u1_band = node_band_set.index(band)
+ if u2 in band[layer_index]:
+ u2_band = node_band_set.index(band)
+ # Determine 'x' values based off the band indices (note 0 is the highest band)
+ if u1_band < u2_band:
+ m += (x[u1,u2] == 1)
+ elif u1_band > u2_band:
+ m += (x[u1,u2] == 0)
+ # No else constraint necessary
+
+ ## Transitivity Constraints
+ for u3 in layer:
+ if u1 != u3 and u2 != u3:
+ m += (x[u3,u1] >= x[u3,u2] + x[u2,u1] - 1)
+ # Increment the current layer by 1
+ layer_index += 1
+
+ ## Constraints for c_main_main
+ for Ek in edges:
+ for (u1,v1) in Ek:
+ for (u2,v2) in Ek:
+ # Only consider 'c' values for crossings where the edges are not the same and the start/end nodes are different
+ if (u1,v1) != (u2,v2) and u1 != u2 and v1 != v2:
+ m += (c_main_main[(u1,v1),(u2,v2)] + x[u2,u1] + x[v1,v2] >= 1)
+ m += (c_main_main[(u1,v1),(u2,v2)] + x[u1,u2] + x[v2,v1] >= 1)
+
+ ## Constraits for c_exit_forward
+ for Ek in edges:
+ for Ee in exit_edges:
+ # Only consider the combinations of edges where the edges are in the same layer
+ if edges.index(Ek) == exit_edges.index(Ee):
+ for (u1,v1) in Ek:
+ for (u2,wp) in Ee:
+ # Only consider 'c' values for the crossings where the starting nodes is NOT the same
+ if u1 != u2:
+ m += (c_exit_forward[(u1,v1),(u2,wp)] + x[u2,u1] + x[u1,wp] >= 1)
+ m += (c_exit_forward[(u1,v1),(u2,wp)] + x[u1,u2] + x[wp,u1] >= 1)
+
+ ## Constraints for c_exit_exit
+ for Ee in exit_edges:
+ for (u1,wp1) in Ee:
+ for (u2,wp2) in Ee:
+ # Only consider 'c' values for the crossings where the start and waypoints are not the same
+ if u1 != u2 and wp1 != wp2:
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u1,u2] + x[u2,wp1] + x[wp1,wp2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u2,u1] + x[wp1,u2] + x[wp2,wp1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u1,wp2] + x[wp2,wp1] + x[wp1,u2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp2,u1] + x[wp1,wp2] + x[u2,wp1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp1,u2] + x[u2,u1] + x[u1,wp2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[u2,wp1] + x[u1,u2] + x[wp2,u1] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp1,wp2] + x[wp2,u1] + x[u1,u2] >= 1)
+ m += (c_exit_exit[(u1,wp1),(u2,wp2)] + x[wp2,wp1] + x[u1,wp2] + x[u2,u1] >= 1)
+
+ ## Constraints for c_return_forward
+ for Ek in edges:
+ for Er in return_edges:
+ # Only consider 'c' values if the return flow is one layer in front of the forward flow
+ if edges.index(Ek) + 1 == return_edges.index(Er):
+ for (u1,v1) in Ek:
+ for (wp,v2) in Er:
+ # Only consider values where the final nodes are not the same
+ # AND the final node of the main flow is not the waypoint
+ if v1 != v2 and v1 != wp:
+ m += (c_return_forward[(u1,v1),(wp,v2)] + x[v2,v1] + x[v1,wp] >= 1)
+ m += (c_return_forward[(u1,v1),(wp,v2)] + x[v1,v2] + x[wp,v1] >= 1)
+
+ ## Constraints for c_return_return
+ for Er in return_edges:
+ for (wp1,v1) in Er:
+ for (wp2,v2) in Er:
+ # Only consider edges where the waypoint and end nodes are not the same
+ if wp1 != wp2 and v1 != v2:
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v1,v2] + x[v2,wp1] + x[wp1,wp2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v2,v1] + x[wp1,v2] + x[wp2,wp1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v1,wp2] + x[wp2,wp1] + x[wp1,v2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp2,v1] + x[wp1,wp2] + x[v2,wp1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp1,v2] + x[v2,v1] + x[v1,wp2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[v2,wp1] + x[v1,v2] + x[wp2,v1] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp1,wp2] + x[wp2,v1] + x[v1,v2] >= 1)
+ m += (c_return_return[(wp1,v1),(wp2,v2)] + x[wp2,wp1] + x[v1,wp2] + x[v2,v1] >= 1)
+
+ ##########################################################################################################
+ # MAXIMISE THE STRAIGHTNESS
+ ##########################################################################################################
+
+ # Unpack the model input dictionary
+ node_layer_set1 = straightness_model['node_layer_set']
+ node_band_set1 = straightness_model['node_band_set']
+ edges1 = straightness_model['edges']
+ edge_weight1 = straightness_model['edge_weight']
+ node_weight1 = straightness_model['node_weight']
+
+ # Create all the y variables - one for every node
+ y = { node: m.add_var(var_type=CONTINUOUS)
+ for layer in node_layer_set1 for node in layer
+ }
+
+ # Create a list of all the node pairings in each layer
+ pairs = [[ (u1,u2) for u1 in layer
+ for u2 in layer
+ if u1 != u2 ]
+ for layer in node_layer_set1 ]
+
+ # Create a dictionary of binary decision variables called 'x' containing the relative positions of the nodes in a layer
+ dx = { k: m.add_var(var_type=CONTINUOUS) for layer in pairs for k in layer }
+
+ # Create the white space variables
+ d = {}
+ for i in range(len(node_layer_set1)):
+
+ # Add the base_line to first node variable
+ d[f'b{i}'] = m.add_var(var_type=CONTINUOUS, lb = 0)
+
+ # loop through all the nodes
+ for node in node_layer_set1[i]:
+ d[node] = m.add_var(var_type=CONTINUOUS, lb = wslb, ub = wsub)
+
+ # Create all the deviation variables
+ s = {}
+ for edge in edges1:
+ s[edge] = m.add_var(var_type=CONTINUOUS)
+
+ ### Now go through and create the constraints
+
+ ## First create the constraints linking y values to white_spaces and weights
+
+ # Loop through all layers and add the constraint
+ for i, layer in enumerate(node_layer_set1):
+
+ # Loop through all the nodes in the layer
+ for u in layer:
+
+ # Loop through and add all the constraints for the dx variable
+ for v in layer:
+ if v != u:
+ # Add all the 4 constraints
+ m += ( dx[v,u] <= wsub * x[v,u] )
+ m += ( dx[v,u] >= wslb * x[v,u] )
+ m += ( dx[v,u] <= d[v] - wslb * (1-x[v,u]) )
+ m += ( dx[v,u] >= d[v] - wsub * (1-x[v,u]) )
+
+ # Add the constraint
+ #m += ( d[f'b{i}'] + xsum( (node_weight1[v] + d[v])*x[v,u] for v in layer if v != u ) )
+ m += ( d[f'b{i}'] + xsum( (node_weight1[v]*x[v,u] + dx[v,u]) for v in layer if v != u ) == y[u] )
+
+ ## Create all the straightness constraints
+
+ # Loop through all the edges and add the two required constraints
+ for (u,v) in edges1:
+ m += (s[(u,v)] >= y[u] - y[v])
+ m += (s[(u,v)] >= -(y[u] - y[v]))
+
+ ## Create all the band constraints (ie higher bands above lower bands)
+
+ # Loop through the node_band_set and add all the nodes accordingly
+ # First loop through
+ for i, bandu in enumerate(node_band_set1):
+ for u in bandu:
+
+ # Now for each 'u' node loop through all the other nodes
+ for j, bandv in enumerate(node_band_set1):
+ for v in bandv:
+
+ # Only add the constraint if the second band is greater than the first
+ if j > i:
+ m += (y[v] >= y[u] + node_weight1[u])
+
+ #########################################################################################################
+ ### Objective Function
+ #########################################################################################################
+
+ m.objective = minimize( # Area of main edge crossings
+ crossing_weight * (
+ xsum(edge_weight[u1v1]*edge_weight[u2v2]*c_main_main[u1v1,u2v2]
+ for (u1v1,u2v2) in c_main_main.keys()) +
+ # Area of crossings between exit and main edges
+ xsum(edge_weight[u1v1]*edge_weight[u2wp]*c_exit_forward[u1v1,u2wp]
+ for (u1v1,u2wp) in c_exit_forward.keys()) +
+ # Area of crossings between exit edges
+ xsum(edge_weight[u1wp1]*edge_weight[u2wp2]*c_exit_exit[u1wp1,u2wp2]
+ for (u1wp1,u2wp2) in c_exit_exit.keys()) +
+ # Area of crossings between return and main edges
+ xsum(edge_weight[u1v1]*edge_weight[wpv2]*c_return_forward[u1v1,wpv2]
+ for (u1v1,wpv2) in c_return_forward.keys()) +
+ # Area of crossings between return edges
+ xsum(edge_weight[wp1v1]*edge_weight[wp2v2]*c_return_return[wp1v1,wp2v2]
+ for (wp1v1,wp2v2) in c_return_return.keys())
+ ) +
+ straightness_weight * (
+ xsum(s[edge]*edge_weight[edge] for edge in s.keys())
+ )
+ )
+
+ # Run the model and optimise!
+ status = m.optimize(max_solutions = 500)
+
+ #########################################################################################################
+ ### Decode Solution
+ #########################################################################################################
+
+ ### Decode the solution by running through and creating simplified dictionary
+ y_coordinates = {}
+ for node in y:
+ y_coordinates[node] = y[node].x
+
+ ### Define a function that decodes the solution (i.e. compares nodes in a layer)
+
+ def cmp_nodes(u1,u2):
+ # If the optmimised x is >= 0.99 then u1 above u2 - thus u1 comes first
+ if x[u1,u2].x >= 0.99:
+ return -1
+ else:
+ return 1
+
+ ### Return Solution
+
+ # Optimised node order arranged in layers
+ sorted_order = [ sorted(layer,key=cmp_to_key(cmp_nodes)) for layer in node_layer_set ]
+
+ # Optimised order arranged in layers and bands
+ banded_order = [[] for i in range(len(node_layer_set))]
+
+ for i in range(len(node_layer_set)):
+ start_index = 0
+ for band in node_band_set:
+ end_index = len(band[i]) + start_index
+ banded_order[i].append(sorted_order[i][start_index:end_index])
+ start_index = end_index
+
+ return banded_order, y_coordinates
diff --git a/floweaver/sankey_data.py b/floweaver/sankey_data.py
index 6d222eb..51bd02a 100644
--- a/floweaver/sankey_data.py
+++ b/floweaver/sankey_data.py
@@ -21,6 +21,15 @@
_validate_opt_str = attr.validators.optional(attr.validators.instance_of(str))
+@attr.s(slots=True, frozen=True)
+class SankeyLayout:
+ """Visual/geometric properties of a Sankey diagram."""
+ width = attr.ib(float)
+ height = attr.ib(float)
+ scale = attr.ib(default=None)
+ node_positions = attr.ib(default=None)
+
+
@attr.s(slots=True, frozen=True)
class SankeyData(object):
nodes = attr.ib()
@@ -29,11 +38,12 @@ class SankeyData(object):
ordering = attr.ib(converter=_convert_ordering, default=Ordering([[]]))
dataset = attr.ib(default=None)
- def to_json(self, filename=None, format=None):
+ def to_json(self, filename=None, format=None, layout=None):
"""Convert data to JSON-ready dictionary."""
+
if format == "widget":
data = {
- "nodes": [n.to_json(format) for n in self.nodes],
+ "nodes": [n.to_json(format, layout) for n in self.nodes],
"links": [l.to_json(format) for l in self.links],
"order": self.ordering.layers,
"groups": self.groups,
@@ -46,7 +56,7 @@ def to_json(self, filename=None, format=None):
"authors": [],
"layers": self.ordering.layers,
},
- "nodes": [n.to_json(format) for n in self.nodes],
+ "nodes": [n.to_json(format, layout) for n in self.nodes],
"links": [l.to_json(format) for l in self.links],
"groups": self.groups,
}
@@ -59,19 +69,45 @@ def to_json(self, filename=None, format=None):
def to_widget(
self,
- width=700,
- height=500,
+ width=None,
+ height=None,
margins=None,
align_link_types=False,
link_label_format="",
link_label_min_width=5,
debugging=False,
+ layout=None,
):
+ """Convert to an ipysankeywidget SankeyWidget.
+
+ `layout` provides width, height and scale, but can be overridden by the
+ `width` and `height` arguments.
+
+ `margins` are used when automatically layout out the node positions, but
+ are ignored when a `layout` is passed which contains explicit node
+ positions.
+
+ """
if SankeyWidget is None:
raise RuntimeError("ipysankeywidget is required")
- if margins is None:
+ if width is None:
+ width = layout.width if layout is not None else 700
+ if height is None:
+ height = layout.height if layout is not None else 500
+
+ has_positions = layout is not None and layout.node_positions is not None
+
+ if has_positions:
+ # Assume the layout has already accounted for margins as needed
+ margins = {
+ "top": 0,
+ "bottom": 0,
+ "left": 0,
+ "right": 0,
+ }
+ elif margins is None:
margins = {
"top": 25,
"bottom": 10,
@@ -79,7 +115,10 @@ def to_widget(
"right": 130,
}
- value = self.to_json(format="widget")
+ # Convert to JSON format, embedding node positions if specified in
+ # `layout`.
+ value = self.to_json(format="widget", layout=layout)
+
widget = SankeyWidget(
nodes=value["nodes"],
links=value["links"],
@@ -88,10 +127,15 @@ def to_widget(
align_link_types=align_link_types,
linkLabelFormat=link_label_format,
linkLabelMinWidth=link_label_min_width,
- layout=Layout(width=str(width), height=str(height)),
+ layout= Layout(width=str(width), height=str(height)),
margins=margins,
+ node_position_attr=('position' if has_positions else None),
)
+ # Set the scale if explicitly defined by the layout
+ if layout is not None and layout.scale is not None:
+ widget.scale = layout.scale
+
if debugging:
output = Output()
@@ -137,13 +181,13 @@ class SankeyNode(object):
direction = attr.ib(validator=_validate_direction, default="R")
hidden = attr.ib(default=False)
style = attr.ib(default=None, validator=_validate_opt_str)
- from_elsewhere_links = attr.ib(default=list)
- to_elsewhere_links = attr.ib(default=list)
+ from_elsewhere_links = attr.ib(default=attr.Factory(list))
+ to_elsewhere_links = attr.ib(default=attr.Factory(list))
- def to_json(self, format=None):
+ def to_json(self, format=None, layout=None):
"""Convert node to JSON-ready dictionary."""
if format == "widget":
- return {
+ result = {
"id": self.id,
"title": self.title if self.title is not None else self.id,
"direction": self.direction.lower(),
@@ -153,7 +197,7 @@ def to_json(self, format=None):
"toElsewhere": [l.to_json(format) for l in self.to_elsewhere_links]
}
else:
- return {
+ result = {
"id": self.id,
"title": self.title if self.title is not None else self.id,
"style": {
@@ -162,6 +206,12 @@ def to_json(self, format=None):
"type": self.style if self.style is not None else "default",
},
}
+ if layout is not None and layout.node_positions is not None:
+ try:
+ result["position"] = layout.node_positions[self.id]
+ except KeyError:
+ raise KeyError(f"No node position specified for node \"{self.id}\"")
+ return result
def _validate_opacity(instance, attr, value):
@@ -178,7 +228,7 @@ class SankeyLink(object):
type = attr.ib(default=None, validator=_validate_opt_str)
time = attr.ib(default=None, validator=_validate_opt_str)
link_width = attr.ib(default=0.0, converter=float)
- data = attr.ib(default=lambda: {"value": 0.0})
+ data = attr.ib(default=attr.Factory(lambda: {"value": 0.0}))
title = attr.ib(default=None, validator=_validate_opt_str)
color = attr.ib(default=None, validator=_validate_opt_str)
opacity = attr.ib(default=1.0, converter=float, validator=_validate_opacity)
diff --git a/setup.py b/setup.py
index 8b6bc19..c9b219e 100644
--- a/setup.py
+++ b/setup.py
@@ -67,6 +67,7 @@ def find_version(*file_paths):
],
extras_require={
'dev': [],
+ 'mip': ['mip'],
'test': ['pytest', 'matplotlib', 'codecov', 'pytest-cov'],
'docs': ['sphinx', 'nbsphinx', 'jupyter_client', 'ipykernel', 'ipysankeywidget']
},
diff --git a/test/test_hierarchy.py b/test/test_hierarchy.py
index 6e204eb..c8f4ec6 100644
--- a/test/test_hierarchy.py
+++ b/test/test_hierarchy.py
@@ -21,5 +21,5 @@ def test_hierarchy():
assert h('East Anglia') == "location in ['Cambridge', 'Ely', 'Escape \"']"
assert h('*') == None
- with pytest.raises(KeyError):
+ with pytest.raises((KeyError, nx.NetworkXError)):
h('unknown')
diff --git a/test/test_node_position_optimisation.py b/test/test_node_position_optimisation.py
new file mode 100644
index 0000000..7ee32e1
--- /dev/null
+++ b/test/test_node_position_optimisation.py
@@ -0,0 +1,90 @@
+import pytest
+
+from floweaver.diagram_optimisation import optimise_node_positions
+from floweaver.sankey_data import SankeyData, SankeyNode, SankeyLink, SankeyLayout
+
+
+def test_node_positions_straight():
+ data = SankeyData(nodes=[SankeyNode(id='a'), SankeyNode(id='b')],
+ links=[SankeyLink(source='a', target='b', link_width=3)],
+ ordering=[['a'], ['b']])
+ layout = optimise_node_positions(data, margins=dict(left=10, top=20, right=10, bottom=20), scale=1)
+
+ # Width not specified -- assumes a suitable gap between layers
+ assumed_gap = 150
+ assert layout.node_positions == {
+ "a": [10, 20],
+ "b": [10 + 150, 20],
+ }
+
+
+TEST_DATA_SIMPLE_MERGE = SankeyData(
+ nodes=[SankeyNode(id='a1'), SankeyNode(id='a2'), SankeyNode(id='b')],
+ links=[
+ SankeyLink(source='a1', target='b', link_width=3),
+ SankeyLink(source='a2', target='b', link_width=3),
+ ],
+ ordering=[['a1', 'a2'], ['b']]
+)
+
+
+def test_node_positions_no_overlap():
+ # Check y positions do not overlap
+ dy_a1 = 3
+ minimum_gap = 10
+ margins = dict(left=10, top=20, right=10, bottom=20)
+ layout = optimise_node_positions(TEST_DATA_SIMPLE_MERGE, margins=margins, scale=1, minimum_gap=minimum_gap)
+ assert layout.node_positions['a1'][1] == 20
+ assert layout.node_positions['a2'][1] >= 20 + dy_a1 + minimum_gap
+
+
+#@pytest.mark.xfail(reason='need to account for scale when calculating node positions')
+def test_node_positions_no_overlap_with_scale():
+ # Check y positions do not overlap
+ scale = 2
+ dy_a1 = 3 * scale
+ minimum_gap = 10
+ margins = dict(left=10, top=20, right=10, bottom=20)
+ layout = optimise_node_positions(TEST_DATA_SIMPLE_MERGE, margins=margins, scale=scale, minimum_gap=minimum_gap)
+ assert layout.node_positions['a1'][1] == 20
+ assert layout.node_positions['a2'][1] >= 20 + dy_a1 + minimum_gap
+
+
+@pytest.mark.xfail(reason='need to account for offset between node position and link position')
+def test_node_positions_target_in_between_sources():
+ layout = optimise_node_positions(TEST_DATA_SIMPLE_MERGE, scale=1)
+ y = lambda k: layout.node_positions[k][1]
+ assert y('b') > y('a1')
+ assert y('b') + 6 <= y('a2') + 3
+
+
+# This test case has a first "start" node going to the top node in the second
+# layer "a1" in order to offset the lower node "a2" away from the top of the
+# diagram. "a2" is connected to two nodes in the following layer, "b1" and "b2".
+# We want "b2" to be aligned so that its link (the lower of the two leaving
+# "a2") is straight. If the offsets between node positions and link positions
+# are not accounted for properly this will fail.
+TEST_DATA_OFFSETS = SankeyData(
+ nodes=[
+ SankeyNode(id='start'),
+ SankeyNode(id='a1'),
+ SankeyNode(id='a2'),
+ SankeyNode(id='b1'),
+ SankeyNode(id='b2'),
+ ],
+ links=[
+ SankeyLink(source='start', target='a1', link_width=30),
+ SankeyLink(source='a1', target='b1', link_width=3),
+ SankeyLink(source='a2', target='b1', link_width=30),
+ SankeyLink(source='a2', target='b2', link_width=30),
+ ],
+ ordering=[['start'], ['a1', 'a2'], ['b1', 'b2']]
+)
+
+
+#@pytest.mark.xfail(reason='need to account for offset between node position and link position')
+def test_node_positions_aligns_links_straight():
+ layout = optimise_node_positions(TEST_DATA_OFFSETS, scale=1)
+ y = lambda k: layout.node_positions[k][1]
+ dy_link_b2_a1 = 30
+ assert y('b2') == y('a2') + dy_link_b2_a1
diff --git a/test/test_sankey_data.py b/test/test_sankey_data.py
index d542476..4b5d8af 100644
--- a/test/test_sankey_data.py
+++ b/test/test_sankey_data.py
@@ -1,6 +1,6 @@
import pytest
-from floweaver.sankey_data import SankeyData, SankeyNode, SankeyLink
+from floweaver.sankey_data import SankeyData, SankeyNode, SankeyLink, SankeyLayout
def test_sankey_data():
@@ -21,6 +21,23 @@ def test_sankey_data_json():
assert json['links'] == [l.to_json() for l in data.links]
+def test_sankey_data_node_positions():
+ data = SankeyData(nodes=[SankeyNode(id='a')],
+ links=[SankeyLink(source='a', target='a')])
+ json1 = data.to_json()
+
+ layout = SankeyLayout(
+ width=100,
+ height=100,
+ scale=1,
+ node_positions={"a": [3, 4]}
+ )
+ json2 = data.to_json(layout=layout)
+
+ assert "position" not in json1["nodes"][0]
+ assert json2["nodes"][0]["position"] == [3, 4]
+
+
def test_sankey_data_node_json():
assert SankeyNode(id='a').to_json() == {
'id': 'a',