Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for nan_to_num, atan2 & bitwise_or op #57

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4620,6 +4620,83 @@ def scaled_dot_product_attention(self, inputs, input_types):
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])

return attn_weight


def nan_to_num(self, inputs, input_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add function docs.

Same for other functions added in this PR

Copy link
Contributor Author

@kamalrajkannan78 kamalrajkannan78 Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explanations are added as comments, I just added function docs like below. Let me know if any specific format should be followed.


"""
Mimics the behavior of torch.nan_to_num (https://pytorch.org/docs/stable/generated/torch.nan_to_num.html).
"""

# Extract input tensor and replacement values
data = inputs[0]
nan_value = inputs[1]
posinf = inputs[2]
neginf = inputs[3]

# Ensure the data type is one of the supported floating-point types
dtype = input_types[0]
assert dtype in ["float16", "float32"], f"Unsupported dtype: {dtype}. Supported types are ['float16', 'float32']."

# Define the maximum and minimum representable values for the data type
dtype_max = np.finfo(dtype).max
kamalrajkannan78 marked this conversation as resolved.
Show resolved Hide resolved
dtype_min = np.finfo(dtype).min

# Create constants for NaN, positive infinity, and negative infinity replacements
nan_tensor = tvm.relay.const(nan_value if nan_value is not None else 0.0, dtype)
posinf_tensor = tvm.relay.const(posinf if posinf is not None else dtype_max, dtype)
neginf_tensor = tvm.relay.const(neginf if neginf is not None else dtype_min, dtype)

# Replace NaN values with the specified or default value
data= tvm.relay.where(_op.isnan(data) , nan_tensor, data)

# Replace positive infinity with the specified or greatest finite value representable by input’s dtype
data = tvm.relay.where(tvm.relay.greater(data, posinf_tensor),posinf_tensor, data)

# Replace negative infinity with the specified or least finite value representable by input’s dtype
result = tvm.relay.where(tvm.relay.less(data, neginf_tensor), neginf_tensor, data)

return result


def atan2(self, inputs, input_types):

kamalrajkannan78 marked this conversation as resolved.
Show resolved Hide resolved
"""
Mimics the behavior of torch.atan2 (https://pytorch.org/docs/stable/generated/torch.atan2.html).
"""

data_1 = inputs[1] # x (denominator)
data_2 = inputs[0] # y (numerator)

# Compute the ratio y/x. This is the tangent of the angle.
kamalrajkannan78 marked this conversation as resolved.
Show resolved Hide resolved
ratio = tvm.relay.divide(data_2, data_1)

# Compute the arctangent of the ratio, which gives the angle in the range [-π/2, π/2].
atan_res = tvm.relay.atan(ratio)

# Define constants for π and 0 for use in correction logic.
pi = tvm.relay.const(np.pi, "float32") # π constant
zero = tvm.relay.const(0.0, "float32") # Zero constant

# Compute the correction term to adjust the angle to the correct quadrant.
# If x < 0:
# - If y >= 0, add π to the angle (to move from 1st to 2nd quadrant).
# - If y < 0, subtract π from the angle (to move from 4th to 3rd quadrant).
# If x >= 0, no correction is needed.
correction = tvm.relay.where(
tvm.relay.less(data_1, zero), # Check if x < 0
tvm.relay.where(
tvm.relay.greater_equal(data_2, zero), # Check if y >= 0
pi, # Add π if x < 0 and y >= 0
-pi # Subtract π if x < 0 and y < 0
),
zero # No correction if x >= 0
)

# Add the correction term to the arctangent result.
result = tvm.relay.add(atan_res, correction)

return result

# Operator mappings
def create_convert_map(self):
Expand Down Expand Up @@ -4893,6 +4970,7 @@ def create_convert_map(self):
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
"aten::__ior__": self.make_elemwise("bitwise_or"),
"aten::bitwise_or_": self.make_elemwise("bitwise_or"),
"aten::__iand__": self.make_elemwise("bitwise_and"),
"aten::__ixor__": self.make_elemwise("bitwise_xor"),
"aten::__lshift__": self.make_elemwise("left_shift"),
Expand Down Expand Up @@ -4920,6 +4998,8 @@ def create_convert_map(self):
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
"aten::lift_fresh": self.identity,
"aten::nan_to_num": self.nan_to_num,
"aten::atan2": self.atan2,
}

def update_convert_map(self, custom_map):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,10 +3257,13 @@ def callback(self, pre, post, node_map):

data = pre_node_map[self.data][0]

cond = tvm.relay.equal(data, tvm.relay.const(np.nan, dtype="float32"))
where = tvm.relay.where(cond, tvm.relay.const(True), tvm.relay.const(False))
# NaN (Not a Number) is the only value in floating-point arithmetic that is not equal to itself.
# So, comparing data with itself will return True if data is NaN, and False otherwise.
# This condition is used to identify NaN values in the data tensor.

return where
cond = tvm.relay.not_equal(data, data)

return cond


class RemoveRedundantBinaryStacks(DFPatternCallback):
Expand Down