Skip to content

Commit

Permalink
Improve constraint violation visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
tjstienstra committed Feb 14, 2024
1 parent 6a23d17 commit 912b7ed
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions opty/direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,27 +367,30 @@ def plot_constraint_violations(self, vector):
- r : number of unknown parameters
"""

N = self.collocator.num_collocation_nodes
con_violations = self.con(vector)
con_nodes = range(self.collocator.num_states,
self.collocator.num_collocation_nodes + 1)
N = len(con_nodes)
fig, axes = plt.subplots(self.collocator.num_states + 1)

for i, (ax, symbol) in enumerate(zip(axes[:-1],
self.collocator.state_symbols)):
ax.plot(con_nodes, con_violations[i * N:i * N + N])
ax.set_ylabel(sm.latex(symbol, mode='inline'))

axes[0].set_title('Constraint Violations')
axes[-2].set_xlabel('Node Number')

left = range(len(con_violations[self.collocator.num_states * N:]))
axes[-1].bar(left, con_violations[self.collocator.num_states * N:],
tick_label=[sm.latex(s, mode='inline')
for s in self.collocator.instance_constraints])
axes[-1].set_ylabel('Instance')
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10)
state_violations = con_violations[
:(N - 1) * len(self.collocator.state_symbols)]
instance_violations = con_violations[len(state_violations):]
state_violations = state_violations.reshape(
(len(self.collocator.state_symbols), N - 1))
con_nodes = range(1, self.collocator.num_collocation_nodes)

plot_inst_viols = self.collocator.instance_constraints is not None
fig, axes = plt.subplots(1 + plot_inst_viols)

axes[0].plot(con_nodes, state_violations.T)
axes[0].set_title('Constraint violations')
axes[0].set_xlabel('Node Number')
axes[0].set_ylabel('EoM violation')

if plot_inst_viols:
axes[-1].bar(
range(len(instance_violations)), instance_violations,
tick_label=[sm.latex(s, mode='inline')
for s in self.collocator.instance_constraints])
axes[-1].set_ylabel('Instance')
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10)

return axes

Expand Down

0 comments on commit 912b7ed

Please sign in to comment.