Skip to content

Commit

Permalink
Updated data structure and solved a couple of bugs in train planner
Browse files Browse the repository at this point in the history
  • Loading branch information
FDsteven committed Sep 3, 2024
1 parent 4173773 commit 0b583b6
Showing 1 changed file with 60 additions and 44 deletions.
104 changes: 60 additions & 44 deletions python/altrios/train_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,18 @@ def generate_demand_trains(
demand_returns.drop("Number_of_Containers"),
demand_rebalancing],
how="diagonal_relaxed")

# if rowx[first three columns] == rowy[first three columns]:
# rowx[fourth column] + rowy[fourth column]
# delete rowy
# combined_row = demand.slice(10,12).select
# demand = demand.group_by()
#Prepare hp_per_ton requirements to merge onto the demand DataFrame
hp_per_ton = (
pl.DataFrame(pd.DataFrame(config.hp_required_per_ton).reset_index(names="Train_Type"))
.melt(id_vars="Train_Type",variable_name="O_D",value_name="HP_Required_Per_Ton")
.with_columns(pl.col("O_D").str.split("_").list.first().alias("Origin"),
pl.col("O_D").str.split("_").list.last().alias("Destination"))
).lazy()
)

#MPrepare ton_per_car requirements to merge onto the demand DataFrame
# TODO: simplify mass API here. Is there a function on the Rust side to get total mass (or should there be)?
Expand All @@ -320,43 +324,50 @@ def get_kg(veh):
.otherwise(pl.col("KG") / utilities.KG_PER_TON)
.alias("Tons_Per_Car"))
.drop(["KG_Empty","KG_Loaded"])
).lazy()
)

demand = (demand.lazy()
.join(ton_per_car, on="Train_Type", how="left")
demand = demand.join(ton_per_car, on="Train_Type", how="left")
# Merge on OD-specific hp_per_ton if the user specified any
.join(hp_per_ton.drop("O_D"),
on=["Origin","Destination","Train_Type"],
how="left")
demand = demand.join(hp_per_ton.drop("O_D"),
on=["Origin","Destination","Train_Type"],
how="left")
# Second, merge on defaults per train type
.join(hp_per_ton.filter((pl.col("O_D") =="Default")).drop(["O_D","Origin","Destination"]),
on=["Train_Type"],
how="left",
suffix="_Default")
demand = demand.join(hp_per_ton.filter((pl.col("O_D") =="Default")).drop(["O_D","Origin","Destination"]),
on=["Train_Type"],
how="left",
suffix="_Default")
# Fill in defaults per train type wherever the user didn't specify OD-specific hp_per_ton
.with_columns(pl.coalesce("HP_Required_Per_Ton", "HP_Required_Per_Ton_Default").alias("HP_Required_Per_Ton"))
.drop("HP_Required_Per_Ton_Default")
demand = demand.with_columns(pl.coalesce("HP_Required_Per_Ton", "HP_Required_Per_Ton_Default").alias("HP_Required_Per_Ton"))
demand = demand.drop("HP_Required_Per_Ton_Default")
# Replace nulls with zero
.with_columns(cs.float().fill_null(0.0), cs.by_dtype(pl.UInt32).fill_null(pl.lit(0).cast(pl.UInt32)))
demand = demand.with_columns(cs.float().fill_null(0.0), cs.by_dtype(pl.UInt32).fill_null(pl.lit(0).cast(pl.UInt32)))
# Convert total number of cars to total number of trains
.with_columns(
(pl.col("Number_of_Cars") * pl.col("Tons_Per_Car")).alias("Tons_Aggregate"),
pl.when(config.single_train_mode)
.then(1)
.when(pl.col("Number_of_Cars") == 0)
.then(0)
.otherwise(
pl.max_horizontal([1,
(pl.when(pl.col("Number_of_Cars").mod(pl.lit(config.target_cars_per_train)).gt(pl.lit(config.min_cars_per_train)))
.then(pl.col("Number_of_Cars").floordiv(pl.lit(config.target_cars_per_train)))
.otherwise(pl.col("Number_of_Cars").floordiv(pl.lit(config.target_cars_per_train))+1))
])
)
.cast(pl.UInt32).alias("Number_of_Trains"))
demand = demand.with_columns(
(pl.col("Number_of_Cars") * pl.col("Tons_Per_Car")).alias("Tons_Aggregate"),
pl.when(config.single_train_mode)
.then(1)
.when(pl.col("Number_of_Cars") == 0)
.then(0)
.otherwise(
pl.max_horizontal([1,
((pl.col("Number_of_Cars").floordiv(pl.lit(config.target_cars_per_train)) + 1))
])
).cast(pl.UInt32).alias("Number_of_Trains"))
# Calculate per-train car counts and tonnage
.with_columns(
pl.col("Tons_Aggregate").truediv(pl.col("Number_of_Trains")).alias("Tons_Per_Train"))
).collect()
demand = demand.with_columns(
pl.col("Tons_Aggregate").truediv(pl.col("Number_of_Trains")).alias("Tons_Per_Train"))
demand = demand.with_columns(
(pl.when(pl.col("Train_Type").str.ends_with("_Empty"))
.then(pl.col("Number_of_Cars"))
.otherwise(0))
.cast(pl.UInt32)
.alias("Cars_Empty"),
(pl.when(pl.col("Train_Type").str.ends_with("_Empty"))
.then(0)
.otherwise(pl.col("Number_of_Cars")))
.cast(pl.UInt32)
.alias("Cars_Loaded")
)
return demand


