Skip to content

Commit

Permalink
Merge pull request #63 from NREL/fix/speed-trace-from-slts-in-ssts
Browse files Browse the repository at this point in the history
Fix/speed trace from slts in ssts
  • Loading branch information
calbaker authored Sep 12, 2024
2 parents 14700b4 + 953d709 commit 49381bd
Show file tree
Hide file tree
Showing 28 changed files with 10,651 additions and 499 deletions.
60 changes: 21 additions & 39 deletions python/altrios/demos/set_speed_train_sim_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import pandas as pd
import seaborn as sns
import os
Expand All @@ -16,7 +17,7 @@
SHOW_PLOTS = alt.utils.show_plots()
PYTEST = os.environ.get("PYTEST", "false").lower() == "true"

SAVE_INTERVAL = 1
SAVE_INTERVAL = 100

# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
Expand Down Expand Up @@ -77,10 +78,9 @@

# Load the network and link path through the network.
network = alt.Network.from_file(
alt.resources_root() / "networks/Taconite.yaml")
network.set_speed_set_for_train_type(alt.TrainType.Freight)
alt.resources_root() / "networks/Taconite-NoBalloon.yaml")
link_path = alt.LinkPath.from_csv_file(
alt.resources_root() / "demo_data/link_points_idx.csv"
alt.resources_root() / "demo_data/link_path.csv"
)

# load the prescribed speed trace that the train will follow
Expand Down Expand Up @@ -142,7 +142,7 @@

ax[-1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.speed_trace.speed_meters_per_second,
np.array(train_sim.speed_trace.speed_meters_per_second)[::SAVE_INTERVAL][1:],
)
ax[-1].set_xlabel('Time [hr]')
ax[-1].set_ylabel('Speed [m/s]')
Expand All @@ -153,37 +153,19 @@
plt.tight_layout()
plt.show()

if PYTEST:
# to access these checks, run `SHOW_PLOTS=f PYTEST=true python set_speed_train_sim_demo.py`
import json
json_path = alt.resources_root() / "test_assets/set_speed_ts_demo.json"
with open(json_path, 'r') as file:
train_sim_reference = json.load(file)

dist_msg = f"`train_sim.state.total_dist_meters`: {train_sim.state.total_dist_meters}\n" + \
f"`train_sim_reference['state']['total_dist']`: {train_sim_reference['state']['total_dist']}"
energy_whl_out_msg = f"`train_sim.state.energy_whl_out_joules`: {train_sim.state.energy_whl_out_joules}\n" + \
f"`train_sim_reference['state']['energy_whl_out']`: {train_sim_reference['state']['energy_whl_out']}"
train_sim_fuel = train_sim.loco_con.get_energy_fuel_joules()
train_sim_reference_fuel = sum(
loco['loco_type']['ConventionalLoco']['fc']['state']['energy_fuel'] if 'ConventionalLoco' in loco['loco_type'] else 0
for loco in train_sim_reference['loco_con']['loco_vec']
)
fuel_msg = f"`train_sim_fuel`: {train_sim_fuel}\n`train_sim_referenc_fuel`: {train_sim_reference_fuel}"
train_sim_net_res = train_sim.loco_con.get_net_energy_res_joules()
train_sim_reference_net_res = sum(
loco['loco_type']['BatteryElectricLoco']['res']['state']['energy_out_chemical'] if 'BatteryElectricLoco' in loco['loco_type'] else 0
for loco in train_sim_reference['loco_con']['loco_vec']
)
net_res_msg = f"`train_sim_net_res`: {train_sim_net_res}\n`train_sim_referenc_net_res`: {train_sim_reference_net_res}"

# check total distance
assert train_sim.state.total_dist_meters == train_sim_reference["state"]["total_dist"], dist_msg

# check total tractive energy
assert train_sim.state.energy_whl_out_joules == train_sim_reference["state"]["energy_whl_out"], energy_whl_out_msg

# check consist-level fuel usage
assert train_sim_fuel == train_sim_reference_fuel, fuel_msg

# check consist-level battery usage
# whether to run assertions, enabled by default
ENABLE_ASSERTS = os.environ.get("ENABLE_ASSERTS", "true").lower() == "true"
# whether to override reference files used in assertions, disabled by default
ENABLE_REF_OVERRIDE = os.environ.get("ENABLE_REF_OVERRIDE", "false").lower() == "true"
# directory for reference files for checking sim results against expected results
ref_dir = alt.resources_root() / "demo_data/set_speed_train_sim_demo/"

if ENABLE_REF_OVERRIDE:
ref_dir.mkdir(exist_ok=True, parents=True)
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()[-1]
df.write_csv(ref_dir / "to_dataframe_expected.csv")
if ENABLE_ASSERTS:
print("Checking output of `to_dataframe`")
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()[-1]
assert to_dataframe_expected.equals(train_sim.to_dataframe()[-1])
print("Success!")
29 changes: 23 additions & 6 deletions python/altrios/demos/speed_limit_train_sim_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import altrios as alt
sns.set_theme()

