-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scripts for beluga_tutorial visualizations (#387)
### Proposed changes Related to #305, split from #340. #### Type of change - [ ] 🐛 Bugfix (change which fixes an issue) - [x] 🚀 Feature (change which adds functionality) - [ ] 📚 Documentation (change which fixes or extends documentation) ### Checklist - [x] Lint and unit tests (if any) pass locally with my changes - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if appropriate) - [x] All commits have been signed for [DCO](https://developercertificate.org/) --------- Signed-off-by: Alon Druck <[email protected]>
- Loading branch information
Showing
2 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2024 Ekumen, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Visualization Scripts for Beluga Tutorial. | ||
This script uses the matplotlib to visualize the data generated | ||
by beluga_tutorial example code. | ||
""" | ||
|
||
import argparse | ||
import matplotlib.pyplot as plt | ||
import matplotlib.animation as animation | ||
import numpy as np | ||
import yaml | ||
|
||
|
||
def plot_stages(yaml_data, axs, index): | ||
particles = yaml_data['simulation_records'][index]["particles"] | ||
for j, stage in enumerate(['current', 'prediction', 'update']): | ||
states = particles[stage]['states'] | ||
ax = axs[j] | ||
ax.clear() | ||
ax.hist( | ||
states, | ||
bins=np.arange(min(states), max(states) + 0.5, 0.5), | ||
color='skyblue', | ||
edgecolor='black', | ||
) | ||
ax.set_title(f"{stage} - Sim Cycle {index+1}") | ||
ax.set_xlabel("States") | ||
ax.set_ylabel("NP") | ||
ax.set_xlim(-5, 100) | ||
ax.set_xticks(np.arange(-5, 101, 5)) | ||
|
||
landmark_map = yaml_data["landmark_map"] | ||
ax_landmark = axs[3] | ||
ax_landmark.clear() | ||
|
||
for x in range(101): | ||
if x in landmark_map: | ||
color = 'red' | ||
else: | ||
color = 'blue' | ||
ax_landmark.bar(x, 1, color=color) | ||
|
||
ax_landmark.set_title("Landmark Map") | ||
ax_landmark.set_xlabel("States") | ||
ax_landmark.set_xlim(-5, 100) | ||
ax_landmark.set_ylim(0, 2) | ||
ax_landmark.set_xticks(np.arange(-5, 101, 5)) | ||
|
||
ground_truth = yaml_data['simulation_records'][index]["ground_truth"] | ||
ax_ground_truth = axs[4] | ||
ax_ground_truth.clear() | ||
ax_ground_truth.bar(ground_truth, 1, color='green') | ||
ax_ground_truth.set_title(f"Ground Truth: {ground_truth}") | ||
ax_ground_truth.set_xlabel("States") | ||
ax_ground_truth.set_xlim(-5, 100) | ||
ax_ground_truth.set_ylim(0, 2) | ||
ax_ground_truth.set_xticks(np.arange(-5, 101, 5)) | ||
|
||
mean = "{:.3f}".format(yaml_data['simulation_records'][index]["estimation"]["mean"]) | ||
sd = "{:.3f}".format(yaml_data['simulation_records'][index]["estimation"]["sd"]) | ||
plt.text( | ||
0.5, | ||
0.5, | ||
f"Mean: {mean}\nSD: {sd}", | ||
ha='center', | ||
va='center', | ||
transform=axs[4].transAxes, | ||
bbox=dict(facecolor='white', alpha=0.5), | ||
) | ||
|
||
plt.tight_layout() | ||
plt.draw() | ||
|
||
|
||
def main(argv=None) -> int: | ||
"""Run the entry point of the program.""" | ||
parser = argparse.ArgumentParser(description=globals()['__doc__']) | ||
|
||
parser.add_argument( | ||
'record_file_path', | ||
type=str, | ||
help='Absolute path to the record file generated by the beluga_tutorial example code', | ||
) | ||
|
||
parser.add_argument( | ||
'-m', | ||
'--manual-control', | ||
action='store_true', | ||
help='Manual controlling the time steps of the visualization', | ||
) | ||
|
||
parser.add_argument( | ||
'-i', | ||
'--interval-ms', | ||
type=int, | ||
help='Delay between frames in milliseconds.', | ||
required=False, | ||
default=250, | ||
) | ||
|
||
parser.add_argument( | ||
'-r', | ||
'--repeat-animation', | ||
action='store_true', | ||
help='Repeat the animation when it is finished', | ||
) | ||
|
||
args = parser.parse_args(argv) | ||
|
||
with open(args.record_file_path, 'r') as file: | ||
yaml_data = yaml.safe_load(file) | ||
|
||
fig, axs = plt.subplots(5, 1) | ||
num_frames = len(yaml_data['simulation_records']) | ||
|
||
def plot_stages_update(current_frame: int) -> None: | ||
"""Update plots (Helper function).""" | ||
plot_stages(yaml_data, axs, current_frame) | ||
|
||
if args.manual_control: | ||
current_frame = 0 | ||
plot_stages_update(current_frame) | ||
|
||
def on_key(event): | ||
"""Handle key events (Callback function).""" | ||
nonlocal current_frame | ||
|
||
if event.key == 'right': | ||
current_frame += 1 | ||
elif event.key == 'left': | ||
current_frame -= 1 | ||
|
||
current_frame %= num_frames | ||
plot_stages_update(current_frame) | ||
|
||
fig.canvas.mpl_connect('key_press_event', on_key) | ||
|
||
else: | ||
anim = animation.FuncAnimation( # noqa: F841 | ||
fig, | ||
plot_stages_update, | ||
frames=num_frames, | ||
blit=False, | ||
interval=args.interval_ms, | ||
repeat=args.repeat_animation, | ||
) | ||
|
||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
exit(main()) |