diff --git a/setup.py b/setup.py index 9bb233953..cfa72c4c9 100644 --- a/setup.py +++ b/setup.py @@ -201,9 +201,9 @@ def append_nvcc_threads(nvcc_extra_args): # "--ptxas-options=-O2", # "-lineinfo", # "-DFLASHATTENTION_DISABLE_BACKWARD", - # "-DFLASHATTENTION_DISABLE_DROPOUT", + "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_ALIBI", - # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", ] + generator_flag