Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for Colbert model #4

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cmd/converter/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import (
"fmt"
"os"
"strconv"

"github.com/nlpodyssey/cybertron/pkg/converter"
)

func main() {
if len(os.Args) != 3 {
fmt.Printf("Usage: %s <model_path> <overwrite_if_exists>\n", os.Args[0])
os.Exit(1)
}
modelDir := os.Args[1]
overwriteIfExists, err := strconv.ParseBool(os.Args[2])
if err != nil {
fmt.Printf("Failed to parse overwrite_if_exists: %s\n", err)
os.Exit(1)
}
fmt.Printf("Converting model from dir %s\n", modelDir)
err = converter.Convert[float32](modelDir, overwriteIfExists)
if err != nil {
panic(err)
}
}
93 changes: 38 additions & 55 deletions pkg/converter/bert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,16 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {
panic(err)
}

m := bert.New[T](config, repo)
bertForQuestionAnswering := bert.NewModelForQuestionAnswering[T](m)
bertForSequenceClassification := bert.NewModelForSequenceClassification[T](m)
bertForTokenClassification := bert.NewModelForTokenClassification[T](m)
bertForSequenceEncoding := bert.NewModelForSequenceEncoding(m)

baseModel := bert.New[T](config, repo)
{
source := pyParams.Pop("bert.embeddings.word_embeddings.weight")
size := m.Embeddings.Tokens.Config.Size
size := baseModel.Embeddings.Tokens.Config.Size
for i := 0; i < config.VocabSize; i++ {
key, _ := vocab.Term(i)
if len(key) == 0 {
continue // skip empty key
}
item, _ := m.Embeddings.Tokens.Embedding(key)
item, _ := baseModel.Embeddings.Tokens.Embedding(key)
item.ReplaceValue(mat.NewVecDense[T](source[i*size : (i+1)*size]))
}
}
Expand All @@ -119,7 +114,7 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {

{
source := pyParams.Pop("bert.embeddings.position_embeddings.weight")
dest := m.Embeddings.Positions
dest := baseModel.Embeddings.Positions
for i := 0; i < config.MaxPositionEmbeddings; i++ {
item, _ := dest.Embedding(i)
item.ReplaceValue(mat.NewVecDense[T](source[i*cols : (i+1)*cols]))
Expand All @@ -128,27 +123,43 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {

{
source := pyParams.Pop("bert.embeddings.token_type_embeddings.weight")
dest := m.Embeddings.TokenTypes
dest := baseModel.Embeddings.TokenTypes
for i := 0; i < config.TypeVocabSize; i++ {
item, _ := dest.Embedding(i)
item.ReplaceValue(mat.NewVecDense[T](source[i*cols : (i+1)*cols]))
}
}

params := make(paramsMap)
mapPooler(m.Pooler, params)
mapEmbeddingsLayerNorm(m.Embeddings.Norm, params)
mapEncoderParams(m.Encoder, params)
mapQAClassifier(bertForQuestionAnswering.Classifier, params)

{
// both architectures map `classifier` params
switch config.Architectures[0] {
case "BertForSequenceClassification":
mapSeqClassifier(bertForSequenceClassification.Classifier, params)
case "BertForTokenClassification":
mapTokenClassifier(bertForTokenClassification.Classifier, params)
}
mapPooler(baseModel.Pooler, params)
mapEmbeddingsLayerNorm(baseModel.Embeddings.Norm, params)
mapEncoderParams(baseModel.Encoder, params)

var finalModel any
switch config.Architectures[0] {
case "BertBase":
finalModel = baseModel
case "BertForQuestionAnswering":
qaModel := bert.NewModelForQuestionAnswering[T](baseModel)
mapQAClassifier(qaModel.Classifier, params)
finalModel = qaModel

case "BertForSequenceClassification":
scModel := bert.NewModelForSequenceClassification[T](baseModel)
mapSeqClassifier(scModel.Classifier, params)
finalModel = scModel

case "BertForTokenClassification":
tcModel := bert.NewModelForTokenClassification[T](baseModel)
mapTokenClassifier(tcModel.Classifier, params)
finalModel = tcModel

case "HF_ColBERT":
colbertModel := bert.NewColbertModel[T](baseModel)
mapLinear(colbertModel.Linear, params)
finalModel = colbertModel
default:
panic(fmt.Errorf("bert: unsupported architecture %s", config.Architectures[0]))
}

mapping := make(map[string]*mappingParam)
Expand Down Expand Up @@ -191,42 +202,14 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {
}

fmt.Printf("Serializing model to \"%s\"... ", goModelFilename)
err = nn.DumpToFile(finalModel, goModelFilename)
if err != nil {
return err
}
if config.Architectures == nil {
config.Architectures = append(config.Architectures, "BertBase")
}

{
switch config.Architectures[0] {
case "BertBase":
err := nn.DumpToFile(m, goModelFilename)
if err != nil {
return err
}
case "BertModel":
err := nn.DumpToFile(bertForSequenceEncoding, goModelFilename)
if err != nil {
return err
}
case "BertForQuestionAnswering":
err := nn.DumpToFile(bertForQuestionAnswering, goModelFilename)
if err != nil {
return err
}
case "BertForSequenceClassification":
err := nn.DumpToFile(bertForSequenceClassification, goModelFilename)
if err != nil {
return err
}
case "BertForTokenClassification":
err := nn.DumpToFile(bertForTokenClassification, goModelFilename)
if err != nil {
return err
}
default:
panic(fmt.Errorf("bert: unsupported architecture %s", config.Architectures[0]))
}
}

fmt.Println("Done.")

return nil
Expand Down
4 changes: 4 additions & 0 deletions pkg/converter/bert/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func mapTokenClassifier(model *linear.Model, params paramsMap) {
params["classifier.weight"] = model.W.Value()
params["classifier.bias"] = model.B.Value()
}
func mapLinear(model *linear.Model, params paramsMap) {
params["linear.weight"] = model.W.Value()
params["linear.bias"] = model.B.Value()
}

// mapProjectionLayer maps the projection layer parameters.
func mapQAClassifier(model *linear.Model, params paramsMap) {
Expand Down
37 changes: 37 additions & 0 deletions pkg/models/bert/colbert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package bert

import (
"encoding/gob"

"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/nn"
"github.com/nlpodyssey/spago/nn/linear"
)

type ColbertModel struct {
nn.Model
// Bart is the fine-tuned BERT model.
Bert *Model
// Linear is the linear layer for dimensionality reduction
Linear *linear.Model
}

func init() {
gob.Register(&ColbertModel{})
}

// NewColbertModel returns a new model for information retrieval using ColBERT
func NewColbertModel[T float.DType](bert *Model) *ColbertModel {
return &ColbertModel{
Bert: bert,
Linear: linear.New[T](bert.Config.HiddenSize, 128),
// TODO: read size dimensionality reduction layer from config
// (artifact-config.metadata , key: dim)
}
}

// Forward returns the representation for the provided tokens
func (m *ColbertModel) Forward(tokens []string) []ag.Node {
return m.Linear.Forward(m.Bert.Encode(tokens)...)
}
1 change: 1 addition & 0 deletions pkg/tasks/scoring/colbert/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
testdata
128 changes: 128 additions & 0 deletions pkg/tasks/scoring/colbert/ranking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package colbert

import (
"fmt"
"path"
"path/filepath"
"strings"

"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tokenizers"
"github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer"
"github.com/nlpodyssey/cybertron/pkg/vocabulary"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/embeddings/store/diskstore"
"github.com/nlpodyssey/spago/nn"
)

const SpecialDocumentMarker = "[unused1]"

const SpecialQueryMarker = "[unused0]"

type DocumentScorer struct {
Model *bert.ColbertModel
Tokenizer *wordpiecetokenizer.WordPieceTokenizer
}

func LoadDocumentScorer(modelPath string) (*DocumentScorer, error) {
vocab, err := vocabulary.NewFromFile(filepath.Join(modelPath, "vocab.txt"))
if err != nil {
return nil, fmt.Errorf("failed to load vocabulary: %w", err)
}

tokenizer := wordpiecetokenizer.New(vocab)

embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelPath, "repo"), diskstore.ReadOnlyMode)
if err != nil {
return nil, fmt.Errorf("failed to load embeddings repository: %w", err)
}

m, err := nn.LoadFromFile[*bert.ColbertModel](path.Join(modelPath, "spago_model.bin"))
if err != nil {
return nil, fmt.Errorf("failed to load colbert model: %w", err)
}

err = m.Bert.SetEmbeddings(embeddingsRepo)
if err != nil {
return nil, fmt.Errorf("failed to set embeddings: %w", err)
}
return &DocumentScorer{
Model: m,
Tokenizer: tokenizer,
}, nil
}

func (r *DocumentScorer) encode(text string, specialMarker string) []ag.Node {
tokens := r.Tokenizer.Tokenize(strings.ToLower(text))

stringTokens := tokenizers.GetStrings(tokens)
stringTokens = append([]string{wordpiecetokenizer.DefaultClassToken, specialMarker}, stringTokens...)
stringTokens = append(stringTokens, wordpiecetokenizer.DefaultSequenceSeparator)
embeddings := normalizeEmbeddings(r.Model.Forward(stringTokens))
return filterEmbeddings(embeddings, stringTokens)
}

func (r *DocumentScorer) EncodeDocument(text string) []ag.Node {
return r.encode(text, SpecialDocumentMarker)
}

func (r *DocumentScorer) EncodeQuery(text string) []ag.Node {
return r.encode(text, SpecialQueryMarker)
}

func (r *DocumentScorer) ScoreDocument(query []ag.Node, document []ag.Node) ag.Node {
var score ag.Node
score = ag.Scalar(0.0)
for i, q := range query {
if i < 3 || i > len(query)-1 {
continue // don't take special tokens into consideration
}
score = ag.Add(score, r.maxSimilarity(q, document))
}
return score
}

func (r *DocumentScorer) maxSimilarity(query ag.Node, document []ag.Node) ag.Node {
var max ag.Node
max = ag.Scalar(0.0)
for i, d := range document {
if i < 3 || i > len(document)-1 {
continue // don't take special tokens into consideration
}
sim := ag.Dot(query, d)
max = ag.Max(max, sim)
}
return max
}

func normalizeEmbeddings(embeddings []ag.Node) []ag.Node {
// Perform l2 normalization of each embedding
normalized := make([]ag.Node, len(embeddings))
for i, e := range embeddings {
normalized[i] = ag.DivScalar(e, ag.Sqrt(ag.ReduceSum(ag.Square(e))))
}
return normalized
}

func isPunctuation(token string) bool {
return token == "." || token == "," || token == "!" || token == "?" ||
token == ":" || token == ";" || token == "-" || token == "'" ||
token == "\"" || token == "(" || token == ")" || token == "[" ||
token == "]" || token == "{" || token == "}" || token == "*" ||
token == "&" || token == "%" || token == "$" || token == "#" ||
token == "@" || token == "=" || token == "+" ||
token == "_" || token == "~" || token == "/" || token == "\\" ||
token == "|" || token == "`" || token == "^" || token == ">" ||
token == "<"
}

func filterEmbeddings(embeddings []ag.Node, tokens []string) []ag.Node {
filtered := make([]ag.Node, 0, len(embeddings))
for i, e := range embeddings {
if isPunctuation(tokens[i]) {
continue
}
filtered = append(filtered, e)
}
return filtered
}
Loading