# Uncomment and run `maturin develop --release --features logging` to enable logging,
# which is needed because logging bogs the CPU and is off by default.
# alt.utils.set_log_level("DEBUG")

SHOW_PLOTS = alt.utils.show_plots()
Expand Down Expand Up @@ -82,7 +80,6 @@

location_map = alt.import_locations(
alt.resources_root() / "networks/default_locations.csv")

train_sim: alt.SpeedLimitTrainSim = tsb.make_speed_limit_train_sim(
location_map=location_map,
save_interval=SAVE_INTERVAL,
Expand All @@ -99,6 +96,16 @@
False,
)[0]

# whether to override files used by set_speed_train_sim_demo.py
OVERRIDE_SSTS_INPUTS = os.environ.get("OVERRIDE_SSTS_INPUTS", "false").lower() == "true"
if OVERRIDE_SSTS_INPUTS:
print("Overriding files used by `set_speed_train_sim_demo.py`")
link_path = alt.LinkPath([x.link_idx for x in timed_link_path.tolist()])
link_path.to_csv_file(alt.resources_root() / "demo_data/link_path.csv")

# uncomment this line to see example of logging functionality
# alt.utils.set_log_level("DEBUG")

t0 = time.perf_counter()
train_sim.walk_timed_path(
network=network,
Expand All @@ -108,6 +115,16 @@
print(f'Time to simulate: {t1 - t0:.5g}')
assert len(train_sim.history) > 1

# Uncomment the following lines to overwrite `set_speed_train_sim_demo.py` `speed_trace`
if OVERRIDE_SSTS_INPUTS:
speed_trace = alt.SpeedTrace(
train_sim.history.time_seconds.tolist(),
train_sim.history.speed_meters_per_second.tolist()
)
speed_trace.to_csv_file(
alt.resources_root() / "demo_data/speed_trace.csv"
)

loco0:alt.Locomotive = train_sim.loco_con.loco_vec.tolist()[0]

fig, ax = plt.subplots(4, 1, sharex=True)
Expand Down Expand Up @@ -241,10 +258,10 @@

if ENABLE_REF_OVERRIDE:
ref_dir.mkdir(exist_ok=True, parents=True)
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()[-1]
df.write_csv(ref_dir / "to_dataframe_expected.csv")
if ENABLE_ASSERTS:
print("Checking output of `to_dataframe`")
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()
assert to_dataframe_expected.equals(train_sim.to_dataframe())
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()[-1]
assert to_dataframe_expected.equals(train_sim.to_dataframe()[-1])
print("Success!")
14 changes: 10 additions & 4 deletions python/altrios/optimization/cal_and_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ class ModelError(object):
"""
Dataclass class for calculating model error of various ALTRIOS objects w.r.t. test data.
Fields:
Attributes:
- `ser_model_dict`: `dict` variable in which:
- key: a `str` representing trip keyword string
- value: a `str` converted from Rust locomotive models' serialization method
- key: a `str` representing trip keyword string
- value: a `str` converted from Rust locomotive models' serialization method
- `model_type`: `str` that can only be `'ConsistSimulation'`, `'SetSpeedTrainSim'` or `'LocomotiveSimulation'`;
indicates which model to instantiate during optimization process
Expand All @@ -85,7 +87,11 @@ class ModelError(object):
- `params`: a tuple whose individual element is a `str` containing hierarchical paths to parameters
to manipulate starting from one of the 3 possible Rust model structs
- `verbose`: `bool`; if `True`, the verbose of error calculation will be printed
- `verbose`: `bool`: if `True`, the verbose of error calculation will be printed
- `debug`: `bool`: if `True`, prints more stuff
- `allow_partial`: whether to allow partial runs, if True, errors out whenever a run can't be completed
"""
# `ser_model_dict` and `dfs` should have the same keys
ser_model_dict: Dict[str, str]
Expand Down

Large diffs are not rendered by default.

92 changes: 92 additions & 0 deletions python/altrios/resources/demo_data/link_path.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
link_idx
71
70
69
72
164
163
74
67
608
895
167
170
168
169
172
76
78
77
80
362
180
179
79
371
75
82
183
186
83
372
375
48
84
47
227
229
621
616
909
722
623
595
656
620
622
625
782
785
598
891
780
778
779
781
776
777
927
936
913
916
540
1013
1007
1017
1015
1010
1012
1011
1008
1009
1016
553
1018
1014
986
983
1034
560
1053
508
502
505
504
507
510
513
521
558
559
556
557
85 changes: 0 additions & 85 deletions python/altrios/resources/demo_data/link_points_idx.csv

This file was deleted.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading

0 comments on commit 49381bd

Please sign in to comment.