From de163dd92b917e8360ac19d7a680391453848e7b Mon Sep 17 00:00:00 2001 From: Dave Voutila Date: Tue, 22 Dec 2020 22:19:54 -0500 Subject: [PATCH] fixed routing table update intervals miss changes I'm not too thrilled with this change...it feels messier...but it's a start at fetching Routing Tables on demand and caching them for some period of time. This solves the problem of a new database being created and it not being accessible immediately. There are still some issues for newly created databases and they need to be ironed out, but existing databases should still behave the same as before. --- backend/backend.go | 100 ++++++++------ backend/monitor.go | 334 ++++++++++++++++++--------------------------- backend/routing.go | 59 ++++---- proxy.go | 25 ++-- 4 files changed, 232 insertions(+), 286 deletions(-) diff --git a/backend/backend.go b/backend/backend.go index f511ac9..5a5f320 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -1,8 +1,9 @@ package backend import ( + "errors" "log" - "strings" + "net/url" "sync" "time" @@ -10,33 +11,42 @@ import ( ) type Backend struct { - monitor *Monitor - routingTable *RoutingTable - tls bool - log *log.Logger + monitor *Monitor + tls bool + log *log.Logger // map of principals -> hosts -> connections connectionPool map[string]map[string]bolt.BoltConn + routingCache map[string]RoutingTable + info ClusterInfo } func NewBackend(logger *log.Logger, username, password string, uri string, hosts ...string) (*Backend, error) { - monitor, err := NewMonitor(username, password, uri, hosts...) + tls := false + u, err := url.Parse(uri) if err != nil { return nil, err } - routingTable := <-monitor.C - - tls := false - switch strings.Split(uri, ":")[0] { + switch u.Scheme { case "bolt+s", "bolt+ssc", "neo4j+s", "neo4j+ssc": tls = true + case "bolt", "neo4j": + // ok default: + return nil, errors.New("invalid neo4j connection scheme") + } + + monitor, err := NewMonitor(username, password, uri, hosts...) + if err != nil { + return nil, err } return &Backend{ - monitor: monitor, - routingTable: routingTable, - tls: tls, - log: logger, + monitor: monitor, + tls: tls, + log: logger, + connectionPool: make(map[string]map[string]bolt.BoltConn), + routingCache: make(map[string]RoutingTable), + info: <-monitor.Info, }, nil } @@ -44,23 +54,34 @@ func (b *Backend) Version() Version { return b.monitor.Version } -func (b *Backend) RoutingTable() *RoutingTable { - if b.routingTable == nil { - panic("attempting to use uninitialized BackendClient") +func (b *Backend) RoutingTable(db string) (RoutingTable, error) { + table, found := b.routingCache[db] + if found && !table.Expired() { + return table, nil } - b.log.Println("checking routing table...") - if b.routingTable.Expired() { + table, err := b.monitor.UpdateRoutingTable(db) + if err != nil { + return RoutingTable{}, err + } + + b.log.Printf("got routing table for %s: %s", db, table) + return table, nil +} + +func (b *Backend) ClusterInfo() (ClusterInfo, error) { + // XXX: this technically isn't thread safe as we mutate b.info + + if b.info.CreatedAt.Add(30 * time.Second).Before(time.Now()) { select { - case rt := <-b.monitor.C: - b.routingTable = rt - case <-time.After(60 * time.Second): - b.log.Fatal("timeout waiting for new routing table!") + case <-time.After(30 * time.Second): + return ClusterInfo{}, errors.New("timeout waiting for updated ClusterInfo") + case info := <-b.monitor.Info: + b.info = info } } - b.log.Println("using routing table") - return b.routingTable + return b.info, nil } // For now, we'll authenticate to all known hosts up-front to simplify things. @@ -74,7 +95,7 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e panic("authenticate requires a Hello message") } - // TODO: clean up this api...push the dirt into Bolt package + // TODO: clean up this api...push the dirt into Bolt package? msg, pos, err := bolt.ParseMap(hello.Data[4:]) if err != nil { b.log.Printf("XXX pos: %d, hello map: %#v\n", pos, msg) @@ -86,25 +107,24 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e } b.log.Println("found principal:", principal) - // refresh routing table - // TODO: this api seems backwards...push down into table? - rt := b.RoutingTable() - - // Try authing first with the default db writer before we try others + // Try authing first with a Core cluster member before we try others // this way we can fail fast and not spam a bad set of credentials - writers, _ := rt.WritersFor(rt.DefaultDb) - defaultWriter := writers[0] + info, err := b.ClusterInfo() + if err != nil { + return nil, err + } + defaultHost := info.Hosts[0] - b.log.Printf("trying to auth %s to host %s\n", principal, defaultWriter) + b.log.Printf("trying to auth %s to host %s\n", principal, defaultHost) conn, err := authClient(hello.Data, b.Version().Bytes(), - "tcp", defaultWriter, b.tls) + "tcp", defaultHost, b.tls) if err != nil { return nil, err } // Ok, now to get the rest - conns := make(map[string]bolt.BoltConn, len(rt.Hosts)) - conns[defaultWriter] = bolt.NewDirectConn(conn) + conns := make(map[string]bolt.BoltConn, len(info.Hosts)) + conns[defaultHost] = bolt.NewDirectConn(conn) // We'll need a channel to collect results as we're going to auth // to all hosts asynchronously @@ -112,11 +132,11 @@ func (b *Backend) Authenticate(hello *bolt.Message) (map[string]bolt.BoltConn, e conn bolt.BoltConn host string } - c := make(chan pair, len(rt.Hosts)+1) + c := make(chan pair, len(info.Hosts)+1) var wg sync.WaitGroup - for host := range rt.Hosts { + for _, host := range info.Hosts { // skip the host we already used to test auth - if host != defaultWriter { + if host != defaultHost { wg.Add(1) go func(h string) { defer wg.Done() diff --git a/backend/monitor.go b/backend/monitor.go index 9fe86d5..b31d265 100644 --- a/backend/monitor.go +++ b/backend/monitor.go @@ -10,21 +10,14 @@ import ( "github.com/neo4j/neo4j-go-driver/v4/neo4j" ) -// Modeled after time.Ticker, a Monitor will keep tabs on the Neo4j routing -// table behind the scenes. It auto-adjusts the refresh interval to match -// the server's declared TTL recommendation. -// -// As it creates new RoutingTable instances on the heap, it will put pointers -// to new instances into the channel C. (Similar to how time.Ticker puts the -// current time into its channel.) -// -// Before it puts a new pointer in the channel, it tries to empty it, which -// hopefully reduces the chance of receiving stale entries. +// TODO: what the hell are we doing here? type Monitor struct { - C <-chan *RoutingTable + Info <-chan ClusterInfo halt chan bool driver *neo4j.Driver Version Version + Ttl time.Duration + Host string } type Version struct { @@ -59,6 +52,10 @@ func (v Version) Bytes() []byte { } } +func (m Monitor) UpdateRoutingTable(db string) (RoutingTable, error) { + return getRoutingTable(m.driver, db, m.Host) +} + // Our default Driver configuration provides: // - custom user-agent name // - ability to add in specific list of hosts to use for address resolution @@ -79,7 +76,8 @@ func newConfigurer(hosts []string) func(c *neo4j.Config) { } return addrs } - c.UserAgent = "bolt-proxy/v0" + // TODO: wire into global version string + c.UserAgent = "bolt-proxy/v0.3.0" } } @@ -94,6 +92,8 @@ RETURN [x IN split(head(split(version, "-")), ".") | toInteger(x)] AS version, // which should provide an array of int64s corresponding to the Version. // Return the Version on success, otherwise return an empty Version and // and error. +// +// Note: Aura provides a special version string because Aura is special. func getVersion(driver *neo4j.Driver) (Version, error) { version := Version{} session := (*driver).NewSession(neo4j.SessionConfig{}) @@ -149,8 +149,8 @@ func getVersion(driver *neo4j.Driver) (Version, error) { // Any additional hosts provided will be used as part of a custom address // resolution function via the neo4j.Driver. func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { - c := make(chan *RoutingTable, 1) - h := make(chan bool, 1) + infoChan := make(chan ClusterInfo, 1) + haltChan := make(chan bool, 1) // Try immediately to connect to Neo4j auth := neo4j.BasicAuth(user, password, "") @@ -163,7 +163,6 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { if err != nil { panic(err) } - // log.Printf("found neo4j version %v\n", version) // TODO: check if in SINGLE, CORE, or READ_REPLICA mode // We can run `CALL dbms.listConfig('dbms.mode') YIELD value` and @@ -171,7 +170,7 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { // simplify the monitor considerably to just health checks and no // routing table. - // Get the first routing table and ttl details + // Get the cluster members and ttl details u, err := url.Parse(uri) if err != nil { return nil, err @@ -181,45 +180,49 @@ func NewMonitor(user, password, uri string, hosts ...string) (*Monitor, error) { host = host + ":7687" } - rt, err := getNewRoutingTable(&driver, host) + info, err := getClusterInfo(&driver, host) if err != nil { - panic(err) + return nil, err + } + infoChan <- info + + monitor := Monitor{ + Info: infoChan, + halt: haltChan, + driver: &driver, + Version: version, + Ttl: info.Ttl, // since right now it's not dynamic + Host: host, } - c <- rt - monitor := Monitor{c, h, &driver, version} go func() { - // preset the initial ticker to use the first ttl measurement - ticker := time.NewTicker(rt.Ttl) + // TODO: configurable cluster info update frequency + ticker := time.NewTicker(time.Second * 30) for { select { case <-ticker.C: - rt, err := getNewRoutingTable(monitor.driver, host) + info, err := getClusterInfo(monitor.driver, monitor.Host) if err != nil { + // TODO: how do we handle faults??? panic(err) } - ticker.Reset(rt.Ttl) - + ticker.Reset(time.Second * 30) // empty the channel and put the new value in // this looks odd, but even though it's racy, // it should be racy in a safe way since it // doesn't matter if another go routine takes // the value first select { - case <-c: + case <-infoChan: default: } select { - case c <- rt: + case infoChan <- info: default: panic("monitor channel full") } - case <-h: + case <-haltChan: ticker.Stop() - // log.Println("monitor stopped") - case <-time.After(10 * rt.Ttl): - msg := fmt.Sprintf("monitor timeout of 10*%v reached\n", rt.Ttl) - panic(msg) } } }() @@ -234,112 +237,26 @@ func (m *Monitor) Stop() { } } -// local data structure for passing the raw routing table details -// TODO: ttl could be pulled direct via a check of dbms.routing_ttl -// since it's not a dynamic config value as of v4.2 -type table struct { - db string - ttl time.Duration - readers []string - writers []string -} - // Denormalize the routing table to make post-processing easier const ROUTING_QUERY = ` -UNWIND $names AS name -CALL dbms.routing.getRoutingTable({address: $host}, name) +CALL dbms.routing.getRoutingTable({address: $host}, $db) YIELD ttl, servers -WITH name, ttl, servers UNWIND servers AS server -WITH name, ttl, server UNWIND server["addresses"] AS address -RETURN name, ttl, server["role"] AS role, address +RETURN server["role"] AS role, address ` -// Dump the list of databases. We need to keep this simple to support v4.0, -// v4.1, and v4.2 since this has been a moving target in how it works -const SHOW_DATABASES = "SHOW DATABASES" - -// Use SHOW DATABASES to dump the current list of databases with the first -// database name being the default (based on the query logic) -func queryDbNames(driver *neo4j.Driver) ([]string, error) { - session := (*driver).NewSession(neo4j.SessionConfig{ - DatabaseName: "system", - }) - defer session.Close() - - result, err := session.Run(SHOW_DATABASES, nil) - if err != nil { - return nil, err - } - rows, err := result.Collect() - if err != nil { - return nil, err - } - - // create a basic set structure - nameSet := make(map[string]bool) - for _, row := range rows { - val, found := row.Get("name") - if !found { - return nil, errors.New("couldn't find name field in result") - } - name, ok := val.(string) - if !ok { - panic("name isn't a string") - } - - val, found = row.Get("currentStatus") - if !found { - return nil, errors.New("couldn't find currentStatus field in result") - } - status, ok := val.(string) - if !ok { - panic("currentStatus isn't a string") - } - - if status == "online" { - nameSet[name] = true - } - } - - names := make([]string, 0, len(nameSet)) - for key := range nameSet { - names = append(names, key) - } - return names, nil -} - -func queryRoutingTable(driver *neo4j.Driver, host string, names []string) (map[string]table, error) { - session := (*driver).NewSession(neo4j.SessionConfig{}) - defer session.Close() - - result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { - return routingTableTx(tx, host, names) - }) - if err != nil { - return map[string]table{}, err - } - - tableMap, ok := result.(map[string]table) - if !ok { - return map[string]table{}, errors.New("invalid type for routing table response") - } - - return tableMap, nil -} - // Given a neo4j.Transaction tx, collect the routing table maps for each of // the databases in names. Since this should run in a transaction work function // we return a generic interface{} on success, or nil and an error if failed. // -// The true data type is a map[string]table, mapping database names to their -// respective tables. -func routingTableTx(tx neo4j.Transaction, host string, names []string) (interface{}, error) { - params := make(map[string]interface{}, 1) - params["names"] = names - params["host"] = host - result, err := tx.Run(ROUTING_QUERY, params) +// The true data type is a table struct, mapping providing arrays of readers, +// writers, and routers for the given db +func routingTableTx(tx neo4j.Transaction, host string, db string) (interface{}, error) { + result, err := tx.Run(ROUTING_QUERY, map[string]interface{}{ + "db": db, + "host": host, + }) if err != nil { return nil, err } @@ -349,8 +266,13 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac return nil, err } - // expected fields: [name, ttl, role, address] - tableMap := make(map[string]table, len(rows)) + // expected fields: [role, address] + t := RoutingTable{ + Name: db, + Readers: []string{}, + Writers: []string{}, + Routers: []string{}, + } for _, row := range rows { val, found := row.Get("address") if !found { @@ -361,32 +283,6 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac panic("addr isn't a string!") } - val, found = row.Get("ttl") - if !found { - return nil, errors.New("missing ttl field in result") - } - ttl, ok := val.(int64) - if !ok { - panic("ttl isn't an integer!") - } - - val, found = row.Get("name") - if !found { - return nil, errors.New("missing name field in result") - } - name, ok := val.(string) - if !ok { - panic("name isn't a string!") - } - - t, found := tableMap[name] - if !found { - t = table{ - db: name, - ttl: time.Duration(ttl) * time.Second, - } - } - val, found = row.Get("role") if !found { return nil, errors.New("missing role field in result") @@ -398,71 +294,103 @@ func routingTableTx(tx neo4j.Transaction, host string, names []string) (interfac switch role { case "READ": - t.readers = append(t.readers, addr) + t.Readers = append(t.Readers, addr) case "WRITE": - t.writers = append(t.writers, addr) + t.Writers = append(t.Writers, addr) case "ROUTE": - continue + t.Routers = append(t.Routers, addr) default: return nil, errors.New("invalid role") } - - tableMap[name] = t } - return tableMap, nil + return t, nil } // Using a pointer to a connected neo4j.Driver, orchestrate fetching the -// database names and get the current routing table for each. -// -// XXX: this is pretty heavy weight :-( -func getNewRoutingTable(driver *neo4j.Driver, host string) (*RoutingTable, error) { - names, err := queryDbNames(driver) - if err != nil { - msg := fmt.Sprintf("error getting database names: %v\n", err) - return nil, errors.New(msg) - } +// routing table for a given database while using the provided host +// routing context. +func getRoutingTable(driver *neo4j.Driver, db, host string) (RoutingTable, error) { + session := (*driver).NewSession(neo4j.SessionConfig{}) + defer session.Close() - tableMap, err := queryRoutingTable(driver, host, names) + result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { + return routingTableTx(tx, host, db) + }) if err != nil { - msg := fmt.Sprintf("error getting routing table: %v\n", err) - return nil, errors.New(msg) + return RoutingTable{}, err } - - // build the new routing table instance - // TODO: clean this up...seems smelly.. - readers := make(map[string][]string) - writers := make(map[string][]string) - rt := RoutingTable{ - DefaultDb: names[0], - readers: readers, - writers: writers, - CreatedAt: time.Now(), - Hosts: make(map[string]bool), + table, ok := result.(RoutingTable) + if !ok { + panic("invalid return type: expected RoutingTable") } - for db, t := range tableMap { - r := make([]string, len(t.readers)) - copy(r, t.readers) - w := make([]string, len(t.writers)) - copy(w, t.writers) - rt.readers[db] = r - rt.writers[db] = w - - // yes, this is redundant... - rt.Ttl = t.ttl - - // yes, this is also wasteful...construct host sets - for _, host := range r { - rt.Hosts[host] = true + + return table, nil +} + +// Populate a ClusterInfo instance with critical details on our backend +func getClusterInfo(driver *neo4j.Driver, host string) (ClusterInfo, error) { + session := (*driver).NewSession(neo4j.SessionConfig{ + DatabaseName: "system", + }) + defer session.Close() + + // Inline TX function here for building the ClusterInfo on the fly + result, err := session.ReadTransaction(func(tx neo4j.Transaction) (interface{}, error) { + info := ClusterInfo{CreatedAt: time.Now()} + result, err := tx.Run("SHOW DATABASES", nil) + if err != nil { + return info, err } - for _, host := range w { - rt.Hosts[host] = true + rows, err := result.Collect() + if err != nil { + return info, err + } + + for _, row := range rows { + val, found := row.Get("default") + if !found { + return info, errors.New("missing 'default' field") + } + defaultDb, ok := val.(bool) + if !ok { + return info, errors.New("default field isn't a boolean") + } + if defaultDb { + val, found = row.Get("name") + if !found { + return info, errors.New("missing 'name' field") + } + name, ok := val.(string) + if !ok { + return info, errors.New("name field isn't a string") + } + info.DefaultDb = name + } } + + return info, nil + }) + if err != nil { + return ClusterInfo{}, err } - // log.Printf("updated routing table: %s\n", &rt) - // log.Printf("known hosts look like: %v\n", rt.Hosts) + info, ok := result.(ClusterInfo) + if !ok { + panic("result isn't a ClusterInfo struct") + } - return &rt, nil + // For now get details for System db... + rt, err := getRoutingTable(driver, "system", host) + if err != nil { + return info, err + } + hosts := map[string]bool{} + for _, host := range append(rt.Readers, rt.Writers...) { + hosts[host] = true + } + for host := range hosts { + info.Hosts = append(info.Hosts, host) + } + return info, nil } diff --git a/backend/routing.go b/backend/routing.go index c24a52a..227ebf1 100644 --- a/backend/routing.go +++ b/backend/routing.go @@ -1,52 +1,41 @@ package backend import ( - "errors" "fmt" "time" ) type RoutingTable struct { - readers map[string][]string - writers map[string][]string - Hosts map[string]bool - DefaultDb string - Ttl time.Duration - CreatedAt time.Time + Name string + Readers, Writers, Routers []string + CreatedAt time.Time + Ttl time.Duration } -func (rt *RoutingTable) Expired() bool { - now := time.Now() - return rt.CreatedAt.Add(rt.Ttl).Before(now) +func (t RoutingTable) String() string { + return fmt.Sprintf("DbTable{ Name: %s, "+ + "Readers: %s, "+ + "Writers: %s, "+ + "Routers: %s, "+ + "CreatedAt: %s, "+ + "Ttl: %s}", + t.Name, t.Readers, t.Writers, t.Routers, + t.CreatedAt, t.Ttl) } -func (rt *RoutingTable) String() string { - return fmt.Sprintf( - "RoutingTable{ DefaultDb: %s, "+ - "readerMap: %v, "+ - "writerMap: %v, "+ - "Ttl: %s, "+ - "CreatedAt: %s }", - rt.DefaultDb, rt.readers, rt.writers, rt.Ttl, rt.CreatedAt, - ) +type ClusterInfo struct { + DefaultDb string + Ttl time.Duration + Hosts []string + CreatedAt time.Time } -func (rt *RoutingTable) ReadersFor(db string) ([]string, error) { - readers, found := rt.readers[db] - if found { - result := make([]string, len(readers)) - copy(result, readers) - return result, nil - } - return nil, errors.New("no such database") +func (rt RoutingTable) Expired() bool { + return rt.CreatedAt.Add(rt.Ttl).Before(time.Now()) } -func (rt *RoutingTable) WritersFor(db string) ([]string, error) { - writers, found := rt.writers[db] - if found { - result := make([]string, len(writers)) - copy(result, writers) - return result, nil - } - return nil, errors.New("no such database") +func (i ClusterInfo) String() string { + return fmt.Sprintf( + "ClusterInfo{ DefaultDb: %s, Ttl: %v, Hosts: %v, CreatedAt: %v }", + i.DefaultDb, i.Ttl, i.Hosts, i.CreatedAt) } diff --git a/proxy.go b/proxy.go index e18d401..6184e25 100644 --- a/proxy.go +++ b/proxy.go @@ -277,7 +277,8 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe // get backend connection pool, err := b.Authenticate(hello) if err != nil { - warn.Fatal(err) + warn.Println(err) + return } // TODO: this seems odd...move parser and version stuff to bolt pkg @@ -360,14 +361,17 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe // we need to find a new connection to switch to if startingTx { mode, _ := bolt.ValidateMode(msg.Data) - rt := b.RoutingTable() - db := rt.DefaultDb + info, err := b.ClusterInfo() + if err != nil { + warn.Printf("error getting cluster info: %s\n", err) + return + } + db := info.DefaultDb // get the db name, if any. otherwise, use default var ( - m map[string]interface{} - err error - n int + m map[string]interface{} + n int ) if msg.T == bolt.BeginMsg { m, _, err = bolt.ParseMap(msg.Data[4:]) @@ -416,11 +420,16 @@ func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backe } // Just choose the first one for now...something simple + rt, err := b.RoutingTable(db) + if err != nil { + warn.Printf("error getting routing table for %s: %s\n", db, err) + return + } var hosts []string if mode == bolt.ReadMode { - hosts, err = rt.ReadersFor(db) + hosts = rt.Readers } else { - hosts, err = rt.WritersFor(db) + hosts = rt.Writers } if err != nil { warn.Printf("couldn't find host for '%s' in routing table", db)