diff --git a/backend/backend.go b/backend/backend.go index f511ac9..5a5f320 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -1,8 +1,9 @@ package backend import ( + "errors" "log" - "strings" + "net/url" "sync" "time" @@ -10,33 +11,42 @@ import ( ) type Backend struct { - monitor *Monitor - routingTable *RoutingTable - tls bool - log *log.Logger + monitor *Monitor + tls bool + log *log.Logger // map of principals -> hosts -> connections connectionPool map[string]map[string]bolt.BoltConn + routingCache map[string]RoutingTable + info ClusterInfo } func NewBackend(logger *log.Logger, username, password string, uri string, hosts ...string) (*Backend, error) { - monitor, err := NewMonitor(username, password, uri, hosts...) + tls := false + u, err := url.Parse(uri) if err != nil { return nil, err } - routingTable := <-monitor.C - - tls := false - switch strings.Split(uri, ":")[0] { + switch u.Scheme { case "bolt+s", "bolt+ssc", "neo4j+s", "neo4j+ssc": tls = true + case "bolt", "neo4j": + // ok default: + return nil, errors.New("invalid neo4j connection scheme") + } + + monitor, err := NewMonitor(username, password, uri, hosts...) + if err != nil { + return nil, err } return &Backend{ - monitor: monitor, - routingTable: routingTable, - tls: tls, - log: logger, + monitor: monitor, + tls: tls, + log: logger, + connectionPool: make(map[string]map[string]bolt.BoltConn), + routingCache: make(map[string]RoutingTable), + info: <-monitor.Info, }, nil } @@ -44,23 +54,34 @@ func (b *Backend) Version() Version { return b.monitor.Version } -func (b *Backend) RoutingTable() *RoutingTable { - if b.routingTable == nil { - panic("attempting to use uninitialized BackendClient") +func (b *Backend) RoutingTable(db string) (RoutingTable, error) { + table, found := b.routingCache[db] + if found && !table.Expired() { + return table, nil } - b.log.Println("checking routing table...") - if b.routingTable.Expired() { + table, err := b.monitor.UpdateRoutingTable(db) + if err != nil { + return RoutingTable{}, err + } + + b.log.Printf("got routing table for %s: %s", db, table) + return table, nil +} + +func (b *Backend) ClusterInfo() (ClusterInfo, error) { + // XXX: this technically isn't thread safe as we mutate b.info + + if b.info.CreatedAt.Add(30 * time.Second).Before(time.Now()) { select { - case rt := <-b.monitor.C: - b.routingTable = rt - case <-time.After(60 * time.Second): - b.log.Fatal("timeout waiting for new routing table!") + case <-time.After(30 * time.Second): + return ClusterInfo{}, errors.New("timeout waiting for updated ClusterInfo") + case info := <-b.monitor.Info: + b.info = info } } - b.log.Println("using routing table") - return b.routingTable + return b.info, nil } // For now, we'll authenticate to all known hosts up-front to simplify things. @@ -74,7 +95,7 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e panic("authenticate requires a Hello message") } - // TODO: clean up this api...push the dirt into Bolt package + // TODO: clean up this api...push the dirt into Bolt package? msg, pos, err := bolt.ParseMap(hello.Data[4:]) if err != nil { b.log.Printf("XXX pos: %d, hello map: %#v\n", pos, msg) @@ -86,25 +107,24 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e } b.log.Println("found principal:", principal) - // refresh routing table - // TODO: this api seems backwards...push down into table? - rt := b.RoutingTable() - - // Try authing first with the default db writer before we try others + // Try authing first with a Core cluster member before we try others // this way we can fail fast and not spam a bad set of credentials - writers, _ := rt.WritersFor(rt.DefaultDb) - defaultWriter := writers[0] + info, err := b.ClusterInfo() + if err != nil { + return nil, err + } + defaultHost := info.Hosts[0] - b.log.Printf("trying to auth %s to host %s\n", principal, defaultWriter) + b.log.Printf("trying to auth %s to host %s\n", principal, defaultHost) conn, err := authClient(hello.Data, b.Version().Bytes(), - "tcp", defaultWriter, b.tls) + "tcp", defaultHost, b.tls) if err != nil { return nil, err } // Ok, now to get the rest - conns := make(map[string]bolt.BoltConn, len(rt.Hosts)) - conns[defaultWriter] = bolt.NewDirectConn(conn) + conns := make(map[string]bolt.BoltConn, len(info.Hosts)) + conns[defaultHost] = bolt.NewDirectConn(conn) // We'll need a channel to collect results as we're going to auth // to all hosts asynchronously @@ -112,11 +132,11 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e conn bolt.BoltConn host string } - c := make(chan pair, len(rt.Hosts)+1) + c := make(chan pair, len(info.Hosts)+1) var wg sync.WaitGroup - for host := range rt.Hosts { + for _, host := range info.Hosts { // skip the host we already used to test auth - if host != defaultWriter { + if host != defaultHost { wg.Add(1) go func(h string) { defer wg.Done() diff --git a/backend/monitor.go b/backend/monitor.go index 9fe86d5..b31d265 100644 --- a/backend/monitor.go +++ b/backend/monitor.go @@ -10,21 +10,14 @@ import ( "github.com/neo4j/neo4j-go-driver/v4/neo4j" ) -// Modeled after time.Ticker, a Monitor will keep tabs on the Neo4j routing -// table behind the scenes. It auto-adjusts the refresh interval to match -// the server's declared TTL recommendation. -// -// As it creates new RoutingTable instances on the heap, it will put pointers -// to new instances into the channel C. (Similar to how time.Ticker puts the -// current time into its channel.) -// -// Before it puts a new pointer in the channel, it tries to empty it, which -// hopefully reduces the chance of receiving stale entries. +// TODO: what the hell are we doing here? type Monitor struct { - C <-chan *RoutingTable + Info <-chan ClusterInfo halt chan bool driver *neo4j.Driver Version Version + Ttl time.Duration + Host string } type Version struct { @@ -59,6 +52,10 @@ func (v Version) Bytes() []byte { } } +func (m Monitor) UpdateRoutingTable(db string) (RoutingTable, error) { + return getRoutingTable(m.driver, db, m.Host) +} + // Our default Driver configuration provides: // - custom user-agent name // - ability to add in specific list of hosts to use for address resolution @@ -79,7 +76,8 @@ func newConfigurer(hosts []string) func(c *neo4j.Config) { } return addrs } - c.UserAgent = "bolt-proxy/v0" + // TODO: wire into global version string + c.UserAgent = "bolt-proxy/v0.3.0" } } @@ -94,6 +92,8 @@ RETURN [x IN split(head(split(version, "-")), ".") | toInteger(x)] AS version, // which should provide an array of int64s corresponding to the Version. // Return the Version on success, otherwise return an empty Version and // and error. +// +// Note: Aura provides a special version string because Aura is special. func getVersion(driver *neo4j.Driver) (Version, error) { version := Version{} session := (*driver).NewSession(neo4j.SessionConfig{}) @@ -149,8 +149,8 @@ func getVersion(driver *neo4j.Driver) (Version, error) { // Any additional hosts provided will be used as part of a custom address // resolution function via the neo4j.Driver. func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { - c := make(chan *RoutingTable, 1) - h := make(chan bool, 1) + infoChan := make(chan ClusterInfo, 1) + haltChan := make(chan bool, 1) // Try immediately to connect to Neo4j auth := neo4j.BasicAuth(user, password, "") @@ -163,7 +163,6 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { if err != nil { panic(err) } - // log.Printf("found neo4j version %v\n", version) // TODO: check if in SINGLE, CORE, or READ_REPLICA mode // We can run `CALL dbms.listConfig('dbms.mode') YIELD value` and @@ -171,7 +170,7 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { // simplify the monitor considerably to just health checks and no // routing table. - // Get the first routing table and ttl details + // Get the cluster members and ttl details u, err := url.Parse(uri) if err != nil { return nil, err @@ -181,45 +180,49 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { host = host + ":7687" } - rt, err := getNewRoutingTable(&driver, host) + info, err := getClusterInfo(&driver, host) if err != nil { - panic(err) + return nil, err + } + infoChan <- info + + monitor := Monitor{ + Info: infoChan, + halt: haltChan, + driver: &driver, + Version: version, + Ttl: info.Ttl, // since right now it's not dynamic + Host: host, } - c <- rt - monitor := Monitor{c, h, &driver, version} go func() { - // preset the initial ticker to use the first ttl measurement - ticker := time.NewTicker(rt.Ttl) + // TODO: configurable cluster info update frequency + ticker := time.NewTicker(time.Second * 30) for { select { case <-ticker.C: - rt, err := getNewRoutingTable(monitor.driver, host) + info, err := getClusterInfo(monitor.driver, monitor.Host) if err != nil { + // TODO: how do we handle faults??? panic(err) } - ticker.Reset(rt.Ttl) - + ticker.Reset(time.Second * 30) // empty the channel and put the new value in // this looks odd, but even though it's racy, // it should be racy in a safe way since it // doesn't matter if another go routine takes // the value first select { - case <-c: + case <-infoChan: default: } select { - case c <- rt: + case infoChan <- info: default: panic("monitor channel full") } - case <-h: + case <-haltChan: ticker.Stop() - // log.Println("monitor stopped") - case <-time.After(10 * rt.Ttl): - msg := fmt.Sprintf("monitor timeout of 10*%v reached\n", rt.Ttl) - panic(msg) } } }() @@ -234,112 +237,26 @@ func (m *Monitor) Stop() { } } -// local data structure for passing the raw routing table details -// TODO: ttl could be pulled direct via a check of dbms.routing_ttl -// since it's not a dynamic config value as of v4.2 -type table struct { - db string - ttl time.Duration - readers []string - writers []string -} - // Denormalize the routing table to make post-processing easier const ROUTING_QUERY = ` -UNWIND $names AS name -CALL dbms.routing.getRoutingTable({address: $host}, name) +CALL dbms.routing.getRoutingTable({address: $host}, $db) YIELD ttl, servers -WITH name, ttl, servers UNWIND servers AS server -WITH name, ttl, server UNWIND server["addresses"] AS address -RETURN name, ttl, server["role"] AS role, address +RETURN server["role"] AS role, address ` -// Dump the list of databases. We need to keep this simple to support v4.0, -// v4.1, and v4.2 since this has been a moving target in how it works -const SHOW_DATABASES = "SHOW DATABASES" - -// Use SHOW DATABASES to dump the current list of databases with the first -// database name being the default (based on the query logic) -func queryDbNames(driver *neo4j.Driver) ([]string, error) { - session := (*driver).NewSession(neo4j.SessionConfig{ - DatabaseName: "system", - }) - defer session.Close() - - result, err := session.Run(SHOW_DATABASES, nil) - if err != nil { - return nil, err - } - rows, err := result.Collect() - if err != nil { - return nil, err - } - - // create a basic set structure - nameSet := make(map[string]bool) - for _, row := range rows { - val, found := row.Get("name") - if !found { - return nil, errors.New("couldn't find name field in result") - } - name, ok := val.(string) - if !ok { - panic("name isn't a string") - } - - val, found = row.Get("currentStatus") - if !found { - return nil, errors.New("couldn't find currentStatus field in result") - } - status, ok := val.(string) - if !ok { - panic("currentStatus isn't a string") - } - - if status == "online" { - nameSet[name] = true - } - } - - names := make([]string, 0, len(nameSet)) - for key := range nameSet { - names = append(names, key) - } - return names, nil -} - -func queryRoutingTable(driver *neo4j.Driver, host string, names []string) (map[string]table, error) { - session := (*driver).NewSession(neo4j.SessionConfig{}) - defer session.Close() - - result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { - return routingTableTx(tx, host, names) - }) - if err != nil { - return map[string]table{}, err - } - - tableMap, ok := result.(map[string]table) - if !ok { - return map[string]table{}, errors.New("invalid type for routing table response") - } - - return tableMap, nil -} - // Given a neo4j.Transaction tx, collect the routing table maps for each of // the databases in names. Since this should run in a transaction work function // we return a generic interface{} on success, or nil and an error if failed. // -// The true data type is a map[string]table, mapping database names to their -// respective tables. -func routingTableTx(tx neo4j.Transaction, host string, names []string) (interface{}, error) { - params := make(map[string]interface{}, 1) - params["names"] = names - params["host"] = host - result, err := tx.Run(ROUTING_QUERY, params) +// The true data type is a table struct, mapping providing arrays of readers, +// writers, and routers for the given db +func routingTableTx(tx neo4j.Transaction, host string, db string) (interface{}, error) { + result, err := tx.Run(ROUTING_QUERY, map[string]interface{}{ + "db": db, + "host": host, + }) if err != nil { return nil, err } @@ -349,8 +266,13 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac return nil, err } - // expected fields: [name, ttl, role, address] - tableMap := make(map[string]table, len(rows)) + // expected fields: [role, address] + t := RoutingTable{ + Name: db, + Readers: []string{}, + Writers: []string{}, + Routers: []string{}, + } for _, row := range rows { val, found := row.Get("address") if !found { @@ -361,32 +283,6 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac panic("addr isn't a string!") } - val, found = row.Get("ttl") - if !found { - return nil, errors.New("missing ttl field in result") - } - ttl, ok := val.(int64) - if !ok { - panic("ttl isn't an integer!") - } - - val, found = row.Get("name") - if !found { - return nil, errors.New("missing name field in result") - } - name, ok := val.(string) - if !ok { - panic("name isn't a string!") - } - - t, found := tableMap[name] - if !found { - t = table{ - db: name, - ttl: time.Duration(ttl) * time.Second, - } - } - val, found = row.Get("role") if !found { return nil, errors.New("missing role field in result") @@ -398,71 +294,103 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac switch role { case "READ": - t.readers = append(t.readers, addr) + t.Readers = append(t.Readers, addr) case "WRITE": - t.writers = append(t.writers, addr) + t.Writers = append(t.Writers, addr) case "ROUTE": - continue + t.Routers = append(t.Routers, addr) default: return nil, errors.New("invalid role") } - - tableMap[name] = t } - return tableMap, nil + return t, nil } // Using a pointer to a connected neo4j.Driver, orchestrate fetching the -// database names and get the current routing table for each. -// -// XXX: this is pretty heavy weight :-( -func getNewRoutingTable(driver *neo4j.Driver, host string) (*RoutingTable, error) { - names, err := queryDbNames(driver) - if err != nil { - msg := fmt.Sprintf("error getting database names: %v\n", err) - return nil, errors.New(msg) - } +// routing table for a given database while using the provided host +// routing context. +func getRoutingTable(driver *neo4j.Driver, db, host string) (RoutingTable, error) { + session := (*driver).NewSession(neo4j.SessionConfig{}) + defer session.Close() - tableMap, err := queryRoutingTable(driver, host, names) + result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { + return routingTableTx(tx, host, db) + }) if err != nil { - msg := fmt.Sprintf("error getting routing table: %v\n", err) - return nil, errors.New(msg) + return RoutingTable{}, err } - - // build the new routing table instance - // TODO: clean this up...seems smelly.. - readers := make(map[string][]string) - writers := make(map[string][]string) - rt := RoutingTable{ - DefaultDb: names[0], - readers: readers, - writers: writers, - CreatedAt: time.Now(), - Hosts: make(map[string]bool), + table, ok := result.(RoutingTable) + if !ok { + panic("invalid return type: expected RoutingTable") } - for db, t := range tableMap { - r := make([]string, len(t.readers)) - copy(r, t.readers) - w := make([]string, len(t.writers)) - copy(w, t.writers) - rt.readers[db] = r - rt.writers[db] = w - - // yes, this is redundant... - rt.Ttl = t.ttl - - // yes, this is also wasteful...construct host sets - for _, host := range r { - rt.Hosts[host] = true + + return table, nil +} + +// Populate a ClusterInfo instance with critical details on our backend +func getClusterInfo(driver *neo4j.Driver, host string) (ClusterInfo, error) { + session := (*driver).NewSession(neo4j.SessionConfig{ + DatabaseName: "system", + }) + defer session.Close() + + // Inline TX function here for building the ClusterInfo on the fly + result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { + info := ClusterInfo{CreatedAt: time.Now()} + result, err := tx.Run("SHOW DATABASES", nil) + if err != nil { + return info, err } - for _, host := range w { - rt.Hosts[host] = true + rows, err := result.Collect() + if err != nil { + return info, err + } + + for _, row := range rows { + val, found := row.Get("default") + if !found { + return info, errors.New("missing 'default' field") + } + defaultDb, ok := val.(bool) + if !ok { + return info, errors.New("default field isn't a boolean") + } + if defaultDb { + val, found = row.Get("name") + if !found { + return info, errors.New("missing 'name' field") + } + name, ok := val.(string) + if !ok { + return info, errors.New("name field isn't a string") + } + info.DefaultDb = name + } } + + return info, nil + }) + if err != nil { + return ClusterInfo{}, err } - // log.Printf("updated routing table: %s\n", &rt) - // log.Printf("known hosts look like: %v\n", rt.Hosts) + info, ok := result.(ClusterInfo) + if !ok { + panic("result isn't a ClusterInfo struct") + } - return &rt, nil + // For now get details for System db... + rt, err := getRoutingTable(driver, "system", host) + if err != nil { + return info, err + } + hosts := map[string]bool{} + for _, host := range append(rt.Readers, rt.Writers...) { + hosts[host] = true + } + for host := range hosts { + info.Hosts = append(info.Hosts, host) + } + return info, nil } diff --git a/backend/routing.go b/backend/routing.go index c24a52a..227ebf1 100644 --- a/backend/routing.go +++ b/backend/routing.go @@ -1,52 +1,41 @@ package backend import ( - "errors" "fmt" "time" ) type RoutingTable struct { - readers map[string][]string - writers map[string][]string - Hosts map[string]bool - DefaultDb string - Ttl time.Duration - CreatedAt time.Time + Name string + Readers, Writers, Routers []string + CreatedAt time.Time + Ttl time.Duration } -func (rt *RoutingTable) Expired() bool { - now := time.Now() - return rt.CreatedAt.Add(rt.Ttl).Before(now) +func (t RoutingTable) String() string { + return fmt.Sprintf("DbTable{ Name: %s, "+ + "Readers: %s, "+ + "Writers: %s, "+ + "Routers: %s, "+ + "CreatedAt: %s, "+ + "Ttl: %s}", + t.Name, t.Readers, t.Writers, t.Routers, + t.CreatedAt, t.Ttl) } -func (rt *RoutingTable) String() string { - return fmt.Sprintf( - "RoutingTable{ DefaultDb: %s, "+ - "readerMap: %v, "+ - "writerMap: %v, "+ - "Ttl: %s, "+ - "CreatedAt: %s }", - rt.DefaultDb, rt.readers, rt.writers, rt.Ttl, rt.CreatedAt, - ) +type ClusterInfo struct { + DefaultDb string + Ttl time.Duration + Hosts []string + CreatedAt time.Time } -func (rt *RoutingTable) ReadersFor(db string) ([]string, error) { - readers, found := rt.readers[db] - if found { - result := make([]string, len(readers)) - copy(result, readers) - return result, nil - } - return nil, errors.New("no such database") +func (rt RoutingTable) Expired() bool { + return rt.CreatedAt.Add(rt.Ttl).Before(time.Now()) } -func (rt *RoutingTable) WritersFor(db string) ([]string, error) { - writers, found := rt.writers[db] - if found { - result := make([]string, len(writers)) - copy(result, writers) - return result, nil - } - return nil, errors.New("no such database") +func (i ClusterInfo) String() string { + return fmt.Sprintf( + "ClusterInfo{ DefaultDb: %s, Ttl: %v, Hosts: %v, CreatedAt: %v }", + i.DefaultDb, i.Ttl, i.Hosts, i.CreatedAt) } diff --git a/proxy.go b/proxy.go index e18d401..6184e25 100644 --- a/proxy.go +++ b/proxy.go @@ -277,7 +277,8 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe // get backend connection pool, err := b.Authenticate(hello) if err != nil { - warn.Fatal(err) + warn.Println(err) + return } // TODO: this seems odd...move parser and version stuff to bolt pkg @@ -360,14 +361,17 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe // we need to find a new connection to switch to if startingTx { mode, _ := bolt.ValidateMode(msg.Data) - rt := b.RoutingTable() - db := rt.DefaultDb + info, err := b.ClusterInfo() + if err != nil { + warn.Printf("error getting cluster info: %s\n", err) + return + } + db := info.DefaultDb // get the db name, if any. otherwise, use default var ( - m map[string]interface{} - err error - n int + m map[string]interface{} + n int ) if msg.T == bolt.BeginMsg { m, _, err = bolt.ParseMap(msg.Data[4:]) @@ -416,11 +420,16 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe } // Just choose the first one for now...something simple + rt, err := b.RoutingTable(db) + if err != nil { + warn.Printf("error getting routing table for %s: %s\n", db, err) + return + } var hosts []string if mode == bolt.ReadMode { - hosts, err = rt.ReadersFor(db) + hosts = rt.Readers } else { - hosts, err = rt.WritersFor(db) + hosts = rt.Writers } if err != nil { warn.Printf("couldn't find host for '%s' in routing table", db)