Skip to content

Commit

Permalink
fixed all the failing tests and added environment variables to make f…
Browse files Browse the repository at this point in the history
…uture updates easier
  • Loading branch information
calbaker committed Sep 12, 2024
1 parent 6a4cf54 commit 953d709
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 727 deletions.
101 changes: 22 additions & 79 deletions python/altrios/demos/set_speed_train_sim_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import pandas as pd
import seaborn as sns
import os
Expand All @@ -16,7 +17,7 @@
SHOW_PLOTS = alt.utils.show_plots()
PYTEST = os.environ.get("PYTEST", "false").lower() == "true"

SAVE_INTERVAL = 1
SAVE_INTERVAL = 100

# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
Expand Down Expand Up @@ -78,13 +79,10 @@
# Load the network and link path through the network.
network = alt.Network.from_file(
alt.resources_root() / "networks/Taconite-NoBalloon.yaml")
network.set_speed_set_for_train_type(alt.TrainType.Freight)
# file created from ./speed_limit_train_sim_demo.py:L92
link_path = alt.LinkPath.from_csv_file(
alt.resources_root() / "demo_data/link_path.csv"
)

# file created from ./speed_limit_train_sim_demo.py:`link_path.to_csv_file`
# load the prescribed speed trace that the train will follow
speed_trace = alt.SpeedTrace.from_csv_file(
alt.resources_root() / "demo_data/speed_trace.csv"
Expand All @@ -105,9 +103,7 @@
t1 = time.perf_counter()
print(f'Time to simulate: {t1 - t0:.5g}')

loco0: alt.Locomotive = train_sim.loco_con.loco_vec.tolist()[0]

fig, ax = plt.subplots(4, 1, sharex=True)
fig, ax = plt.subplots(3, 1, sharex=True)
ax[0].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(train_sim.history.pwr_whl_out_watts) / 1e6,
Expand Down Expand Up @@ -144,85 +140,32 @@
ax[1].set_ylabel('Force [MN]')
ax[1].legend()

ax[2].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(loco0.res.history.soc)
)
ax[2].set_ylabel('SOC')

ax[-1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.history.speed_meters_per_second,
label='achieved'
)
ax[-1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.history.speed_limit_meters_per_second,
label='limit'
np.array(train_sim.speed_trace.speed_meters_per_second)[::SAVE_INTERVAL][1:],
)
ax[-1].set_xlabel('Time [hr]')
ax[-1].set_ylabel('Speed [m/s]')
ax[-1].legend()
plt.suptitle("Set Speed Train Sim Demo")

fig1, ax1 = plt.subplots(3, 1, sharex=True)
ax1[0].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(train_sim.history.offset_in_link_meters) / 1_000,
label='current link',
)
ax1[0].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(train_sim.history.offset_meters) / 1_000,
label='overall',
)
ax1[0].legend()
ax1[0].set_ylabel('Net Dist. [km]')

ax1[1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.history.link_idx_front,
linestyle='',
marker='.',
)
ax1[1].set_ylabel('Link Idx Front')

ax1[-1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.history.speed_meters_per_second,
)
ax1[-1].set_xlabel('Time [hr]')
ax1[-1].set_ylabel('Speed [m/s]')

plt.suptitle("Set Speed Train Sim Demo")
plt.tight_layout()



fig2, ax2 = plt.subplots(3, 1, sharex=True)
ax2[0].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(train_sim.history.pwr_whl_out_watts) / 1e6,
label="tract pwr",
)
ax2[0].set_ylabel('Power [MW]')
ax2[0].legend()

ax2[1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
np.array(train_sim.history.grade_front) * 100.,
)
ax2[1].set_ylabel('Grade [%] at\nHead End')

ax2[-1].plot(
np.array(train_sim.history.time_seconds) / 3_600,
train_sim.history.speed_meters_per_second,
)
ax2[-1].set_xlabel('Time [hr]')
ax2[-1].set_ylabel('Speed [m/s]')

plt.suptitle("Set Speed Train Sim Demo")
plt.tight_layout()

if SHOW_PLOTS:
plt.tight_layout()
plt.show()

