Skip to content

Commit

Permalink
merging branches
Browse files Browse the repository at this point in the history
  • Loading branch information
amirDahari1 committed Jul 30, 2024
1 parent 1112759 commit 4d704cd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 49 deletions.
4 changes: 4 additions & 0 deletions representativity/prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from scipy.optimize import curve_fit
from functools import partial

import os

print(os.getcwd())

'''
File: prediction_error.py
Expand Down
49 changes: 0 additions & 49 deletions representativity/tpc_expectation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@
import time
from scipy.stats import norm

def generate_sg_tpc(netG, mode, imsize):
'''
This function is used to predict the integral range of the microstructure.
'''
lf = imsize//32 + 2 # the size of G's input
single_img = util.generate_image(netG, lf=lf, threed=mode=='3D', reps=1)
if single_img.any():
single_img = single_img.cpu()[0]
dims = len(single_img.shape)
tpc = util.tpc_radial(single_img, threed=dims == 3)
return tpc

def tpc_by_radius(tpc):
tpc = np.array(tpc)
middle_idx = np.array(tpc.shape)//2
Expand All @@ -47,36 +35,6 @@ def tpc_by_radius(tpc):
tpc_res.append(np.sum(tpc_vec[dist_bool])/np.sum(vec_arr[dist_bool]))
return pf, pf**2, tpc_res, dist_indices

def tpc_check():
all_data, micros, netG, v_names, run_v_names = ms.json_preprocessing()

edge_lengths_pred = all_data[f'data_gen_2D']['edge_lengths_pred']
for j, p in enumerate(micros):

try:
netG.load_state_dict(torch.load(p + "_Gen.pt"))
except: # if the image is greayscale it's excepting because there's only 1 channel
continue
n = p.split('/')[-1]
args = (edge_lengths_pred[10], netG, '2D')
tpc_results, pfs, pf_squares = tpcs_radius(generate_sg_tpc, args)
# print(f'{len(pf_squares)} / {test_runs} done for {n}')
mean_tpc_results = np.mean(np.stack(tpc_results,axis=0), axis=0)
tpc_fig.plot(mean_tpc_results, label='mean tpc')
real_pf_squared = np.mean(pfs)**2
# print(f'real pf squared = {np.round(real_pf_squared, 6)}')
# print(f'pf squared = {np.round(np.mean(pf_squares), 6)}')
tpc_fig.plot([real_pf_squared]*len(mean_tpc_results), label='real pf squared')
tpc_fig.plot([np.mean(pf_squares)]*len(mean_tpc_results), label='pf squared')
tpc_fig.xlabel('Growing Radius')
tpc_fig.ylabel('TPC')
tpc_fig.legend()
tpc_fig.savefig(f'tpc_results/{n}_tpc.png')
tpc_fig.close()
# print(f'end tpc = {np.round(np.mean(end_tpc_results), 6)}')
# print(f'end tpc std = {np.round(np.std(end_tpc_results), 6)}\n')
print(f'{p} done\n')

def tpcs_radius(gen_func, test_runs, args):
tpc_results = []
tpcs = []
Expand Down Expand Up @@ -133,7 +91,6 @@ def fill_img_with_circles(img, circle_radius, circle_centers):
return img

if __name__ == '__main__':
# tpc_check()
pfs =[]
imsize = 100
circle_size = 20
Expand All @@ -160,12 +117,6 @@ def fill_img_with_circles(img, circle_radius, circle_centers):
axs[f'circle{i}'].set_xlabel(f'$\omega_{i}$')
axs[f'circle{i}'].set_xticks([])
axs[f'circle{i}'].set_yticks([])



# im0_tpc = make_tpc(circle_ims[0])
# im0_tpc_by_radius = tpc_by_radius(im0_tpc)[2]
# tpc_fig.plot(im0_tpc_by_radius, label='TPC of $\omega_0$')

tpc_results, pfs, pf_squares, dist_len = tpcs_radius(make_circles_tpc, run_tests, args=args)
bins = 30
Expand Down
3 changes: 3 additions & 0 deletions representativity/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def ps_error_prediction(dim, data, confidence, error_target):
large_im_repeats = 1
in_the_bounds_w_model = []
in_the_bounds_wo_model = []
iters = 0
for generator, params in ps_generators.items():
for value_comb in product(*params.values()):
args = {key: value for key, value in zip(params.keys(), value_comb)}
Expand All @@ -143,6 +144,8 @@ def ps_error_prediction(dim, data, confidence, error_target):
print(f'One image stat analysis cls: {one_im_stat_analysis_cls}')
one_im_clss.append(one_im_stat_analysis_cls)
for i in range(2):
print(f'Iteration {iters}')
iters += 1
with_model = i == 0
im_err, l_for_err_target, cls = util.make_error_prediction(small_im,
conf=confidence, err_targ=error_target, model_error=with_model, n_divisions=301)
Expand Down

0 comments on commit 4d704cd

Please sign in to comment.