Skip to content

Commit

Permalink
more trainier tweaks.
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Sep 14, 2024
1 parent 54b9f10 commit a0ab4ca
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ namespace ndAdvancedRobot
ndFloat32 rollReward = ndExp(-100.0f * deltaRoll * deltaRoll);
ndFloat32 pitchReward = ndExp(-100.0f * deltaPitch * deltaPitch);

const ndFloat32 rewardWeight = 1.0 / 6.0f;
return rewardWeight * (posit_xReward + posit_yReward + azimuthReward + yawReward + rollReward + pitchReward);
const ndFloat32 rewardWeight = 1.0 / 8.0f;
return rewardWeight * (posit_xReward + posit_yReward + yawReward + rollReward + pitchReward + 3.0f * azimuthReward);

#else
ndQuaternion effectorRotation(effectorMatrix);
Expand Down Expand Up @@ -620,9 +620,29 @@ namespace ndAdvancedRobot
observation->m_delta_y = ndBrainFloat(m_targetLocation.m_y - currenPosit.m_y);
observation->m_deltaAzimuth = ndBrainFloat(ndAnglesSub(m_targetLocation.m_azimuth, azimuth));

//ndTrace(("%f %f %f %f\n", rotation.m_x, rotation.m_y, rotation.m_z, rotation.m_w));
//ndTrace(("%f %f %f %f\n\n", currenPosit.m_x, currenPosit.m_y, currenPosit.m_z, azimuth * ndRadToDegree));
//ndTrace(("%f %f %f %f\n", m_targetLocation.m_x, m_targetLocation.m_y, 0.0f, m_targetLocation.m_azimuth * ndRadToDegree));
//static int xxxx = 0;
//if (xxxx)
//{
// const ndMatrix invBaseMatrix(m_base_rotator->CalculateGlobalMatrix1().OrthoInverse());
// const ndMatrix effectorMatrix(m_effectorMatrixOffset * m_arm_4->CalculateGlobalMatrix0() * invBaseMatrix);
//
// ndFloat32 azimuth1 = 0.0f;
// const ndVector& posit = effectorMatrix.m_posit;
// if ((posit.m_y * posit.m_y + posit.m_z * posit.m_z) > 1.0e-3f)
// {
// azimuth1 = ndAtan2(posit.m_z, posit.m_y);
// }
// const ndMatrix aximuthMatrix(ndPitchMatrix(azimuth1));
// const ndVector currenPosit(aximuthMatrix.UnrotateVector(posit) - m_effectorPositOffset);
//
// ndFloat32 dx = m_targetLocation.m_x - currenPosit.m_x;
// ndFloat32 dy = m_targetLocation.m_y - currenPosit.m_y;
// ndFloat32 dAzimuth = ndAnglesSub(m_targetLocation.m_azimuth, azimuth);
//
// ndAssert(ndAbs(dx - observation->m_delta_x) < 1.0e-3f);
// ndAssert(ndAbs(dy - observation->m_delta_y) < 1.0e-3f);
// ndAssert(ndAbs(dAzimuth - observation->m_deltaAzimuth) < 1.0e-3f);
//}
}

//#pragma optimize( "", off )
Expand Down Expand Up @@ -1050,14 +1070,15 @@ namespace ndAdvancedRobot
,m_bestActor()
,m_outFile(nullptr)
,m_timer(ndGetTimeInMicroseconds())
//,m_horizon(ndFloat32(1.0f) / (ndFloat32(1.0f) - m_discountFactor))
,m_maxScore(ndFloat32(-1.0e10f))
,m_saveScore(m_maxScore)
,m_discountFactor(0.99f)
,m_horizon(ndFloat32(1.0f) / (ndFloat32(1.0f) - m_discountFactor))
,m_lastEpisode(0xffffffff)
,m_stopTraining(ndUnsigned32(4000)* ndUnsigned32(1000000))
,m_savePartial(0)
,m_stopTraining(ndUnsigned32(2000)* ndUnsigned32(1000000))
,m_modelIsTrained(false)
{
m_horizon = ndFloat32(1.0f) / (ndFloat32(1.0f) - m_discountFactor);
m_outFile = fopen("robotArmReach-vpg.csv", "wb");
fprintf(m_outFile, "vpg\n");

Expand Down Expand Up @@ -1099,8 +1120,8 @@ namespace ndAdvancedRobot

ndInt32 countX = 22;
ndInt32 countZ = 23;
//countX = 1;
//countZ = 1;
countX = 10;
countZ = 10;

// add a hidden battery of model to generate trajectories in parallel
for (ndInt32 i = 0; i < countZ; ++i)
Expand Down Expand Up @@ -1197,7 +1218,7 @@ namespace ndAdvancedRobot

episodeCount -= m_master->GetEposideCount();
ndFloat32 rewardTrajectory = m_master->GetAverageFrames() * m_master->GetAverageScore();
if (rewardTrajectory >= ndFloat32(m_maxScore))
if (rewardTrajectory >= m_maxScore)
{
if (m_lastEpisode != m_master->GetEposideCount())
{
Expand All @@ -1217,21 +1238,23 @@ namespace ndAdvancedRobot
fflush(m_outFile);
}
}
}

if (stopTraining / 100000000 == m_savePartial)
{
m_savePartial++;
char fileName[1024];
ndBrain* const actor = m_master->GetActor();
ndGetWorkingFileName("ndRobotArmReach_actor.dnn", fileName);
actor->SaveToFile(fileName);
if (rewardTrajectory > m_saveScore)
{
char fileName[1024];
m_saveScore = ndFloor(rewardTrajectory) + 2.0f;

ndBrain* const critic = m_master->GetCritic();
ndGetWorkingFileName("ndRobotArmReach_critic.dnn", fileName);
critic->SaveToFile(fileName);
// save partial controller in case of crash
ndBrain* const actor = m_master->GetActor();
ndGetWorkingFileName("ndRobotArmReach_actor.dnn", fileName);
actor->SaveToFile(fileName);

ndBrain* const critic = m_master->GetCritic();
ndGetWorkingFileName("ndRobotArmReach_critic.dnn", fileName);
critic->SaveToFile(fileName);
}
}

if ((stopTraining >= m_stopTraining) || (100.0f * m_master->GetAverageScore() / m_horizon > 96.0f))
{
char fileName[1024];
Expand All @@ -1252,12 +1275,13 @@ namespace ndAdvancedRobot
ndList<ndModelArticulation*> m_models;
FILE* m_outFile;
ndUnsigned64 m_timer;
ndFloat32 m_horizon;
ndFloat32 m_maxScore;
ndFloat32 m_saveScore;
ndFloat32 m_discountFactor;
ndFloat32 m_horizon;
ndUnsigned32 m_lastEpisode;
ndUnsigned32 m_stopTraining;
ndUnsigned32 m_savePartial;

bool m_modelIsTrained;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ void ndBrainAgentContinuePolicyGradient_TrainerMaster::OptimizeStep()
{
ndBrainAgentContinuePolicyGradient_Trainer* const agent = node->GetInfo();
agent->m_trajectory.SetCount(0);
agent->ResetModel();
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion newton-4.00/sdk/dCore/ndQuaternion.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class ndQuaternion: public ndVector
};

inline ndQuaternion::ndQuaternion()
:ndVector(ndVector::m_wOne)
//:ndVector(ndVector::m_wOne)
:ndVector(ndFloat32 (0.0f), ndFloat32(0.0f), ndFloat32(0.0f), ndFloat32(1.0f))
{
}

Expand Down

0 comments on commit a0ab4ca

Please sign in to comment.