-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the automatic differentiation multibody solver based on JAX #305
base: develop
Are you sure you want to change the base?
Conversation
Also cleaned up a lot of unneeded comments in some files, and prevents one of the tests from leaving undeleted .vtu files.
to the angle of attack (in radians) and then the ``C_L``, ``C_D`` and ``C_M``. | ||
to the angle of attack (in radians) and then the ``C_L``, ``C_D`` and ``C_M``. | ||
|
||
Multibody file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New documentation - should be general for both multibody implementations.
@@ -145,7 +145,8 @@ def run(self): | |||
"openpyxl>=3.0.10", | |||
"lxml>=4.4.1", | |||
"PySocks", | |||
"PyYAML" | |||
"PyYAML", | |||
"jax", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added JAX as a new dependency during install
@@ -270,15 +265,28 @@ def generate_zeta_timestep_info(self, structure_tstep, aero_tstep, beam, setting | |||
raise NotImplementedError(str(self.data_dict['control_surface_type'][i_control_surface]) + | |||
' control surfaces are not yet implemented') | |||
|
|||
|
|||
# add sweep for aerogrid warping in constraint defintition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New code for dynamically sweeping the aero grid. This shouldn't impact any existing cases as it will ignore it if the warp factor parameter does not exist.
@@ -62,7 +61,7 @@ def get_coefs(self, aoa_deg): | |||
cd = self.cd_interp(aoa_deg) | |||
cm = self.cm_interp(aoa_deg) | |||
|
|||
return cl, cd, cm | |||
return cl[0], cd[0], cm[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed one of the pesky errors converting an array to a scalar during the unit test - this seems to work fine.
|
||
|
||
@controller_interface.controller | ||
class MultibodyController(controller_interface.BaseController): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New multibody controller for setting angle between beams.
@@ -0,0 +1,368 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New solver, very similar to the existing multibody solver in its inputs.
@@ -323,11 +312,10 @@ def run(self, **kwargs): | |||
""" | |||
|
|||
aero_tstep = settings_utils.set_value_or_default(kwargs, 'aero_step', self.data.aero.timestep_info[-1]) | |||
structure_tstep = settings_utils.set_value_or_default(kwargs, 'structural_step', self.data.structure.timestep_info[-1]) | |||
convect_wake = settings_utils.set_value_or_default(kwargs, 'convect_wake', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cut some parameters here as they aren't actually referenced again
@@ -71,6 +71,8 @@ class NewmarkBeta(_BaseTimeIntegrator): | |||
|
|||
def __init__(self): | |||
|
|||
self.sys_size = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some variables used later were not declared in init
@@ -0,0 +1,209 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New time integrator for the JAX solver, should have the same numerics as the existing implementation.
@@ -544,7 +476,8 @@ def get_body(self, ibody): | |||
int_list_nodes = np.arange(0, ibody_beam.num_node, 1) | |||
for ielem in range(ibody_beam.num_elem): | |||
for inode_in_elem in range(ibody_beam.num_node_elem): | |||
ibody_beam.connectivities[ielem, inode_in_elem] = int_list_nodes[ibody_nodes == ibody_beam.connectivities[ielem, inode_in_elem]] | |||
ibody_beam.connectivities[ielem, inode_in_elem] = int_list_nodes[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another fix for converting arrays to scalars which creates a warning during unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good spot and fix!
@@ -0,0 +1,498 @@ | |||
from typing import Callable, Any, Optional, Type, cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where the multibody constraint numerics take place
@@ -1989,13 +1991,15 @@ def check(self): | |||
raise RuntimeError(("'behaviour' parameter is required in '%s' lagrange constraint" % self.behaviour)) | |||
|
|||
|
|||
def generate_multibody_file(list_LagrangeConstraints, list_Bodies, route, case_name): | |||
def generate_multibody_file(list_LagrangeConstraints, list_Bodies, route, case_name, use_jax=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to add this additional setting as the generate case class checks for constraints in the existing solver; as I have new constraints which aren't implemented there, it fails when this is False.
@@ -0,0 +1,325 @@ | |||
""" | |||
Multibody library for the NonlinearDynamicMultibodyJAX solver |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy of the existing multibody utils library, with some things cut.
@@ -322,7 +301,7 @@ def test_doublependulum_hinge_slanted_lateralrot(self): | |||
def tearDown(self): | |||
solver_path = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) | |||
solver_path += '/' | |||
for name in [name_hinge_slanted, name_hinge_slanted_pen, name_hinge_slanted_lateralrot]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed an issue with the double angled pendulum test case where the names weren't correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh sorry! my bad😂😂
@@ -0,0 +1,368 @@ | |||
import numpy as np | |||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add typing as a dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing is the built-in Python type hinting, so this doesn't require any new packages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhhhhh ok! thanks
sharpy/solvers/statictrim.py
Outdated
@@ -317,17 +317,12 @@ def trim_algorithm(self): | |||
|
|||
def evaluate(self, alpha, deflection_gamma, thrust): | |||
if not np.isfinite(alpha): | |||
import pdb; pdb.set_trace() | |||
raise ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think the reason behind breaking and not terminating the run is because staticcoupled has no post-processing - so rather than killing it (and leaving no trace to what happened) they preferred to leave a break point here. If we haven't got plans to introduce post-processing for staticcoupled iterations, is there any neater alternative to just throwing an error here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was having the trace cause me issues when running on the HPC a while ago, I'm not sure if it's good practice to have traces like this in “production” code (although I may be wrong). This error is thrown when one of the trim gradients is zero, which is currently occurring for some of my cases and I believe is a bug, that I'm currently looking into. If static coupled fails, the code won't get this far, as the FORTRAN will instead throw a singular matrix error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that makes sense - I meant more in the off chance that staticcoupled did return eg. same total forces for a parametrised run with different geometry (perhaps the staticcoupled get total forces function working on the timestep[-1] time information, and somehow got contaminated? just a wild guess) - that would kill the trim routine like what you've seen. but yeah I agree it is probably best to clean up production code for an HPC environment assuming no user input is possible further
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a message to the ValueError
so that it is possible to trace where it is coming from?
@@ -208,7 +196,7 @@ def generate_aero_file(): | |||
|
|||
working_elem = 0 | |||
working_node = 0 | |||
# right wing (surface 0, beam 0) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my take is to leave these in - it helps with understanding the model generation procedure which if now there's no simplification coming up the pipeline seems to have a big learning curve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot, must have gotten carried away with deleting commented code at some point! Will add these comments back in
@@ -207,7 +193,7 @@ def generate_aero_file(): | |||
|
|||
working_elem = 0 | |||
working_node = 0 | |||
# right wing (surface 0, beam 0) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise, I suggest leaving them in, happy to have a discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @ben-l-p! Thanks for building a parallel for most files used by the new jax routines - that should hopefully make the work of migrating over in the future much easier. Happy to merge given this passes all tests and that you've been using it day-to-day already.
sharpy/aero/models/aerogrid.py
Outdated
try: | ||
cst_name = f"constraint_{i_constraint:02d}" | ||
ctrl_id = structure_tstep.mb_dict[cst_name]['controller_id'].decode('UTF-8') | ||
f_warp = structure_tstep.mb_dict[cst_name]['aerogrid_warp_factor'][i_elem, i_local_node] | ||
ang_z = structure_tstep.mb_prescribed_dict[ctrl_id]['delta_psi'][2] | ||
ang_warp += f_warp * ang_z | ||
except KeyError: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why is a KeyError
ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constraints here are input as dictionaries and have different key-value pairs depending on their functionality (hinge axis, node number etc.). Warping the aerodynamic grid will occur here if a constraint has both controller_id
and aerogrid_warp_factor
entries, with the intended behavior to skip this code if a constraint is missing one. A key error will occur if it's a constraint that is missing either of these entries (and therefore is not a constraint which requires the warping). Of course, it could technically fail if I had coded it wrong and delta_psi
is not defined, and this would incorrectly ignore this code, but I'm pretty sure that can't happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Is the logic of it along this line?
if 'controller_id' in structure_tstep.mb_dict[cst_name] and 'aerogrid_warp_factor' in structure_tstep.mb_dict[cst_name]:
f_warp = structure_tstep.mb_dict[cst_name]['aerogrid_warp_factor'][i_elem, i_local_node]
ctrl_id = structure_tstep.mb_dict[cst_name]['controller_id'].decode('UTF-8')
ang_z = structure_tstep.mb_prescribed_dict[ctrl_id]['delta_psi'][2]
ang_warp += f_warp * ang_z
) | ||
|
||
def __init__(self): | ||
self.in_dict = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's in_dict
? Is it an input dictionary or a boolean about whether something is in a dictionary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have included this to be consistent with the other controllers available, it's the settings dictionary, and I can add a comment to state this. I do not like how the controller code (for all of them) is structured, but I would rather they're all the same at least.
if controlled_state["structural"].mb_prescribed_dict is None: | ||
controlled_state["structural"].mb_prescribed_dict = dict() | ||
controlled_state["structural"].mb_prescribed_dict[self.controller_id] = { | ||
"psi": control_command, | ||
"psi_dot": psi_dot, | ||
} | ||
controlled_state["structural"].mb_prescribed_dict[self.controller_id].update( | ||
{"delta_psi": control_command - self.prescribed_ang_time_history[0, :]} | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there isn't a need to declare the dict()
and use update()
by simplifying it to,
controlled_state["structural"].mb_prescribed_dict[self.controller_id] = {
"psi": control_command,
"psi_dot": psi_dot,
"delta_psi": control_command - self.prescribed_ang_time_history[0, :]
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot!
sharpy/solvers/statictrim.py
Outdated
@@ -317,17 +317,12 @@ def trim_algorithm(self): | |||
|
|||
def evaluate(self, alpha, deflection_gamma, thrust): | |||
if not np.isfinite(alpha): | |||
import pdb; pdb.set_trace() | |||
raise ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a message to the ValueError
so that it is possible to trace where it is coming from?
sharpy/utils/generate_cases.py
Outdated
try: | ||
constraint_id.create_dataset("rot_axisA2", | ||
data=getattr(constraint, "rot_axisA2")) | ||
except: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is existing code. But, why are there try
, except
blocks that don't catch any error? Should they be like the newly introduced code where it catches an AttributeError
?
lc_settings: list[dict] = [] | ||
self.num_lm_tot = 0 | ||
for i in range(self.data.structure.ini_mb_dict['num_constraints']): | ||
lc_settings.append(self.data.structure.ini_mb_dict[f'constraint_{i:02d}']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small style comment - If i
is always going to be an int
as it comes from range()
, then there shouldn't be a need to specify the format as f"{i:0d}"
, f"{i}"
should be enough
(there are a couple more of this below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number needs to be formatted to have two digits, i.e. 00, 01, 02. Not a fan of this method as it means this has to happen in a few places in the code, but unless if I overhaul the multibody then it's a necessity for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see! That makes sense. tbf I didn't realise that it had to have two digits so that's my bad
Btw I really like your PR message, it's super clear about what's in the PR etc |
I have added a new solver
NonlinearDynamicMultibodyJAX
with the similar functionality asNonlinearDynamicMultibody
, except it does not include anyDynamicTrim
routine (this can be added in due course when required). This new solver makes use of AD to calculate the Jacobians which leads to a much nicer and more compact definition of constraints, particularly for use of the penalty method. A lot of the key constraint types from the original solver are included here, such as:The constraints numerics and the derivatives are defined in the
sharpy/structure/utils/lagrangeconstraintsjax
file. I have included new versions of all the other parts of the framework to prevent changes to one breaking the other solver etc., hence whysharpy/utils/multibodyjax
andsharpy/solvers/timeintegratorsjax
exist.Also included is the ability for controlled actuation between beams, which I primarily implemented for testing variable sweep wings, however the formulation/implementation is general for any 3D rotation. These can be controlled with the new
sharpy/controllers/multibodycontroller
controller type, which takes a Cartesian rotation vector time series as input. To allow this to sweep the wing correctly, the ability to warp the aero grid has been added with a newaerogrid_warp_factor
parameter in the multibody file. This allows for a gradual sweep around a kink, and should not effect for existing cases, as it has no effect is the parameter is not included in the H5 file.A test case for this solver is included in the form of a flexible double pendulum comparison. A free double pendulum case is run, the angles from the two hinges extracted and applied onto a prescribed model, where both should yield the “same” result for structural deflections.
Also included is some documentation on the multibody case files (should be general for both multibody solvers). In recent testing I have found a bug in the
StaticTrim
routine which was not present in v2.0 and I am currently investigating, which will be another PR in due course, but it is not connected to the multibody implementation.I have also done some code cleanups I found along the way, due to PyCharm largely doing it for me, but for files not related to the new solver the functionality should not have been changed. Lastly, I found one of the unit test cases was creating files which didn't get deleted at the end, which has also been fixed.