# whether to run assertions, enabled by default
ENABLE_ASSERTS = os.environ.get("ENABLE_ASSERTS", "true").lower() == "true"
# whether to override reference files used in assertions, disabled by default
ENABLE_REF_OVERRIDE = os.environ.get("ENABLE_REF_OVERRIDE", "false").lower() == "true"
# directory for reference files for checking sim results against expected results
ref_dir = alt.resources_root() / "demo_data/set_speed_train_sim_demo/"

if ENABLE_REF_OVERRIDE:
ref_dir.mkdir(exist_ok=True, parents=True)
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()[-1]
df.write_csv(ref_dir / "to_dataframe_expected.csv")
if ENABLE_ASSERTS:
print("Checking output of `to_dataframe`")
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()[-1]
assert to_dataframe_expected.equals(train_sim.to_dataframe()[-1])
print("Success!")
36 changes: 19 additions & 17 deletions python/altrios/demos/speed_limit_train_sim_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
import altrios as alt
sns.set_theme()

# Uncomment and run `maturin develop --release --features logging` to enable logging,
# which is needed because logging bogs the CPU and is off by default.
# alt.utils.set_log_level("DEBUG")

SHOW_PLOTS = alt.utils.show_plots()

SAVE_INTERVAL = 1
SAVE_INTERVAL = 100

# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
Expand Down Expand Up @@ -82,7 +80,6 @@

location_map = alt.import_locations(
alt.resources_root() / "networks/default_locations.csv")

train_sim: alt.SpeedLimitTrainSim = tsb.make_speed_limit_train_sim(
location_map=location_map,
save_interval=SAVE_INTERVAL,
Expand All @@ -99,9 +96,12 @@
False,
)[0]

# Uncomment the following lines to overwrite `set_speed_train_sim_demo.py` `link_path`
# link_path = alt.LinkPath([x.link_idx for x in timed_link_path.tolist()])
# link_path.to_csv_file(alt.resources_root() / "demo_data/link_path.csv")
# whether to override files used by set_speed_train_sim_demo.py
OVERRIDE_SSTS_INPUTS = os.environ.get("OVERRIDE_SSTS_INPUTS", "false").lower() == "true"
if OVERRIDE_SSTS_INPUTS:
print("Overriding files used by `set_speed_train_sim_demo.py`")
link_path = alt.LinkPath([x.link_idx for x in timed_link_path.tolist()])
link_path.to_csv_file(alt.resources_root() / "demo_data/link_path.csv")

# uncomment this line to see example of logging functionality
# alt.utils.set_log_level("DEBUG")
Expand All @@ -116,13 +116,14 @@
assert len(train_sim.history) > 1

# Uncomment the following lines to overwrite `set_speed_train_sim_demo.py` `speed_trace`
# speed_trace = alt.SpeedTrace(
# train_sim.history.time_seconds.tolist(),
# train_sim.history.speed_meters_per_second.tolist()
# )
# speed_trace.to_csv_file(
# alt.resources_root() / "demo_data/speed_trace.csv"
# )
if OVERRIDE_SSTS_INPUTS:
speed_trace = alt.SpeedTrace(
train_sim.history.time_seconds.tolist(),
train_sim.history.speed_meters_per_second.tolist()
)
speed_trace.to_csv_file(
alt.resources_root() / "demo_data/speed_trace.csv"
)

loco0:alt.Locomotive = train_sim.loco_con.loco_vec.tolist()[0]

Expand Down Expand Up @@ -244,6 +245,7 @@


if SHOW_PLOTS:
plt.tight_layout()
plt.show()
# Impact of sweep of battery capacity TODO: make this happen

Expand All @@ -256,10 +258,10 @@

if ENABLE_REF_OVERRIDE:
ref_dir.mkdir(exist_ok=True, parents=True)
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()
df:pl.DataFrame = train_sim.to_dataframe().lazy().collect()[-1]
df.write_csv(ref_dir / "to_dataframe_expected.csv")
if ENABLE_ASSERTS:
print("Checking output of `to_dataframe`")
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()
assert to_dataframe_expected.equals(train_sim.to_dataframe())
to_dataframe_expected = pl.scan_csv(ref_dir / "to_dataframe_expected.csv").collect()[-1]
assert to_dataframe_expected.equals(train_sim.to_dataframe()[-1])
print("Success!")

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading

0 comments on commit 953d709

Please sign in to comment.