Skip to content

Commit

Permalink
set return_dict to false in nlp models (#1004)
Browse files Browse the repository at this point in the history
## Summary 

- This PR sets return_dict to false in Albert, Bert, DPR models to avoid
the below issue(check logs for full trace) on verify function & replace
compare_with_golden with verify


```
    for fw, co in zip(fw_out, co_out):
            if verify_cfg.verify_dtype:
>               if fw.dtype != co.dtype:
E               AttributeError: 'dict' object has no attribute 'dtype'
```

- for Models having pcc drop, New argument added to VerifyConfig to
disable pcc check.

Before Fix:

-
[before_fix.zip](https://github.com/user-attachments/files/18301228/before_fix.zip)

After Fix:

-
[after_fix.zip](https://github.com/user-attachments/files/18301237/after_fix.zip)
  • Loading branch information
kamalrajkannan78 authored Jan 11, 2025
1 parent f7a8119 commit 1db1f92
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 19 deletions.
1 change: 1 addition & 0 deletions forge/forge/verify/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class VerifyConfig:
verify_size: bool = True # Check output size
verify_dtype: bool = True # Check output dtype
verify_shape: bool = True # Check output shape
verify_values: bool = True # Check output values
value_checker: ValueChecker = AutomaticValueChecker()

# --- Logging settings --- #
Expand Down
3 changes: 2 additions & 1 deletion forge/forge/verify/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,5 @@ def verify(
if fw.shape != co.shape:
raise ValueError(f"Shape mismatch: framework_model.shape={fw.shape}, compiled_model.shape={co.shape}")

verify_cfg.value_checker.check(fw, co)
if verify_cfg.verify_values:
verify_cfg.value_checker.check(fw, co)
9 changes: 5 additions & 4 deletions forge/test/models/pytorch/text/albert/test_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

import forge
from forge.verify.config import VerifyConfig
from forge.verify.verify import verify

from test.models.utils import Framework, Task, build_module_name
Expand All @@ -34,7 +35,7 @@ def test_albert_masked_lm_pytorch(record_forge_property, size, variant):

# Load Albert tokenizer and model from HuggingFace
tokenizer = download_model(AlbertTokenizer.from_pretrained, model_ckpt)
framework_model = download_model(AlbertForMaskedLM.from_pretrained, model_ckpt)
framework_model = download_model(AlbertForMaskedLM.from_pretrained, model_ckpt, return_dict=False)

# Load data sample
sample_text = "The capital of France is [MASK]."
Expand All @@ -53,7 +54,7 @@ def test_albert_masked_lm_pytorch(record_forge_property, size, variant):
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))


sizes = ["base", "large", "xlarge", "xxlarge"]
Expand All @@ -80,7 +81,7 @@ def test_albert_token_classification_pytorch(record_forge_property, size, varian

# Load ALBERT tokenizer and model from HuggingFace
tokenizer = AlbertTokenizer.from_pretrained(model_ckpt)
framework_model = AlbertForTokenClassification.from_pretrained(model_ckpt)
framework_model = AlbertForTokenClassification.from_pretrained(model_ckpt, return_dict=False)

# Load data sample
sample_text = "HuggingFace is a company based in Paris and New York"
Expand All @@ -100,4 +101,4 @@ def test_albert_token_classification_pytorch(record_forge_property, size, varian
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))
23 changes: 15 additions & 8 deletions forge/test/models/pytorch/text/bert/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

import forge
from forge.verify.config import VerifyConfig
from forge.verify.verify import verify

from test.models.utils import Framework, Task, build_module_name
Expand All @@ -20,7 +21,7 @@
def generate_model_bert_maskedlm_hf_pytorch(variant):
# Load Bert tokenizer and model from HuggingFace
tokenizer = BertTokenizer.from_pretrained(variant)
model = BertForMaskedLM.from_pretrained(variant)
model = BertForMaskedLM.from_pretrained(variant, return_dict=False)

# Load data sample
sample_text = "The capital of France is [MASK]."
Expand Down Expand Up @@ -52,13 +53,13 @@ def test_bert_masked_lm_pytorch(record_forge_property, variant):
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))


def generate_model_bert_qa_hf_pytorch(variant):
# Load Bert tokenizer and model from HuggingFace
tokenizer = download_model(BertTokenizer.from_pretrained, variant)
model = download_model(BertForQuestionAnswering.from_pretrained, variant)
model = download_model(BertForQuestionAnswering.from_pretrained, variant, return_dict=False)

