diff --git a/taufactor/taufactor.py b/taufactor/taufactor.py index a724ac2..00db295 100644 --- a/taufactor/taufactor.py +++ b/taufactor/taufactor.py @@ -191,11 +191,16 @@ def check_convergence(self, verbose, conv_crit): self.old_fl = self.new_fl return False - def check_vertical_flux(self, conv_crit): + def calc_vertical_flux(self): + '''Calculates the vertical flux through the volume''' vert_flux = self.conc[:, 1:-1, 1:-1, 1:-1] - \ self.conc[:, :-2, 1:-1, 1:-1] vert_flux[self.conc[:, :-2, 1:-1, 1:-1] == 0] = 0 vert_flux[self.conc[:, 1:-1, 1:-1, 1:-1] == 0] = 0 + return vert_flux + + def check_vertical_flux(self, conv_crit): + vert_flux = self.calc_vertical_flux() fl = torch.sum(vert_flux, (0, 2, 3))[1:-1] err = (fl.max() - fl.min())/(fl.max()) if fl.min() == 0: @@ -286,10 +291,15 @@ def solve(self, iter_limit=5000, verbose=True, conv_crit=2*10**-2, D_0=1): self.end_simulation(iter_limit, verbose, start) return self.tau - def check_vertical_flux(self, conv_crit): + def calc_vertical_flux(self): + '''Calculates the vertical flux through the volume''' vert_flux = abs(self.conc - torch.roll(self.conc, 1, 1)) vert_flux[self.conc == 0] = 0 vert_flux[torch.roll(self.conc, 1, 1) == 0] = 0 + return vert_flux + + def check_vertical_flux(self, conv_crit): + vert_flux = self.calc_vertical_flux() fl = torch.sum(vert_flux, (0, 2, 3))[3:-2] err = (fl.max() - fl.min())*2/(fl.max() + fl.min()) if err < conv_crit or torch.isnan(err).item(): @@ -488,10 +498,15 @@ def check_convergence(self, verbose, conv_crit): return False - def check_vertical_flux(self, conv_crit): + def calc_vertical_flux(self): + '''Calculates the vertical flux through the volume''' vert_flux = (self.conc[:, 1:-1, 1:-1, 1:-1] - self.conc[:, :-2, 1:-1, 1:-1]) * self.pre_factors[1][:, :-2, 1:-1, 1:-1] vert_flux[self.nn == torch.inf] = 0 + return vert_flux + + def check_vertical_flux(self, conv_crit): + vert_flux = self.calc_vertical_flux() fl = torch.sum(vert_flux, (0, 2, 3))[2:-2] err = (fl.max() - fl.min())*2/(fl.max() + fl.min()) if err < conv_crit or torch.isnan(err).item():