-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbuildPINNs.m
86 lines (63 loc) · 2.07 KB
/
buildPINNs.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
%% Load data
load massSpringDamperData.mat
xsolFcn = @(t)real(A.*exp(omega1.*t) + B.*exp(omega2.*t));
plotMassSpringDamperData(t0, tmax, tdata, xdata, tpinns, xsolFcn)
%% Build neural network
inputSize = 1;
outputSize = 1;
numHiddenUnits = 128;
layers = [ featureInputLayer(1)
fullyConnectedLayer(numHiddenUnits)
tanhLayer()
fullyConnectedLayer(numHiddenUnits)
tanhLayer()
fullyConnectedLayer(outputSize) ];
net = dlnetwork(layers);
deepNetworkDesigner(net)
%% Train the neural network
% Specify training hyperparameters.
numIterations = 5e3;
% Specify ADAM hyperparameters.
learnRate = 0.01;
mp = [];
vp = [];
% Prepare data for training.
tdata = dlarray(tdata, 'CB');
xdata = dlarray(xdata, 'CB');
tpinns = dlarray(tpinns, 'CB');
% Create training progress plot.
monitor = trainingProgressMonitor(Metrics=["Loss", "LossPINN", "LossData"]);
fig = figure();
% Accelerate model loss.
accFcn = dlaccelerate(@modelLoss);
for iteration = 1:numIterations
[loss, gradients, lossPinn, lossData] = dlfeval(accFcn, net, tdata, xdata, tpinns, m, mu, k);
[net, mp, vp] = adamupdate(net, gradients, mp, vp, iteration, learnRate);
recordMetrics(monitor, iteration, ...
Loss=loss, ...
LossPINN=lossPinn, ...
LossData=lossData);
if mod(iteration, 50) == 0
ttest = sort(rand(100,1)).*tmax;
xtest = xsolFcn(ttest);
xpred = predict(net, ttest);
plotModelPredictions(fig, ttest, xtest, xpred, iteration);
end
end
function [loss, gradients, lossPinn, lossData] = modelLoss(net, tdata, xdata, tpinns, m, mu, k)
lossPinn = pinnsLoss(net, tpinns, m, mu, k);
lossData = dataLoss(net, tdata, xdata);
loss = 0.1.*lossPinn + 0.05.*lossData;
gradients = dlgradient(loss, net.Learnables);
end
function loss = pinnsLoss(net, t, m, mu, k)
x = forward(net, t);
xt = dlgradient(sum(x,'all'), t, EnableHigherDerivatives=true);
xtt = dlgradient(sum(xt,'all'), t);
residual = m.*xtt + mu.*xt + k.*x;
loss = mean( residual.^2, 'all' );
end
function loss = dataLoss(net, t, xtarget)
x = forward(net, t);
loss = l2loss(x, xtarget);
end