diff --git a/examples-gallery/betts_10_50_solution.npy b/examples-gallery/betts_10_50_solution.npy new file mode 100644 index 00000000..7c3082e1 Binary files /dev/null and b/examples-gallery/betts_10_50_solution.npy differ diff --git a/opty/direct_collocation.py b/opty/direct_collocation.py index 67175244..6b04417f 100644 --- a/opty/direct_collocation.py +++ b/opty/direct_collocation.py @@ -436,7 +436,7 @@ def plot_trajectories(self, vector, axes=None): return axes @_optional_plt_dep - def plot_constraint_violations(self, vector, axes=None): + def plot_constraint_violations(self, vector, axes=None, detailed_eoms=False): """Returns an axis with the state constraint violations plotted versus node number and the instance constraints as a bar graph. @@ -446,6 +446,12 @@ def plot_constraint_violations(self, vector, axes=None): The initial guess, solution, or any other vector that is in the canonical form. + detailed_eoms : boolean, optional. If True, the equations of motion + will be plotted in a separate plot for each state. Default is False. + + axes : ndarray of AxesSubplot, optional. If given, it is the user's + responsibility to provide the correct number of axes. + Returns ======= axes : ndarray of AxesSubplot @@ -523,16 +529,43 @@ def plot_constraint_violations(self, vector, axes=None): con_nodes = range(1, self.collocator.num_collocation_nodes) if axes is None: - fig, axes = plt.subplots(1 + num_plots, 1, - figsize=(6.4, 1.50*(1 + num_plots)), - layout='compressed') + if detailed_eoms == False or self.collocator.num_states == 1: + num_eom_plots = 1 + else: + num_eom_plots = self.collocator.num_states + fig, axes = plt.subplots(num_eom_plots + num_plots, 1, + figsize=(6.4, 1.75*(num_eom_plots + num_plots)), + constrained_layout=True) + + else: + num_eom_plots = len(axes) - num_plots axes = np.asarray(axes).ravel() - axes[0].plot(con_nodes, state_violations.T) - axes[0].set_title('Constraint violations') - axes[0].set_xlabel('Node Number') - axes[0].set_ylabel('EoM violation') + if detailed_eoms == False or self.collocator.num_states == 1: + axes[0].plot(con_nodes, state_violations.T) + axes[0].set_title('Constraint violations') + axes[0].set_xlabel('Node Number') + axes[0].set_ylabel('EoM violation') + + else: + for i in range(self.collocator.num_states): + k = i + 1 + if k in (11,12,13): + msg = 'th' + elif k % 10 == 1: + msg = 'st' + elif k % 10 == 2: + msg = 'nd' + elif k % 10 == 3: + msg = 'rd' + else: + msg = 'th' + + axes[i].plot(con_nodes, state_violations[i]) + axes[i].set_ylabel(f'{str(k)}-{msg} EOM violation') + axes[num_eom_plots-1].set_xlabel('Node Number') + axes[0].set_title('Constraint violations') if self.collocator.instance_constraints is not None: # reduce the instance constrtaints to 2 digits after the decimal @@ -575,11 +608,11 @@ def plot_constraint_violations(self, vector, axes=None): inst_constr = instance_constr_plot[beginn: endd] width = [0.06*num_ticks for _ in range(num_ticks)] - axes[i+1].bar(range(num_ticks), inst_viol, + axes[i+num_eom_plots].bar(range(num_ticks), inst_viol, tick_label=[sm.latex(s, mode='inline') for s in inst_constr], width=width) - axes[i+1].set_ylabel('Instance') - axes[i+1].set_xticklabels(axes[i+1].get_xticklabels(), + axes[i+num_eom_plots].set_ylabel('Instance') + axes[i+num_eom_plots].set_xticklabels(axes[i+num_eom_plots].get_xticklabels(), rotation=rotation) return axes