Expand All @@ -377,22 +388,25 @@ def calculate_dispatch_times(
dispatch_times: Tabulated dispatching time for each demand pair for each train type
in hours
"""
return (demand
.filter(pl.col("Number_of_Trains") > 0)
demand = demand \
.filter(pl.col("Number_of_Trains") > 0) \
.select(["Origin","Destination","Train_Type","Number_of_Trains",
"Number_of_Cars",
"Tons_Per_Train","HP_Required_Per_Ton"])
"Tons_Per_Train","HP_Required_Per_Ton", "Cars_Loaded", "Cars_Empty"]) \
.with_columns(
(hours / pl.col("Number_of_Trains")).alias("Interval"),
pl.col("Number_of_Trains").cast(pl.Int32).alias("Number_of_Trains"))
.select(pl.exclude("Number_of_Trains").repeat_by("Number_of_Trains").explode())
pl.col("Number_of_Trains").cast(pl.Int32).alias("Number_of_Trains"),
pl.col("Number_of_Cars").floordiv(pl.col("Number_of_Trains")).alias("Number_of_Cars"),
pl.col("Cars_Empty").floordiv(pl.col("Number_of_Trains")).alias("Cars_Empty"),
pl.col("Cars_Loaded").floordiv(pl.col("Number_of_Trains")).alias("Cars_Loaded"),
).select(pl.exclude("Number_of_Trains").repeat_by("Number_of_Trains").explode()) \
.with_columns(
((pl.col("Interval").cumcount().over(["Origin","Destination","Train_Type"]) + 1) \
* pl.col("Interval")).alias("Hour"))
.drop("Interval")
((pl.col("Interval").cumcount().over(["Origin","Destination","Train_Type"])) \
* pl.col("Interval")).alias("Hour")
).drop("Interval") \
.sort(["Hour","Origin","Destination","Train_Type"])
)


return demand
def build_locopool(
config: TrainPlannerConfig,
demand_file: Union[pl.DataFrame, Path, str],
Expand Down Expand Up @@ -1069,7 +1083,9 @@ def run_train_planner(
loco_pool.filter(selected).get_column('Locomotive_Type'),
pl.Series(repeat(this_train['Origin'], new_row_count)),
pl.Series(repeat(this_train['Destination'], new_row_count)),
pl.Series(repeat(this_train['Number_of_Cars'], new_row_count)),
pl.Series(repeat(this_train['Cars_Loaded'], new_row_count)),
pl.Series(repeat(this_train['Cars_Empty'], new_row_count)),
# pl.Series(repeat(this_train['Number_of_Cars'], new_row_count)),
loco_start_soc_j,
pl.Series(repeat(current_time, new_row_count)),
pl.Series(repeat(current_time + travel_time, new_row_count))],
Expand Down Expand Up @@ -1121,7 +1137,7 @@ def run_train_planner(
str(alt.resources_root() / "networks/default_locations.csv")
)
network = alt.Network.from_file(
str(alt.resources_root() / "networks/Taconite.yaml")
str(alt.resources_root() / "networks/Taconite-NoBalloon.yaml")
)
config = TrainPlannerConfig()
loco_pool = build_locopool(config, defaults.DEMAND_FILE)
Expand Down

0 comments on commit 0b583b6

Please sign in to comment.