Skip to content

Commit

Permalink
Merge branch 'main' into penghuic/pytest_by_script
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng authored Jan 3, 2025
2 parents 200847b + f634c3c commit c924f3d
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .github/ci_expected_accuracy/check_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@


# load csv files
test_data= pd.read_csv(args.csv_file)
test_data= pd.read_csv(args.csv_file, comment='#')
# test_data = test_data.reset_index() # make sure indexes pair with number of rows
# test_data = test_data.sort_values(by=["name"], ascending=True)
test_names = [row["name"] for index, row in test_data.iterrows()]

current_path = pathlib.Path(__file__).parent.resolve()
refer_file = str(current_path) + "/" + args.category + "_" + args.suite + "_" + args.mode + ".csv"
refer_data= pd.read_csv(refer_file)
refer_data= pd.read_csv(refer_file, comment='#')
# refer_data = refer_data.reset_index() # make sure indexes pair with number of rows
# refer_data = refer_data.sort_values(by=["name"], ascending=True)
refer_names = [row["name"] for index, row in refer_data.iterrows()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ CamemBert,pass,pass,pass,pass,pass
DebertaForMaskedLM,pass,pass,pass,pass,pass
DebertaForQuestionAnswering,pass,pass,pass,pass,pass
DebertaV2ForMaskedLM,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
DebertaV2ForQuestionAnswering,pass,pass,pass,pass,pass
# Skip DebertaV2ForQuestionAnswering issue: https://github.com/intel/torch-xpu-ops/issues/1216
DebertaV2ForQuestionAnswering,fail_accuracy,fail_accuracy,fail_accuracy,pass,pass
DistilBertForMaskedLM,pass,pass,pass,pass,pass
DistilBertForQuestionAnswering,pass,pass,pass,pass,pass
DistillGPT2,pass,pass,pass,pass,pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,6 @@ torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
vgg16,pass,pass,pass,pass,pass
vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
yolov3,pass,pass,pass,pass,pass
# Skip yolov3 for known torchbench issue: https://github.com/intel/torch-xpu-ops/issues/1229
yolov3,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
hf_Roberta_base,pass,pass,pass,pass,pass
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,6 @@ torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
vgg16,pass,pass,pass,pass,pass
vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run
yolov3,pass,pass,pass,pass,pass
# Skip yolov3 for known torchbench issue: https://github.com/intel/torch-xpu-ops/issues/1229
yolov3,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run
hf_Roberta_base,pass,pass,pass,pass,pass
2 changes: 1 addition & 1 deletion .github/workflows/nightly_ondemand.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ jobs:
fi
echo "TORCH_BRANCH_ID=$(git rev-parse --abbrev-ref HEAD)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCH_COMMIT_ID=$(git rev-parse HEAD)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<third_party/torch-xpu-ops/.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHVISION_COMMIT_ID=$(<.github/ci_commit_pins/vision.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHAUDIO_COMMIT_ID=$(<.github/ci_commit_pins/audio.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly_ondemand_rolling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ jobs:
fi
echo "TORCH_BRANCH_ID=$(git rev-parse --abbrev-ref HEAD)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCH_COMMIT_ID=$(git rev-parse HEAD)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<third_party/torch-xpu-ops/.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHVISION_COMMIT_ID=$(<.github/ci_commit_pins/vision.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHAUDIO_COMMIT_ID=$(<.github/ci_commit_pins/audio.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly_ondemand_whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ jobs:
echo "TORCHAUDIO_COMMIT_ID=$(python -c 'import torchaudio; print(torchaudio.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TRITON_COMMIT_ID=$(python -c 'import triton; print(triton.__version__)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
cd ../pytorch
echo "TORCHBENCH_COMMIT_ID=$(<third_party/torch-xpu-ops/.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<.github/ci_commit_pins/torchbench.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "MODEL_ONLY_NAME=${{ inputs.model }}" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
cd ../pytorch
echo "TRITON_COMMIT_ID=$(<.ci/docker/ci_commit_pins/triton-xpu.txt)" >> "${GITHUB_ENV}"
echo "TORCHVISION_COMMIT_ID=$(<.github/ci_commit_pins/vision.txt)" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<third_party/torch-xpu-ops/.github/ci_commit_pins/torchbench.txt)" >> "${GITHUB_ENV}"
echo "TORCHBENCH_COMMIT_ID=$(<.github/ci_commit_pins/torchbench.txt)" >> "${GITHUB_ENV}"
echo "TORCHAUDIO_COMMIT_ID=$(<.github/ci_commit_pins/audio.txt)" >> "${GITHUB_ENV}"
echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" >> "${GITHUB_ENV}"
echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" >> "${GITHUB_ENV}"
Expand Down Expand Up @@ -144,9 +144,9 @@ jobs:
run: |
rm -rf ${{ github.workspace }}/upload_files
cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files
failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true)
failed_case=$(grep "Real failed models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true)
if [ ${failed_case} -ne 0 ];then
grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log
grep -E "Real failed models: [1-9]|Summary for" ${{ github.workspace }}/upload_files/summary_accuracy.log
exit 1
fi
- name: Upload Inductor XPU E2E Data
Expand Down
29 changes: 10 additions & 19 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
} else {
invstd =
static_cast<stat_accscalar_t>(1) /
device_sqrt(
std::sqrt(
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
}

Expand All @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE;
feature_vec_begin < fs;
feature_vec_begin += VEC_SIZE * item.get_local_range(1)) {
auto remaining = fs - feature_vec_begin;
if (remaining < VEC_SIZE) {
for (index_t idx = 0; idx < remaining; ++idx) {
index_t feature = feature_vec_begin + idx;
o[feature] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
} else {
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
#pragma unroll
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
}
}
}
Expand Down Expand Up @@ -1459,7 +1450,7 @@ void batch_norm_elemt_template(
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float)) {
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
Expand Down
43 changes: 39 additions & 4 deletions test/xpu/extended/skip_list_win_mtl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
skip_dict = {
# failed on MTL windows, skip first for Preci
"test_ops_xpu.py": (
"test_compare_cpu_sqrt_xpu_complex64",
"test_backward_nn_functional_adaptive_avg_pool2d_xpu_float32",

"test_compare_cpu_cosh_xpu_complex128",
"test_compare_cpu_frexp_xpu_bfloat16",
"test_compare_cpu_frexp_xpu_float16",
Expand All @@ -14,7 +11,45 @@
"test_compare_cpu_max_pool2d_with_indices_backward_xpu_float32",
"test_compare_cpu_max_pool2d_with_indices_backward_xpu_float64",
"test_compare_cpu_nn_functional_avg_pool2d_xpu_bfloat16",
"test_compare_cpu_nn_functional_avg_pool2d_xpu_float32",
"test_compare_cpu_nn_functional_avg_pool3d_xpu_float32",
"test_compare_cpu_nn_functional_avg_pool3d_xpu_float64",
"test_compare_cpu_nn_functional_batch_norm_xpu_float16",
"test_compare_cpu_nn_functional_interpolate_bicubic_xpu_float32",
"test_compare_cpu_nn_functional_interpolate_bicubic_xpu_float64",
"test_compare_cpu_nn_functional_interpolate_bilinear_xpu_float32",
"test_compare_cpu_nn_functional_interpolate_bilinear_xpu_float64",
"test_compare_cpu_nn_functional_max_pool2d_xpu_bfloat16",
"test_compare_cpu_nn_functional_max_pool2d_xpu_float16",
"test_compare_cpu_nn_functional_max_pool2d_xpu_float32",
"test_compare_cpu_nn_functional_max_pool2d_xpu_float64",
"test_compare_cpu_norm_nuc_xpu_complex128",
"test_compare_cpu_norm_nuc_xpu_complex64",
"test_compare_cpu_norm_nuc_xpu_float32",
"test_compare_cpu_norm_nuc_xpu_float64",
"test_compare_cpu_sinh_xpu_complex128",
"test_compare_cpu_softmax_with_dtype_xpu_bfloat16",
"test_compare_cpu_softmax_with_dtype_xpu_complex128",
"test_compare_cpu_softmax_with_dtype_xpu_complex64",
"test_compare_cpu_softmax_with_dtype_xpu_float64",
"test_compare_cpu_softmax_with_dtype_xpu_int32",
"test_compare_cpu_softmax_with_dtype_xpu_int64",
"test_compare_cpu_softmax_with_dtype_xpu_uint8",
"test_compare_cpu_softmax_xpu_float64",
"test_compare_cpu_square_xpu_complex128",
"test_backward_norm_nuc_xpu_float32",
"test_cow_input_norm_nuc_xpu_float32",
"test_forward_ad_norm_nuc_xpu_float32",
"test_operator_norm_nuc_xpu_float32",
"test_view_replay_norm_nuc_xpu_float32",
"test_compare_cpu_nn_functional_avg_pool2d_xpu_float32",
"test_compare_cpu_nn_functional_avg_pool2d_xpu_float64",
"test_compare_cpu_softmax_with_dtype_xpu_bool",
"test_compare_cpu_softmax_with_dtype_xpu_float32",
"test_compare_cpu_softmax_with_dtype_xpu_int16",
"test_compare_cpu_softmax_with_dtype_xpu_int8",
"test_compare_cpu_nn_functional_avg_pool2d_xpu_float16",
"test_compare_cpu_softmax_with_dtype_xpu_float16",
"test_compare_cpu_softmax_xpu_bfloat16",
"test_compare_cpu_softmax_xpu_float32",
),
}
4 changes: 4 additions & 0 deletions test/xpu/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"test_python_ref_torch_fallback__refs_log10_xpu_complex128",
"test_python_ref_torch_fallback__refs_sigmoid_xpu_complex128",
"test_python_ref_executor__refs_log10_executor_aten_xpu_complex128",
"test_noncontiguous_samples_histogram_xpu_float32",

# TODO: Fix the following tests
"test_out_warning_torch__scaled_mm_xpu",

# To be removed from this file.
# CUDA and XPU both XFAIL now.
Expand Down

0 comments on commit c924f3d

Please sign in to comment.