Skip to content

Commit

Permalink
Merge pull request #13 from ilijamt/select-custom-database
Browse files Browse the repository at this point in the history
Select custom database
  • Loading branch information
securingsincity authored Jan 11, 2017
2 parents 5314e13 + 91ba6ba commit c2a7770
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea/
8 changes: 6 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func (v *ValidationError) Error() string {

type Collection struct {
Name string
Database string
Context *Context
Connection *Connection
}

Expand All @@ -77,12 +79,14 @@ func (d DocumentNotFoundError) Error() string {
return "Document not found"
}

// Collection ...
func (c *Collection) Collection() *mgo.Collection {
return c.Connection.Session.DB(c.Connection.Config.Database).C(c.Name)
return c.Connection.Session.DB(c.Database).C(c.Name)
}

// CollectionOnSession ...
func (c *Collection) collectionOnSession(sess *mgo.Session) *mgo.Collection {
return sess.DB(c.Connection.Config.Database).C(c.Name)
return sess.DB(c.Database).C(c.Name)
}

func (c *Collection) PreSave(doc Document) error {
Expand Down
5 changes: 5 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,31 @@ type hookedDocument struct {

func (h *hookedDocument) BeforeSave(c *Collection) error {
h.RanBeforeSave = true
So(c.Context.Get("foo"), ShouldEqual, "bar")
return nil
}

func (h *hookedDocument) AfterSave(c *Collection) error {
h.RanAfterSave = true
So(c.Context.Get("foo"), ShouldEqual, "bar")
return nil
}

func (h *hookedDocument) BeforeDelete(c *Collection) error {
h.RanBeforeDelete = true
So(c.Context.Get("foo"), ShouldEqual, "bar")
return nil
}

func (h *hookedDocument) AfterDelete(c *Collection) error {
h.RanAfterDelete = true
So(c.Context.Get("foo"), ShouldEqual, "bar")
return nil
}

func (h *hookedDocument) AfterFind(c *Collection) error {
h.RanAfterFind = true
So(c.Context.Get("foo"), ShouldEqual, "bar")
return nil
}

Expand Down
30 changes: 30 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package bongo

// Context struct
type Context struct {
set map[string]interface{}
}

// Get ...
func (c *Context) Get(key string) interface{} {
if value, ok := c.set[key]; ok {
return value
}
return nil
}

func (c *Context) Delete(key string) bool {
if _, ok := c.set[key]; ok {
delete(c.set, key)
return true
}
return false
}

// Set ...
func (c *Context) Set(key string, value interface{}) {
if c.set == nil {
c.set = make(map[string]interface{})
}
c.set[key] = value
}
27 changes: 27 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package bongo

import (
"testing"

. "github.com/smartystreets/goconvey/convey"
)

func Test_Context(t *testing.T) {
Convey("Context", t, func() {

Convey("Setting context, checking it and deleting it", func() {
c := &Context{}
c.Set("foo", "bar")
So(c.Get("foo"), ShouldEqual, "bar")
So(c.Delete("foo"), ShouldBeTrue)
})

Convey("Invalid Keys", func() {
c := &Context{}
So(c.Get("foo"), ShouldBeNil)
So(c.Delete("foo"), ShouldBeFalse)
})

})

}
18 changes: 13 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ type Connection struct {
Config *Config
Session *mgo.Session
// collection []Collection
Context *Context
}

// Create a new connection and run Connect()
func Connect(config *Config) (*Connection, error) {
conn := &Connection{
Config: config,
Config: config,
Context: &Context{},
}

err := conn.Connect()
Expand Down Expand Up @@ -57,7 +59,6 @@ func (m *Connection) Connect() (err error) {
}

session, err := mgo.DialWithInfo(m.Config.DialInfo)

if err != nil {
return err
}
Expand All @@ -69,11 +70,18 @@ func (m *Connection) Connect() (err error) {
return nil
}

func (m *Connection) Collection(name string) *Collection {

// Just create a new instance - it's cheap and only has name
// CollectionFromDatabase ...
func (m *Connection) CollectionFromDatabase(name string, database string) *Collection {
// Just create a new instance - it's cheap and only has name and a database name
return &Collection{
Connection: m,
Context: m.Context,
Database: database,
Name: name,
}
}

// Collection ...
func (m *Connection) Collection(name string) *Collection {
return m.CollectionFromDatabase(name, m.Config.Database)
}
29 changes: 27 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func getConnection() *Connection {
}

conn, err := Connect(conf)
conn.Context.Set("foo", "bar")

if err != nil {
panic(err)
Expand Down Expand Up @@ -45,6 +46,10 @@ func TestConnect(t *testing.T) {
defer conn.Session.Close()
So(err, ShouldEqual, nil)

conn.Context.Set("foo", "bar")
value := conn.Context.Get("foo")
So(value, ShouldEqual, "bar")

err = conn.Session.Ping()
So(err, ShouldEqual, nil)
})
Expand All @@ -54,9 +59,29 @@ func TestRetrieveCollection(t *testing.T) {
Convey("should be able to retrieve a collection instance from a connection", t, func() {
conn := getConnection()
defer conn.Session.Close()
col := conn.Collection("tests")

col := conn.Collection("tests");
So(col.Name, ShouldEqual, "tests")
So(col.Connection, ShouldEqual, conn)

So(col.Context.Get("foo"), ShouldEqual, "bar")

So(conn.Config.Database, ShouldEqual, col.Database)
})
Convey("should be able to retrieve a collection instance from a connection with different databases", t, func() {
conn := getConnection()
defer conn.Session.Close()

col1 := conn.CollectionFromDatabase("tests", "test1");
So(col1.Name, ShouldEqual, "tests")
So(col1.Connection, ShouldEqual, conn)
So(col1.Database, ShouldEqual, "test1")

col2 := conn.CollectionFromDatabase("tests", "test2");
So(col2.Name, ShouldEqual, "tests")
So(col2.Connection, ShouldEqual, conn)
So(col2.Database, ShouldEqual, "test2")

So(col2.Connection, ShouldEqual, col1.Connection)
So(col1.Database, ShouldNotEqual, col2.Database)
})
}
2 changes: 1 addition & 1 deletion resultSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (r *ResultSet) Paginate(perPage, page int) (*PaginationInfo, error) {
// Get count on a different session to avoid blocking
sess := r.Collection.Connection.Session.Copy()

count, err := sess.DB(r.Collection.Connection.Config.Database).C(r.Collection.Name).Find(r.Params).Count()
count, err := sess.DB(r.Collection.Database).C(r.Collection.Name).Find(r.Params).Count()
sess.Close()

if err != nil {
Expand Down

0 comments on commit c2a7770

Please sign in to comment.