Skip to content

Commit

Permalink
MongoDB input - add batching (#1)
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Anderson <[email protected]>
  • Loading branch information
boorad authored Oct 13, 2023
1 parent 665f238 commit 4608bae
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 43 deletions.
155 changes: 125 additions & 30 deletions internal/impl/mongodb/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"

"github.com/benthosdev/benthos/v4/public/service"
)
Expand All @@ -14,15 +15,16 @@ import (
const (
FindInputOperation = "find"
AggregateInputOperation = "aggregate"
DefaultBatchSize = 1000
)

func mongoConfigSpec() *service.ConfigSpec {
return service.NewConfigSpec().
// Stable(). TODO
Version("3.64.0").
Categories("Services").
Summary("Executes a find query and creates a message for each row received.").
Description(`Once the rows from the query are exhausted this input shuts down, allowing the pipeline to gracefully terminate (or the next input in a [sequence](/docs/components/inputs/sequence) to execute).`).
Summary("Executes a query and creates a message for each document received.").
Description(`Once the documents from the query are exhausted, this input shuts down, allowing the pipeline to gracefully terminate (or the next input in a [sequence](/docs/components/inputs/sequence) to execute).`).
Fields(clientFields()...).
Field(service.NewStringField("collection").Description("The collection to select from.")).
Field(service.NewStringEnumField("operation", FindInputOperation, AggregateInputOperation).
Expand All @@ -44,26 +46,47 @@ func mongoConfigSpec() *service.ConfigSpec {
Example(`
root.from = {"$lte": timestamp_unix()}
root.to = {"$gte": timestamp_unix()}
`))
`)).
Field(service.NewIntField("batchSize").
Description("A number of documents at which the batch should be flushed. Greater than `0`. Operations: `find`, `aggregate`").
Optional().
Default(1000).
Version("4.22.0")).
Field(service.NewIntMapField("sort").
Description("An object specifying fields to sort by, and the respective sort order (`1` ascending, `-1` descending). Operations: `find`").
Optional().
Example(`
name: 1
age: -1
`).
Version("4.22.0")).
Field(service.NewIntField("limit").
Description("A number of documents to return. Operations: `find`").
Optional().
Version("4.22.0"))
}

func init() {
err := service.RegisterInput(
err := service.RegisterBatchInput(
"mongodb", mongoConfigSpec(),
func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) {
return newMongoInput(conf)
func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) {
return newMongoInput(conf, mgr.Logger())
})
if err != nil {
panic(err)
}
}

func newMongoInput(conf *service.ParsedConfig) (service.Input, error) {
func newMongoInput(conf *service.ParsedConfig, logger *service.Logger) (service.BatchInput, error) {
var (
batchSize, limit int
sort map[string]int
)

mClient, database, err := getClient(conf)
if err != nil {
return nil, err
}

collection, err := conf.FieldString("collection")
if err != nil {
return nil, err
Expand All @@ -84,14 +107,36 @@ func newMongoInput(conf *service.ParsedConfig) (service.Input, error) {
if err != nil {
return nil, err
}

return service.AutoRetryNacks(&mongoInput{
if conf.Contains("batchSize") {
batchSize, err = conf.FieldInt("batchSize")
if err != nil {
return nil, err
}
}
if conf.Contains("sort") {
sort, err = conf.FieldIntMap("sort")
if err != nil {
return nil, err
}
}
if conf.Contains("limit") {
limit, err = conf.FieldInt("limit")
if err != nil {
return nil, err
}
}
return service.AutoRetryNacksBatched(&mongoInput{
query: query,
collection: collection,
client: mClient,
database: database,
operation: operation,
marshalCanon: marshalMode == string(JSONMarshalModeCanonical),
batchSize: int32(batchSize),
sort: sort,
limit: int64(limit),
count: 0,
logger: logger,
}), nil
}

Expand All @@ -103,6 +148,11 @@ type mongoInput struct {
cursor *mongo.Cursor
operation string
marshalCanon bool
batchSize int32
sort map[string]int
limit int64
count int
logger *service.Logger
}

func (m *mongoInput) Connect(ctx context.Context) error {
Expand All @@ -118,11 +168,21 @@ func (m *mongoInput) Connect(ctx context.Context) error {
collection := m.database.Collection(m.collection)
switch m.operation {
case "find":
m.cursor, err = collection.Find(ctx, m.query)
var findOptions *options.FindOptions
findOptions, err = m.getFindOptions()
if err != nil {
return fmt.Errorf("error parsing 'find' options: %v", err)
}
m.cursor, err = collection.Find(ctx, m.query, findOptions)
case "aggregate":
m.cursor, err = collection.Aggregate(ctx, m.query)
var aggregateOptions *options.AggregateOptions
aggregateOptions, err = m.getAggregateOptions()
if err != nil {
return fmt.Errorf("error parsing 'aggregate' options: %v", err)
}
m.cursor, err = collection.Aggregate(ctx, m.query, aggregateOptions)
default:
return fmt.Errorf("opertaion %s not supported. the supported values are \"find\" and \"aggregate\"", m.operation)
return fmt.Errorf("operation '%s' not supported. the supported values are 'find' and 'aggregate'", m.operation)
}
if err != nil {
_ = m.client.Disconnect(ctx)
Expand All @@ -131,33 +191,68 @@ func (m *mongoInput) Connect(ctx context.Context) error {
return nil
}

func (m *mongoInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) {
func (m *mongoInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
i := 0
batch := make(service.MessageBatch, m.batchSize)

if m.cursor == nil {
return nil, nil, service.ErrNotConnected
}
if !m.cursor.Next(ctx) {
return nil, nil, service.ErrEndOfInput
}
var decoded any
if err := m.cursor.Decode(&decoded); err != nil {
return nil, nil, err
}

data, err := bson.MarshalExtJSON(decoded, m.marshalCanon, false)
if err != nil {
return nil, nil, err
}
for m.cursor.Next(ctx) {
msg := service.NewMessage(nil)
msg.MetaSet("mongo_database", m.database.Name())
msg.MetaSet("mongo_collection", m.collection)

msg := service.NewMessage(nil)
msg.SetBytes(data)
return msg, func(ctx context.Context, err error) error {
return nil
}, nil
var decoded any
if err := m.cursor.Decode(&decoded); err != nil {
msg.SetError(err)
} else {
data, err := bson.MarshalExtJSON(decoded, m.marshalCanon, false)
if err != nil {
msg.SetError(err)
}
msg.SetBytes(data)
}
batch[i] = msg
i++
m.count++

if m.cursor.RemainingBatchLength() == 0 {
return batch[:i], func(ctx context.Context, err error) error {
return nil
}, nil
}
}
return nil, nil, service.ErrEndOfInput
}

func (m *mongoInput) Close(ctx context.Context) error {
if m.cursor != nil && m.client != nil {
m.logger.Debugf("Got %d documents from '%s' collection", m.count, m.collection)
return m.client.Disconnect(ctx)
}
return nil
}

func (m *mongoInput) getFindOptions() (*options.FindOptions, error) {
findOptions := options.Find()
if m.batchSize > 0 {
findOptions.SetBatchSize(m.batchSize)
}
if m.sort != nil {
findOptions.SetSort(m.sort)
}
if m.limit > 0 {
findOptions.SetLimit(m.limit)
}
return findOptions, nil
}

func (m *mongoInput) getAggregateOptions() (*options.AggregateOptions, error) {
aggregateOptions := options.Aggregate()
if m.batchSize > 0 {
aggregateOptions.SetBatchSize(m.batchSize)
}
return aggregateOptions, nil
}
57 changes: 46 additions & 11 deletions internal/impl/mongodb/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/benthosdev/benthos/v4/public/service"
)

func TestSQLSelectInputEmptyShutdown(t *testing.T) {
func TestMongoInputEmptyShutdown(t *testing.T) {
conf := `
url: "mongodb://localhost:27017"
username: foouser
Expand All @@ -32,13 +32,14 @@ query: |

spec := mongoConfigSpec()
env := service.NewEnvironment()
resources := service.MockResources()

mongoConfig, err := spec.ParseYAML(conf, env)
require.NoError(t, err)

selectInput, err := newMongoInput(mongoConfig)
mongoInput, err := newMongoInput(mongoConfig, resources.Logger())
require.NoError(t, err)
require.NoError(t, selectInput.Close(context.Background()))
require.NoError(t, mongoInput.Close(context.Background()))
}

func TestInputIntegration(t *testing.T) {
Expand Down Expand Up @@ -128,13 +129,19 @@ func TestInputIntegration(t *testing.T) {
placeholderConf string
jsonMarshalMode JSONMarshalMode
}
limit := int64(3)
cases := map[string]testCase{
"find": {
query: func(coll *mongo.Collection) (*mongo.Cursor, error) {
return coll.Find(context.Background(), bson.M{
"age": bson.M{
"$gte": 18,
},
}, &options.FindOptions{
Sort: bson.M{
"name": 1,
},
Limit: &limit,
})
},
placeholderConf: `
Expand All @@ -146,6 +153,10 @@ collection: "TestCollection"
json_marshal_mode: relaxed
query: |
root.age = {"$gte": 18}
batchSize: 2
sort:
name: 1
limit: 3
`,
jsonMarshalMode: JSONMarshalModeRelaxed,
},
Expand All @@ -160,7 +171,12 @@ query: |
},
},
bson.M{
"$limit": 3,
"$sort": bson.M{
"name": 1,
},
},
bson.M{
"$limit": limit,
},
})
},
Expand All @@ -181,10 +197,16 @@ query: |
}
}
},
{
"$sort": {
"name": 1
}
},
{
"$limit": 3
}
]
batchSize: 2
`,
jsonMarshalMode: JSONMarshalModeCanonical,
},
Expand Down Expand Up @@ -227,27 +249,40 @@ func testInput(

spec := mongoConfigSpec()
env := service.NewEnvironment()
resources := service.MockResources()

mongoConfig, err := spec.ParseYAML(conf, env)
require.NoError(t, err)

selectInput, err := newMongoInput(mongoConfig)
mongoInput, err := newMongoInput(mongoConfig, resources.Logger())
require.NoError(t, err)

ctx := context.Background()
err = selectInput.Connect(ctx)
err = mongoInput.Connect(ctx)
require.NoError(t, err)
for _, wMsg := range wantMsgs {
msg, ack, err := selectInput.Read(ctx)

// read all batches
var actualMsgs service.MessageBatch
for {
batch, ack, err := mongoInput.ReadBatch(ctx)
if err == service.ErrEndOfInput {
break
}
require.NoError(t, err)
actualMsgs = append(actualMsgs, batch...)
require.NoError(t, ack(ctx, nil))
}

// compare to wanted messages
for i, wMsg := range wantMsgs {
msg := actualMsgs[i]
msgBytes, err := msg.AsBytes()
require.NoError(t, err)
assert.JSONEq(t, string(wMsg), string(msgBytes))
require.NoError(t, ack(ctx, nil))
}
_, ack, err := selectInput.Read(ctx)
_, ack, err := mongoInput.ReadBatch(ctx)
assert.Equal(t, service.ErrEndOfInput, err)
require.Nil(t, ack)

require.NoError(t, selectInput.Close(context.Background()))
require.NoError(t, mongoInput.Close(context.Background()))
}
Loading

0 comments on commit 4608bae

Please sign in to comment.