# Load data sample from SQuADv1.1
context = """Super Bowl 50 was an American football game to determine the champion of the National Football League
Expand Down Expand Up @@ -94,19 +95,19 @@ def test_bert_question_answering_pytorch(record_forge_property, variant):
# Record Forge Property
record_forge_property("module_name", module_name)

framework_model, inputs, _ = generate_model_bert_qa_hf_pytorch()
framework_model, inputs, _ = generate_model_bert_qa_hf_pytorch(variant)

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))


def generate_model_bert_seqcls_hf_pytorch(variant):
# Load Bert tokenizer and model from HuggingFace
tokenizer = download_model(BertTokenizer.from_pretrained, variant)
model = download_model(BertForSequenceClassification.from_pretrained, variant)
model = download_model(BertForSequenceClassification.from_pretrained, variant, return_dict=False)

# Load data sample
review = "the movie was great!"
Expand Down Expand Up @@ -142,11 +143,17 @@ def test_bert_sequence_classification_pytorch(record_forge_property, variant):
# Model Verification
verify(inputs, framework_model, compiled_model)

co_out = compiled_model(*inputs)
predicted_value = co_out[0].argmax(-1).item()

# Answer - "positive"
print(f"Predicted Sentiment: {framework_model.config.id2label[predicted_value]}")


def generate_model_bert_tkcls_hf_pytorch(variant):
# Load Bert tokenizer and model from HuggingFace
tokenizer = download_model(BertTokenizer.from_pretrained, variant)
model = download_model(BertForTokenClassification.from_pretrained, variant)
model = download_model(BertForTokenClassification.from_pretrained, variant, return_dict=False)

# Load data sample
sample_text = "HuggingFace is a company based in Paris and New York"
Expand Down Expand Up @@ -180,4 +187,4 @@ def test_bert_token_classification_pytorch(record_forge_property, variant):
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))
18 changes: 12 additions & 6 deletions forge/test/models/pytorch/text/dpr/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

import forge
from forge.verify.config import VerifyConfig
from forge.verify.verify import verify

from test.models.utils import Framework, build_module_name
Expand All @@ -33,7 +34,7 @@ def test_dpr_context_encoder_pytorch(record_forge_property, variant):
# Variants: facebook/dpr-ctx_encoder-single-nq-base, facebook/dpr-ctx_encoder-multiset-base
model_ckpt = variant
tokenizer = download_model(DPRContextEncoderTokenizer.from_pretrained, model_ckpt)
framework_model = download_model(DPRContextEncoder.from_pretrained, model_ckpt)
framework_model = download_model(DPRContextEncoder.from_pretrained, model_ckpt, return_dict=False)

# Load data sample
sample_text = "Hello, is my dog cute?"
Expand All @@ -53,7 +54,7 @@ def test_dpr_context_encoder_pytorch(record_forge_property, variant):
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))


variants = ["facebook/dpr-question_encoder-single-nq-base", "facebook/dpr-question_encoder-multiset-base"]
Expand All @@ -74,7 +75,7 @@ def test_dpr_question_encoder_pytorch(record_forge_property, variant):
# Variants: facebook/dpr-question_encoder-single-nq-base, facebook/dpr-question_encoder-multiset-base
model_ckpt = variant
tokenizer = download_model(DPRQuestionEncoderTokenizer.from_pretrained, model_ckpt)
framework_model = download_model(DPRQuestionEncoder.from_pretrained, model_ckpt)
framework_model = download_model(DPRQuestionEncoder.from_pretrained, model_ckpt, return_dict=False)

# Load data sample
sample_text = "Hello, is my dog cute?"
Expand All @@ -93,8 +94,13 @@ def test_dpr_question_encoder_pytorch(record_forge_property, variant):
# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

verify_values = True

if variant == "facebook/dpr-question_encoder-multiset-base":
verify_values = False

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=verify_values))


variants = ["facebook/dpr-reader-single-nq-base", "facebook/dpr-reader-multiset-base"]
Expand All @@ -113,7 +119,7 @@ def test_dpr_reader_pytorch(record_forge_property, variant):
# Variants: facebook/dpr-reader-single-nq-base, facebook/dpr-reader-multiset-base
model_ckpt = variant
tokenizer = download_model(DPRReaderTokenizer.from_pretrained, model_ckpt)
framework_model = download_model(DPRReader.from_pretrained, model_ckpt)
framework_model = download_model(DPRReader.from_pretrained, model_ckpt, return_dict=False)

# Data preprocessing
input_tokens = tokenizer(
Expand All @@ -132,4 +138,4 @@ def test_dpr_reader_pytorch(record_forge_property, variant):
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
verify(inputs, framework_model, compiled_model, verify_cfg=VerifyConfig(verify_values=False))

0 comments on commit 1db1f92

Please sign in to comment.