Skip to content

Commit

Permalink
WIP: Add support for transformer models
Browse files Browse the repository at this point in the history
Add a simple example model for detecting defects, both with finetuning and with using embeddings
to train a xgboost model

Fixes #39
  • Loading branch information
marco-c committed Oct 27, 2023
1 parent 798f215 commit 49e79c8
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 2 deletions.
2 changes: 2 additions & 0 deletions bugbug/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"component": "bugbug.models.component.ComponentModel",
"component_nn": "bugbug.models.component_nn.ComponentNNModel",
"defect": "bugbug.models.defect.DefectModel",
"defect_finetuning": "bugbug.models.defect.DefectFinetuningModel",
"defect_embedding": "bugbug.models.defect.DefectEmbeddingModel",
"defectenhancementtask": "bugbug.models.defect_enhancement_task.DefectEnhancementTaskModel",
"devdocneeded": "bugbug.models.devdocneeded.DevDocNeededModel",
"duplicate": "bugbug.models.duplicate.DuplicateModel",
Expand Down
98 changes: 98 additions & 0 deletions bugbug/models/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
import logging
from typing import Any

import torch
import xgboost
from imblearn.over_sampling import BorderlineSMOTE
from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction import DictVectorizer
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
from skorch.callbacks import ProgressBar
from skorch.hf import HuggingfacePretrainedTokenizer
from torch import nn

from bugbug import bug_features, bugzilla, feature_cleanup, labels, utils
from bugbug.model import BugModel
from bugbug.nn import DistilBertModule, ExtractEmbeddings, get_training_device
from bugbug.utils import MergeText

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -283,3 +290,94 @@ def overwrite_classes(self, bugs, classes, probabilities):
classes[i] = 0 if not probabilities else [1.0, 0.0]

return classes


class DefectFinetuningModel(DefectModel):
def __init__(self, last_layer_only=True, **kwargs):
super().__init__(**kwargs)

self.sampler = None
self.calculate_importance = False
self.cross_validation_enabled = False

self.extraction_pipeline = Pipeline(
[
(
"bug_extractor",
bug_features.BugExtractor([], [], rollback=True),
),
("extract", MergeText(["title", "comments"])),
]
)

self.clf = Pipeline(
[
(
"tokenizer",
HuggingfacePretrainedTokenizer(
"distilbert-base-uncased", max_length=512
),
),
(
"classifier",
NeuralNetClassifier(
DistilBertModule,
module__name="distilbert-base-uncased",
module__num_labels=2,
module__last_layer_only=last_layer_only,
optimizer=torch.optim.AdamW,
lr=6e-5,
max_epochs=2,
criterion=nn.CrossEntropyLoss,
batch_size=4,
iterator_train__shuffle=True,
device=get_training_device(),
callbacks=[
ProgressBar(),
],
),
),
]
)

def get_feature_names(self):
return []


class DefectEmbeddingModel(DefectModel):
def __init__(self, **kwargs):
print(**kwargs)
super().__init__(**kwargs)

self.sampler = None
self.calculate_importance = False
self.cross_validation_enabled = False

self.extraction_pipeline = Pipeline(
[
(
"bug_extractor",
bug_features.BugExtractor([], [], rollback=True),
),
("extract", MergeText(["title", "comments"])),
]
)

self.clf = Pipeline(
[
(
"tokenizer",
HuggingfacePretrainedTokenizer(
"distilbert-base-uncased", max_length=512
),
),
("extract_embeddings", ExtractEmbeddings("distilbert-base-uncased")),
(
"classifier",
xgboost.XGBClassifier(n_jobs=utils.get_physical_cpu_count()),
),
]
)

def get_feature_names(self):
return []
42 changes: 42 additions & 0 deletions bugbug/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.

import torch
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from torch import nn
from transformers import AutoModel, AutoModelForSequenceClassification

from bugbug.utils import numpy_to_dict

Expand Down Expand Up @@ -55,3 +58,42 @@ def predict_proba(self, X):

def predict(self, X):
return self.predict_proba(X).argmax(axis=-1)


class ExtractEmbeddings(BaseEstimator, TransformerMixin):
def __init__(self, model_name: str):
self.model = AutoModel.from_pretrained(model_name)

def fit(self, X, y):
return self

def transform(self, X):
with torch.no_grad():
# TODO: support .last_hidden_state.mean(dim=1) as an alternative
return self.model(**X).last_hidden_state[:, 0, :]


def get_training_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"


class DistilBertModule(nn.Module):
def __init__(self, name, num_labels, last_layer_only=False):
super().__init__()
self.name = name
self.num_labels = num_labels
self.last_layer_only = last_layer_only

self.reset_weights()

def reset_weights(self):
self.bert = AutoModelForSequenceClassification.from_pretrained(
self.name, num_labels=self.num_labels
)
if self.last_layer_only:
for param in self.bert.distilbert.parameters():
param.requires_grad = False

def forward(self, **kwargs):
pred = self.bert(**kwargs)
return pred.logits
11 changes: 11 additions & 0 deletions bugbug/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ def transform(self, data):
return np.array([elem[self.key] for elem in data]).reshape(-1, 1)


class MergeText(BaseEstimator, TransformerMixin):
def __init__(self, cols):
self.cols = cols

def fit(self, X, y=None):
return self

def transform(self, X):
return X[self.cols].apply(lambda row: " ".join(row), axis=1)


class MissingOrdinalEncoder(OrdinalEncoder):
"""Ordinal encoder that ignores missing values encountered after training.
Expand Down
4 changes: 3 additions & 1 deletion scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def go(self, args):
model_name = args.model

model_class = get_model_class(model_name)
parameter_names = set(inspect.signature(model_class.__init__).parameters)
parameter_names = set(inspect.signature(model_class.__init__).parameters) - {
"kwargs"
}
parameters = {
key: value for key, value in vars(args).items() if key in parameter_names
}
Expand Down
49 changes: 48 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from scripts import trainer


def test_trainer():
# Test xgboost model on TF-IDF
def test_trainer_simple():
# Pretend the DB was already downloaded and no new DB is available.

url = "https://community-tc.services.mozilla.com/api/index/v1/task/project.bugbug.data_bugs.latest/artifacts/public/bugs.json"
Expand All @@ -29,3 +30,49 @@ def test_trainer():
)

trainer.Trainer().go(trainer.parse_args(["defect"]))


# Test finetuning of transformer model
def test_trainer_finetuning():
# Pretend the DB was already downloaded and no new DB is available.

url = "https://community-tc.services.mozilla.com/api/index/v1/task/project.bugbug.data_bugs.latest/artifacts/public/bugs.json"

responses.add(
responses.GET,
f"{url}.version",
status=200,
body=str(db.DATABASES[bugzilla.BUGS_DB]["version"]),
)

responses.add(
responses.HEAD,
f"{url}.zst",
status=200,
headers={"ETag": "etag"},
)

trainer.Trainer().go(trainer.parse_args(["defect_finetuning"]))


# Test xgboost model on transformed model's embeddings
def test_trainer_embedding():
# Pretend the DB was already downloaded and no new DB is available.

url = "https://community-tc.services.mozilla.com/api/index/v1/task/project.bugbug.data_bugs.latest/artifacts/public/bugs.json"

responses.add(
responses.GET,
f"{url}.version",
status=200,
body=str(db.DATABASES[bugzilla.BUGS_DB]["version"]),
)

responses.add(
responses.HEAD,
f"{url}.zst",
status=200,
headers={"ETag": "etag"},
)

trainer.Trainer().go(trainer.parse_args(["defect_embedding"]))

0 comments on commit 49e79c8

Please sign in to comment.