Skip to content

Commit

Permalink
fix: NormalVectorInterpolator setup properly
Browse files Browse the repository at this point in the history
  • Loading branch information
rabii-chaarani committed Mar 6, 2024
1 parent c7e72bd commit 0192a82
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions map2loop/interpolator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Tuple, Any

from Cython.Includes.numpy import ndarray
# from Cython.Includes.numpy import ndarray
from map2loop.m2l_enums import Datatype
import beartype
import pandas
Expand Down Expand Up @@ -101,17 +101,33 @@ def interpolate(self, map_data: MapData) -> list:

class NormalVectorInterpolator(Interpolator):
"""
Normal vector interpolation class
Args:
NormalVectorInterpolator(Interpolator): Derived from Abstract Base Class
This class is a subclass of the Interpolator abstract base class. It implements the normal vector interpolation
method for a given set of data points. The class is initialised without any arguments.
Attributes:
dataframe (pandas.DataFrame): A DataFrame that stores the processed data points for interpolation.
x (numpy.ndarray): A numpy array that stores the x-coordinates of the data points.
y (numpy.ndarray): A numpy array that stores the y-coordinates of the data points.
xi (numpy.ndarray): A numpy array that stores the x-coordinates of the grid points for interpolation.
yi (numpy.ndarray): A numpy array that stores the y-coordinates of the grid points for interpolation.
interpolator_label (str): A string that stores the label of the interpolator. For this class, it is
"NormalVectorInterpolator".
Methods:
type(): Returns the label of the interpolator.
setup_interpolation(map_data: MapData): Sets up the interpolation by preparing the data points for interpolation.
setup_grid(map_data: MapData): Sets up the grid for interpolation.
interpolator(ni: Any) -> numpy.ndarray: Performs the interpolation for a given set of values.
interpolate(map_data: MapData) -> numpy.ndarray: Executes the interpolation method.
"""

def __init__(self):
"""
Initialiser of for NormalVectorInterpolator class
"""
self.dataframe = None
self.x = None
self.y = None
self.xi = None
self.yi = None
self.interpolator_label = "NormalVectorInterpolator"
Expand All @@ -126,7 +142,6 @@ def type(self):
return self.interpolator_label

@beartype.beartype
@abstractmethod
def setup_interpolation(self, map_data: MapData):
"""
Setup the interpolation method (abstract method)
Expand Down Expand Up @@ -224,6 +239,8 @@ def setup_interpolation(self, map_data: MapData):
)

self.dataframe = contact_orientations
self.x = self.dataframe['geometry'].apply(lambda geom: geom.x).to_numpy()
self.y = self.dataframe['geometry'].apply(lambda geom: geom.y).to_numpy()

@beartype.beartype
def setup_grid(self, map_data: MapData):
Expand All @@ -234,24 +251,27 @@ def setup_grid(self, map_data: MapData):
map_data (map2loop.MapData): a catchall so that access to all map data is available
"""
# Define the desired cell size
cell_size = 0.05 * (map_data.bounding_box["maxx"] - map_data.bounding_box["minx"])
cell_size = 0.01 * (map_data.bounding_box["maxx"] - map_data.bounding_box["minx"])

# Calculate the grid resolution
grid_resolution = round((map_data.bounding_box["maxx"] - map_data.bounding_box["minx"]) / cell_size)

# Generate the grid
xi = numpy.linspace(
x = numpy.linspace(
map_data.bounding_box["minx"], map_data.bounding_box["maxx"], grid_resolution
)
yi = numpy.linspace(
y = numpy.linspace(
map_data.bounding_box["miny"], map_data.bounding_box["maxy"], grid_resolution
)
xi, yi = numpy.meshgrid(x, y)
xi = xi.flatten()
yi = yi.flatten()

self.xi = xi
self.yi = yi

@beartype.beartype
def interpolator(self, x: float, y: float, ni: float, xi: float, yi: float) -> numpy.ndarray:
def interpolator(self, ni: Any) -> numpy.ndarray:
# TODO: 1. add argument for type of interpolator. 2. add code to process different types of
# interpolators from Scipy and use the chosen one
"""
Expand All @@ -268,9 +288,9 @@ def interpolator(self, x: float, y: float, ni: float, xi: float, yi: float) -> n
Rbf: radial basis function object
"""

rbf = Rbf(x, y, ni, function="linear")
rbf = Rbf(self.x, self.y, ni, function="linear")

return rbf(xi, yi)
return rbf(self.xi, self.yi)

@beartype.beartype
def interpolate(self, map_data: MapData) -> numpy.ndarray:
Expand All @@ -284,16 +304,18 @@ def interpolate(self, map_data: MapData) -> numpy.ndarray:
list: sorted list of unit names
"""
self.setup_interpolation(map_data)
self.setup_grid(map_data)

stratigraphic_orientations = self.setup_interpolation(map_data)
x, y = stratigraphic_orientations[["X", "Y"]].to_numpy()
nx, ny, nz = stratigraphic_orientations[["nx", "ny", "nz"]].to_numpy()
xi, yi = self.setup_grid(map_data)
# nx, ny, nz = self.dataframe[["nx", "ny", "nz"]].to_numpy()
nx = self.dataframe["nx"].to_numpy()
ny = self.dataframe["ny"].to_numpy()
nz = self.dataframe["nz"].to_numpy()

# interpolate each component of the normal vector nx, ny, nz
nx_interp = self.interpolator(x, y, nx, xi, yi)
ny_interp = self.interpolator(x, y, ny, xi, yi)
nz_interp = self.interpolator(x, y, nz, xi, yi)
nx_interp = self.interpolator(nx)
ny_interp = self.interpolator(ny)
nz_interp = self.interpolator(nz)

vecs = numpy.array([nx_interp, ny_interp, nz_interp]).T
vecs /= numpy.linalg.norm(vecs, axis=1)[:, None]
Expand Down Expand Up @@ -358,12 +380,18 @@ def setup_grid(self, map_data: MapData):
grid_resolution = round((map_data.bounding_box["maxx"] - map_data.bounding_box["minx"]) / cell_size)

# Generate the grid
self.xi = numpy.linspace(
x = numpy.linspace(
map_data.bounding_box["minx"], map_data.bounding_box["maxx"], grid_resolution
)
self.yi = numpy.linspace(
y = numpy.linspace(
map_data.bounding_box["miny"], map_data.bounding_box["maxy"], grid_resolution
)
# generate the grid
xi, yi = numpy.meshgrid(x, y)
xi = xi.flatten()
yi = yi.flatten()
self.xi = xi
self.yi = yi

@beartype.beartype
def interpolator(self, ni: Any) -> numpy.ndarray:
Expand All @@ -383,7 +411,7 @@ def interpolator(self, ni: Any) -> numpy.ndarray:
return rbf(self.xi, self.yi)

@beartype.beartype
def interpolate(self, map_data: MapData) -> tuple[ndarray | ndarray, ndarray | ndarray]:
def interpolate(self, map_data: MapData) -> numpy.ndarray:
"""
Execute interpolation method (abstract method)
Expand Down

0 comments on commit 0192a82

Please sign in to comment.