diff --git a/reperes/tp_pantilt.zip b/reperes/tp_pantilt.zip index 62e02da..224d6ae 100644 Binary files a/reperes/tp_pantilt.zip and b/reperes/tp_pantilt.zip differ diff --git a/reperes/tp_pantilt/learn_correction.py b/reperes/tp_pantilt/learn_correction.py index 70a36eb..d793f29 100644 --- a/reperes/tp_pantilt/learn_correction.py +++ b/reperes/tp_pantilt/learn_correction.py @@ -1,11 +1,15 @@ import model_correction as model import numpy as np import torch as th +from torchinfo import summary import time from mlp import MLP net = MLP(2, 2) -batch_size = 512 +# summary(net) +# exit() + +batch_size = 256 optimizer = th.optim.Adam(net.parameters(), 1e-3) for k in range(1024): @@ -21,5 +25,5 @@ loss.backward() optimizer.step() print(loss) - -th.save(net.state_dict(), "weights") \ No newline at end of file + +net.save() \ No newline at end of file diff --git a/reperes/tp_pantilt/mlp.py b/reperes/tp_pantilt/mlp.py index 30d09f9..185eb29 100644 --- a/reperes/tp_pantilt/mlp.py +++ b/reperes/tp_pantilt/mlp.py @@ -19,3 +19,9 @@ def __init__(self, input_dimension: int, output_dimension: int): def forward(self, x): return self.net(x) + + def load(self): + self.load_state_dict(th.load("weights")) + + def save(self): + th.save(self.state_dict(), "weights") diff --git a/reperes/tp_pantilt/model_correction.py b/reperes/tp_pantilt/model_correction.py index 1a45dfc..7caaaf9 100644 --- a/reperes/tp_pantilt/model_correction.py +++ b/reperes/tp_pantilt/model_correction.py @@ -15,6 +15,9 @@ def direct(alpha, beta): """ return utils.Rz(alpha) @ utils.translation(0, 0, l1) @ utils.Ry(beta) @ utils.translation(0, 0, l2) + + + def laser(alpha, beta): """ Reçoit en paramètre les angles du robot, retourne la @@ -32,6 +35,12 @@ def laser(alpha, beta): return pos_on_floor[:2] else: return [0., 0.] + + + + + + def inverse(x, y): """ @@ -55,7 +64,7 @@ def inverse_nn(x, y): if net is None: net = MLP(2, 2) - net.load_state_dict(th.load("weights")) + net.load() laser_pos = th.tensor([x, y]) with th.no_grad(): diff --git a/reperes/tp_pantilt/sim.py b/reperes/tp_pantilt/sim.py index c59570b..6220aa8 100644 --- a/reperes/tp_pantilt/sim.py +++ b/reperes/tp_pantilt/sim.py @@ -2,7 +2,7 @@ import time from onshape_to_robot import simulation import pybullet as p -import model +import model_correction as model import argparse sim = simulation.Simulation("pantilt/robot.urdf", fixed=True, panels=True) diff --git a/reperes/tp_pantilt/weights b/reperes/tp_pantilt/weights deleted file mode 100644 index 3fa0170..0000000 Binary files a/reperes/tp_pantilt/weights and /dev/null differ