forked from 2iw31Zhv/diffsmat_py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_opt.py
117 lines (102 loc) · 3.96 KB
/
test_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import maxpy.rcwa as rcwa
import matplotlib.pyplot as plt
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("Using GPU")
else:
print("Using CPU")
nx = ny = 8 # half of the harmonics along x and y directions
# number of harmonics is (2 * nx + 1) x (2 * ny + 1)
nx_grid = 20 # grid number along x direction, we use analytical Fourier transform, so nx_grid can be very small
n_opt = 10 # number of optimization grid along one direction
wavelength = 1.55
Lx = 1. # period
Ly = 1. # period
n_mode = 2 # number of modes to be optimized
ny_grid = int(nx_grid * Ly / Lx) # grid number along y direction
k0 = 2 * np.pi / wavelength # free space wavevector
eps_in = torch.tensor(3.48*3.48+0j, device = device, requires_grad = False, dtype = torch.complex128)
eps_out = torch.tensor(1.+0j, device = device, requires_grad = False, dtype = torch.complex128)
de = 0.5 * torch.ones(2*n_opt//2, 2*n_opt//2, device = device, dtype = torch.float64)
de.requires_grad_(True)
coeff = rcwa.MaxwellCoeff(nx, ny, Lx, Ly, device = device)
port_mode = rcwa.MaxwellMode()
port_mode.compute_in_vacuum(wavelength, coeff, device = device)
neff_port = port_mode.valsqrt.real / k0 / k0
neff_port, port_indices = torch.sort(neff_port, descending = True)
select_indices = port_indices[:n_mode]
def get_permittivity(de):
ex = eps_out * torch.ones(nx_grid, ny_grid, device = device, dtype = torch.float64)
ex[nx_grid//2 - n_opt//2 : nx_grid//2 + n_opt//2, ny_grid//2 - n_opt//2 : ny_grid//2 + n_opt//2] += (eps_in - eps_out) * de
return ex
USE_DIFFERENTIABLE_EIG = False # False for our method
def compute_loss(de):
ex = get_permittivity(de)
coeff.compute(wavelength, ex, device = device)
mode = rcwa.MaxwellMode()
if USE_DIFFERENTIABLE_EIG:
mode.compute_diff(coeff, device = device)
smat = rcwa.ScatteringMatrix()
smat.compute(mode, 1.)
smat.port_project_diff(port_mode, coeff)
else:
mode.compute(coeff, device = device)
smat = rcwa.ScatteringMatrix()
smat.compute(mode, 1.)
smat.port_project(port_mode, coeff)
Tuu_2 = torch.abs(smat.Tuu())**2
ind_0 = select_indices[0]
ind_1 = select_indices[1]
print(Tuu_2[ind_0, ind_0].item(), Tuu_2[ind_1, ind_1].item())
# maximize the transmission of one polarization
# minimize the transmission of the other polarization
return -Tuu_2[ind_0, ind_0] + Tuu_2[ind_1, ind_1]
plt.imshow(get_permittivity(de).detach().cpu().numpy().real)
plt.savefig("de_init.png")
plt.close()
loss_history = []
niters = 20
optimizer = optim.Adam([de], lr=5e-2)
t1 = time.perf_counter()
for i in range(niters):
optimizer.zero_grad()
loss = compute_loss(de)
loss_history.append(loss.item())
loss.backward()
optimizer.step()
de.data = torch.clamp(de.data, 0., 1.)
print("i, de, loss = ", i, loss.item())
t2 = time.perf_counter()
print("time = ", t2 - t1)
# np.save("loss_history_v2.npy", np.array(loss_history))
loss_v2 = np.load("loss_history_v2.npy")
plt.imshow(get_permittivity(de).detach().cpu().numpy().real)
plt.savefig("de_final.png")
plt.close()
# plt.plot(loss_v2, label = "Lorentzian broadening")
plt.plot(loss_history, label = "our method")
# plt.legend()
plt.xlabel("iteration")
plt.ylabel("loss")
plt.savefig("loss_history.png")
plt.close()
plt.imshow(get_permittivity((de >= 0.4).float()).detach().cpu().numpy().real)
plt.savefig("de_binarized.png")
plt.close()
if 0:
thresholds = torch.linspace(0., 1., 11, device = device)
binarized_loss = []
for threshold in thresholds:
de_bin = de.clone()
de_bin = (de_bin >= threshold).float()
loss = compute_loss(de_bin)
binarized_loss.append([loss.detach().cpu().numpy()])
binarized_loss = np.array(binarized_loss)
plt.plot(thresholds.detach().cpu().numpy(), binarized_loss[:, 0], '-', label='loss')
plt.savefig('binarized_loss.png')
plt.close()