Skip to content

Commit

Permalink
remove softmax decompose
Browse files Browse the repository at this point in the history
  • Loading branch information
dgolubovicTT committed Jul 27, 2024
1 parent 7885d9f commit cc9c3c3
Show file tree
Hide file tree
Showing 8 changed files with 1 addition and 37 deletions.
7 changes: 0 additions & 7 deletions pybuda/pybuda/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class CompilerConfig:
enable_auto_transposing_placement: bool = ("PYBUDA_ENABLE_AUTO_TRANSPOSE" in os.environ) # compiler automatically detects ops to transpose on placement when the flag is set
fracture_groups: List[Tuple[List[Tuple[str, int, int]], List[str], List[int]]] = field(default_factory=lambda: list()) # see insert_fracture_group
conv_multi_op_fracture_factor_override: Dict[str, int] = field(default_factory=lambda: dict()) # override multi op fracture factor for conv
enable_stable_softmax: bool = True
enable_single_buffer_fallback: bool = False

backend_opt_level: int = 4 # backend optimization level
Expand Down Expand Up @@ -235,9 +234,6 @@ def apply_env_config_overrides(self):
if "PYBUDA_PRESTRIDE_DISABLE" in os.environ:
self.enable_conv_prestride = not bool(int(os.environ["PYBUDA_PRESTRIDE_DISABLE"]))

if "PYBUDA_DISABLE_STABLE_SOFTMAX" in os.environ:
self.enable_stable_softmax = not bool(int(os.environ["PYBUDA_DISABLE_STABLE_SOFTMAX"]))

if "PYBUDA_CONVERT_PARAMS_TO_TVM" in os.environ:
self.convert_framework_params_to_tvm = bool(int(os.environ["PYBUDA_CONVERT_PARAMS_TO_TVM"]))

Expand Down Expand Up @@ -390,7 +386,6 @@ def set_configuration_options(
backend_runtime_args: Optional[str] = None,
enable_auto_fusing: Optional[bool] = None,
enable_conv_prestride: Optional[bool] = None,
enable_stable_softmax: Optional[bool] = None,
amp_level: Optional[int] = None,
harvested_rows: Optional[List[List[int]]] = None,
store_backend_db_to_yaml: Optional[bool] = None,
Expand Down Expand Up @@ -541,8 +536,6 @@ def set_configuration_options(
g_compiler_config.enable_auto_fusing = enable_auto_fusing
if enable_conv_prestride is not None:
g_compiler_config.enable_conv_prestride = enable_conv_prestride
if enable_stable_softmax is not None:
g_compiler_config.enable_stable_softmax = enable_stable_softmax
if amp_level is not None:
g_compiler_config.amp_level = amp_level
if harvested_rows is not None:
Expand Down
22 changes: 1 addition & 21 deletions pybuda/pybuda/op/eval/pybuda/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,28 +459,8 @@ def decompose_post_autograd(op_type, attr, dc, inputs):
Result of the operation.
"""

if op_type == "softmax":

assert len(inputs) == 1, "Softmax should have one operand."
assert len(attr) == 2, "Softmax should have two attributes."
x = inputs[0]
dim = attr[0]
stable = attr[1]

if stable and dc.get_compiler_cfg().enable_stable_softmax:
res_max = dc.op("reduce_max", (x, ), (dim, ))
res_x_max = dc.op("subtract", (x, res_max), ())
res_exp = dc.op(Exp.create(), (res_x_max, ), ())
else:
res_exp = dc.op(Exp.create(), (x, ), ())


res_exp_sum = dc.op("reduce_sum", (res_exp, ), (dim, ))
res_exp_sum = dc.op("add", (res_exp_sum, dc.tensor(torch.zeros(res_exp_sum.shape.as_list()) + 1e-10)), ())
res_exp_sum_recip = dc.op(Reciprocal.create(), (res_exp_sum, ), ())
result = dc.op("multiply", (res_exp, res_exp_sum_recip), ())
dc.fuse(result)
if op_type == "softmax":
return

if op_type == "softmax_bw":
Expand Down
1 change: 0 additions & 1 deletion pybuda/test/falcon/pybudify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, pt_module, device='silicon', arch='wormhole_b0', precision='f

# pybuda workarounds
os.environ["GOLDEN_WORMHOLE_B0"] = "1" # golden should always simulate a B0 as that's all we use now
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1" # improved accuracy - pybuda team surprised we need it though
os.environ["PYBUDA_CONVERT_PARAMS_TO_TVM"] = "0" # faster compile times... why would this ever be 1?
os.environ["TT_BACKEND_TIMEOUT"] = "0" # default is too aggressive for large models?

Expand Down
1 change: 0 additions & 1 deletion pybuda/test/falcon/tests/falcon_modules/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,6 @@ def __init__(self, args):
# pybuda workarounds
os.environ["GOLDEN_WORMHOLE_B0"] = "1"
os.environ["WHA0_DISABLE_RELAY_BUFS"] = "1"
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1"
os.environ["PYBUDA_CONVERT_PARAMS_TO_TVM"] = "0"
os.environ["TT_BACKEND_TIMEOUT"] = "0"

Expand Down
1 change: 0 additions & 1 deletion pybuda/test/llama/pybudify_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, pt_module, device='silicon', arch='wormhole_b0', precision='f
#os.environ["PYBUDA_DISABLE_FORK_JOIN_BUF"] = "1"
# os.environ["PYBUDA_DRAM_PICK_CAPACITY"] = "1"
os.environ["WHA0_DISABLE_RELAY_BUFS"] = "1"
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1"
os.environ["PYBUDA_FUSE_STOP_ON_RECIPROCAL"] = "1"
os.environ["PYBUDA_PLACER_SNAKE"] = "1"
os.environ["LOGGER_LEVEL"] = log_level
Expand Down
1 change: 0 additions & 1 deletion pybuda/test/model_demos/high_prio/cnn/pytorch/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_swin_v1_tiny_4_224_hf_pytorch(test_device):
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.retain_tvm_python_files = True
compiler_cfg.enable_tvm_constant_prop = True
os.environ["PYBUDA_ENABLE_STABLE_SOFTMAX"] = "1"
os.environ["TVM_BACKTRACE"]="1"

# STEP 2: Create PyBuda module from PyTorch model
Expand Down
3 changes: 0 additions & 3 deletions pybuda/test/model_demos/models/falcon/pybudify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def __init__(
os.environ[
"GOLDEN_WORMHOLE_B0"
] = "1" # golden should always simulate a B0 as that's all we use now
os.environ[
"PYBUDA_ENABLE_STABLE_SOFTMAX"
] = "1" # improved accuracy - pybuda team surprised we need it though
os.environ[
"PYBUDA_CONVERT_PARAMS_TO_TVM"
] = "0" # faster compile times... why would this ever be 1?
Expand Down
2 changes: 0 additions & 2 deletions pybuda/test/test_fusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ def forward(self, act1):
@pytest.mark.parametrize("dim", ["r", "c"])
def test_softmax(test_device, test_kind, dim):

pybuda.set_configuration_options(enable_stable_softmax=False)

os.environ["PYBUDA_FUSE_REDUCE"] = "1"

dim_index = -1 if dim == "c" else -2
Expand Down

0 comments on commit cc9c3c3

Please sign in to comment.