Skip to content

Commit

Permalink
Add support for nan_to_num,atan2 op and map bitwise_or_ op and fix is…
Browse files Browse the repository at this point in the history
…nan decomposition logic
  • Loading branch information
kamalrajkannan78 committed Jan 10, 2025
1 parent d8027e9 commit 5d4e3f4
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
80 changes: 80 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4620,6 +4620,83 @@ def scaled_dot_product_attention(self, inputs, input_types):
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])

return attn_weight


def nan_to_num(self, inputs, input_types):

"""
Mimics the behavior of torch.nan_to_num (https://pytorch.org/docs/stable/generated/torch.nan_to_num.html).
"""

# Extract input tensor and replacement values
data = inputs[0]
nan_value = inputs[1]
posinf = inputs[2]
neginf = inputs[3]

# Ensure the data type is one of the supported floating-point types
dtype = input_types[0]
assert dtype in ["float16", "float32", "float64"], f"Unsupported dtype: {dtype}. Supported types are ['float16', 'float32', 'float64']."

# Define the maximum and minimum representable values for the data type
dtype_max = np.finfo(dtype).max
dtype_min = np.finfo(dtype).min

# Create constants for NaN, positive infinity, and negative infinity replacements
nan_tensor = tvm.relay.const(nan_value if nan_value is not None else 0.0, dtype)
posinf_tensor = tvm.relay.const(posinf if posinf is not None else dtype_max, dtype)
neginf_tensor = tvm.relay.const(neginf if neginf is not None else dtype_min, dtype)

# Replace NaN values with the specified or default value
data= tvm.relay.where(_op.isnan(data) , nan_tensor, data)

# Replace positive infinity with the specified or greatest finite value representable by input’s dtype
data = tvm.relay.where(tvm.relay.greater(data, posinf_tensor),posinf_tensor, data)

# Replace negative infinity with the specified or least finite value representable by input’s dtype
result = tvm.relay.where(tvm.relay.less(data, neginf_tensor), neginf_tensor, data)

return result


def atan2(self, inputs, input_types):

"""
Mimics the behavior of torch.atan2 (https://pytorch.org/docs/stable/generated/torch.atan2.html).
"""

data_1 = inputs[1] # x (denominator)
data_2 = inputs[0] # y (numerator)

# Compute the ratio y/x. This is the tangent of the angle.
ratio = tvm.relay.divide(data_2, data_1)

# Compute the arctangent of the ratio, which gives the angle in the range [-π/2, π/2].
atan_res = tvm.relay.atan(ratio)

# Define constants for π and 0 for use in correction logic.
pi = tvm.relay.const(np.pi, "float32") # π constant
zero = tvm.relay.const(0.0, "float32") # Zero constant

# Compute the correction term to adjust the angle to the correct quadrant.
# If x < 0:
# - If y >= 0, add π to the angle (to move from 1st to 2nd quadrant).
# - If y < 0, subtract π from the angle (to move from 4th to 3rd quadrant).
# If x >= 0, no correction is needed.
correction = tvm.relay.where(
tvm.relay.less(data_1, zero), # Check if x < 0
tvm.relay.where(
tvm.relay.greater_equal(data_2, zero), # Check if y >= 0
pi, # Add π if x < 0 and y >= 0
-pi # Subtract π if x < 0 and y < 0
),
zero # No correction if x >= 0
)

# Add the correction term to the arctangent result.
result = tvm.relay.add(atan_res, correction)

return result

# Operator mappings
def create_convert_map(self):
Expand Down Expand Up @@ -4893,6 +4970,7 @@ def create_convert_map(self):
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
"aten::__ior__": self.make_elemwise("bitwise_or"),
"aten::bitwise_or_": self.make_elemwise("bitwise_or"),
"aten::__iand__": self.make_elemwise("bitwise_and"),
"aten::__ixor__": self.make_elemwise("bitwise_xor"),
"aten::__lshift__": self.make_elemwise("left_shift"),
Expand Down Expand Up @@ -4920,6 +4998,8 @@ def create_convert_map(self):
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
"aten::lift_fresh": self.identity,
"aten::nan_to_num": self.nan_to_num,
"aten::atan2": self.atan2,
}

def update_convert_map(self, custom_map):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,10 +3257,13 @@ def callback(self, pre, post, node_map):

data = pre_node_map[self.data][0]

cond = tvm.relay.equal(data, tvm.relay.const(np.nan, dtype="float32"))
where = tvm.relay.where(cond, tvm.relay.const(True), tvm.relay.const(False))
# NaN (Not a Number) is the only value in floating-point arithmetic that is not equal to itself.
# So, comparing data with itself will return True if data is NaN, and False otherwise.
# This condition is used to identify NaN values in the data tensor.

return where
cond = tvm.relay.not_equal(data, data)

return cond


class RemoveRedundantBinaryStacks(DFPatternCallback):
Expand Down

0 comments on commit 5d4e3f4

Please sign in to comment.