diff --git a/README.md b/README.md index ac7a94b..ccbc13f 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ The role of the hologram server must have assume role permissions. See permissi For different projects it is recommended that you create IAM roles for each and have your developers assume these roles for testing the software. Hologram supports a command `hologram use ` which will fetch temporary credentials for this role instead of the default developer one until it is reset or another role is assumed. -You will need to modify the Trusted Entities for each of these roles that you create so that the IAM instance profile you created for the Hologram Server can access them. The hologram user must have permission to assume that role. +You will need to modify the Trusted Entities for each of these roles that you create so that the IAM instance profile you created for the Hologram Server can access them. The hologram user must have permission to assume that role. ```json { @@ -158,7 +158,7 @@ The user must have permission to iam:GetUser on itself(resource "arn:aws:iam::AC ### LDAP Based Roles -Hologram supports assigning roles based on a user's LDAP group. Roles can be turned on by setting the `enableLDAPRoles` key to `true` in `config/server.json`. +Hologram supports assigning roles based on a user's LDAP group. Roles can be turned on by setting the `enableServerRoles` key to `true` in `config/server.json`. An LDAP group attribute will have to be chosen for user roles. By default `businessCategory` is chosen for this role since it is part of the core LDAP schema. The attribute used can be modified by editing the `roleAttribute` key in `config/server.json`. The value of this attribute should be the name of the group's role in AWS. diff --git a/cmd/hologram-server/config.go b/cmd/hologram-server/config.go index fa2ebfd..52562cc 100644 --- a/cmd/hologram-server/config.go +++ b/cmd/hologram-server/config.go @@ -18,22 +18,32 @@ type LDAP struct { DN string `json:"dn"` Password string `json:"password"` } `json:"bind"` - UserAttr string `json:"userattr"` - BaseDN string `json:"basedn"` - Host string `json:"host"` - InsecureLDAP bool `json:"insecureldap"` + UserAttr string `json:"userattr"` + BaseDN string `json:"basedn"` + Host string `json:"host"` EnableLDAPRoles bool `json:"enableldaproles"` + InsecureLDAP bool `json:"insecureldap"` RoleAttribute string `json:"roleattr"` DefaultRoleAttr string `json:"defaultroleattr"` } +type KeysFile struct { + FilePath string `json:"filepath"` + UserAttr string `json:"userattr"` + RoleAttr string `json:"roleattr"` + DefaultRoleAttr string `json:"defaultroleattr"` +} + type Config struct { - LDAP LDAP `json:"ldap"` - AWS struct { + UserStorage string `json:"userstorage"` + LDAP LDAP `json:"ldap"` + KeysFile KeysFile `json:"keysfile"` + AWS struct { Account string `json:"account"` DefaultRole string `json:"defaultrole"` } `json:"aws"` - Stats string `json:"stats"` - Listen string `json:"listen"` - CacheTimeout int `json:"cachetimeout"` + EnableServerRoles bool `json:"enableserverroles"` + Stats string `json:"stats"` + Listen string `json:"listen"` + CacheTimeout int `json:"cachetimeout"` } diff --git a/cmd/hologram-server/main.go b/cmd/hologram-server/main.go index e62eb69..70ac3ed 100644 --- a/cmd/hologram-server/main.go +++ b/cmd/hologram-server/main.go @@ -35,6 +35,11 @@ import ( "github.com/peterbourgon/g2s" ) +const ( + LDAPUserStorage = "ldap" + FileUserStorage = "file" +) + func ConnectLDAP(conf LDAP) (*ldap.Conn, error) { var ldapServer *ldap.Conn var err error @@ -72,14 +77,17 @@ func main() { ldapBindPassword = flag.String("ldapBindPassword", "", "LDAP password for bind.") statsdHost = flag.String("statsHost", "", "Address to send statsd metrics to.") iamAccount = flag.String("iamaccount", "", "AWS Account ID for generating IAM Role ARNs") - enableLDAPRoles = flag.Bool("ldaproles", false, "Enable role support using LDAP directory.") - roleAttribute = flag.String("roleattribute", "", "Group attribute to get role from.") - defaultRoleAttr = flag.String("defaultroleattr", "", "User attribute to check to determine a user's default role.") - defaultRole = flag.String("role", "", "AWS role to assume by default.") - configFile = flag.String("conf", "/etc/hologram/server.json", "Config file to load.") - cacheTimeout = flag.Int("cachetime", 3600, "Time in seconds after which to refresh LDAP user cache.") - debugMode = flag.Bool("debug", false, "Enable debug mode.") - config Config + // Still here for backwards compatibility + enableLDAPRoles = flag.Bool("ldaproles", false, "Enable role support using LDAP directory (DEPRECATED: Use enableServerRoles instead).") + enableServerRoles = flag.Bool("serverRoles", false, "Enable role support using server directory.") + roleAttribute = flag.String("roleattribute", "", "Group attribute to get role from.") + defaultRoleAttr = flag.String("defaultroleattr", "", "User attribute to check to determine a user's default role.") + defaultRole = flag.String("role", "", "AWS role to assume by default.") + userStorage = flag.String("userStorage", LDAPUserStorage, "User storage type (ldap, file)") + configFile = flag.String("conf", "/etc/hologram/server.json", "Config file to load.") + cacheTimeout = flag.Int("cachetime", 3600, "Time in seconds after which to refresh LDAP user cache.") + debugMode = flag.Bool("debug", false, "Enable debug mode.") + config Config ) flag.Parse() @@ -105,6 +113,16 @@ func main() { } // Merge in command flag options. + if config.UserStorage == "" { + config.UserStorage = *userStorage + } + + // Validating user storage value + if config.UserStorage != LDAPUserStorage && config.UserStorage != FileUserStorage { + log.Errorf("Invalid user storage value: %s. Possible values (%s, %s)", config.UserStorage, LDAPUserStorage, FileUserStorage) + os.Exit(1) + } + if *ldapAddress != "" { config.LDAP.Host = *ldapAddress } @@ -137,10 +155,14 @@ func main() { config.AWS.DefaultRole = *defaultRole } - if *enableLDAPRoles { + if *enableServerRoles || *enableLDAPRoles { config.LDAP.EnableLDAPRoles = true } + if config.LDAP.EnableLDAPRoles || *enableLDAPRoles { + config.EnableServerRoles = true + } + if *defaultRoleAttr != "" { config.LDAP.DefaultRoleAttr = *defaultRoleAttr } @@ -177,21 +199,35 @@ func main() { stsConnection := sts.New(session.New(&aws.Config{})) credentialsService := server.NewDirectSessionTokenService(config.AWS.Account, stsConnection) - open := func() (server.LDAPImplementation, error) { return ConnectLDAP(config.LDAP) } - ldapServer, err := server.NewPersistentLDAP(open) - if err != nil { - log.Errorf("Fatal error, exiting: %s", err.Error()) - os.Exit(1) - } + var ( + userCache server.UserCache + userStorageImpl server.UserStorage + ) - ldapCache, err := server.NewLDAPUserCache(ldapServer, stats, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles, config.LDAP.RoleAttribute, config.AWS.DefaultRole, config.LDAP.DefaultRoleAttr) - if err != nil { - log.Errorf("Top-level error in LDAPUserCache layer: %s", err.Error()) - os.Exit(1) + if config.UserStorage == LDAPUserStorage { + open := func() (server.LDAPImplementation, error) { return ConnectLDAP(config.LDAP) } + userStorageImpl, err := server.NewPersistentLDAP(open) + if err != nil { + log.Errorf("Fatal error, exiting: %s", err.Error()) + os.Exit(1) + } + + userCache, err = server.NewLDAPUserCache(userStorageImpl, stats, config.LDAP.UserAttr, config.LDAP.BaseDN, config.EnableServerRoles, config.LDAP.RoleAttribute, config.AWS.DefaultRole, config.LDAP.DefaultRoleAttr) + if err != nil { + log.Errorf("Top-level error in LDAPUserCache layer: %s", err.Error()) + os.Exit(1) + } + } else if config.UserStorage == FileUserStorage { + open := func() ([]byte, error) { return ioutil.ReadFile(config.KeysFile.FilePath) } + dump := func(data []byte) error { + return ioutil.WriteFile(config.KeysFile.FilePath, data, os.FileMode(500)) + } + userStorageImpl := server.NewPersistentKeysFile(open, dump, config.KeysFile.UserAttr, config.KeysFile.RoleAttr) + userCache, _ = server.NewKeysFileUserCache(userStorageImpl, stats, config.EnableServerRoles, config.KeysFile.UserAttr, config.KeysFile.RoleAttr, config.AWS.DefaultRole, config.KeysFile.DefaultRoleAttr) } - serverHandler := server.New(ldapCache, credentialsService, config.AWS.DefaultRole, stats, ldapServer, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles, config.LDAP.DefaultRoleAttr) - server, err := remote.NewServer(config.Listen, serverHandler.HandleConnection) + serverHandler := server.New(userCache, credentialsService, config.AWS.DefaultRole, stats, userStorageImpl, config.EnableServerRoles) + server, _ := remote.NewServer(config.Listen, serverHandler.HandleConnection) // Wait for a signal from the OS to shutdown. terminate := make(chan os.Signal) @@ -227,10 +263,12 @@ WaitForTermination: log.DebugMode(false) case <-reloadCacheSigHup: log.Info("Force-reloading user cache.") - ldapCache.Update() + err := userCache.Update() + log.Errorf("Error while updating cache: %s", err.Error()) case <-cacheTimeoutTicker.C: - log.Info("Cache timeout. Reloading user cache.") - ldapCache.Update() + err := userCache.Update() + log.Errorf("Error while updating cache: %s", err.Error()) + } } diff --git a/protocol/hologram.proto b/protocol/hologram.proto index ac165ec..bebb617 100644 --- a/protocol/hologram.proto +++ b/protocol/hologram.proto @@ -51,7 +51,7 @@ message ServerRequest { SSHChallengeResponse challengeResponse = 5; MFATokenResponse tokenResponse = 6; GetUserCredentials getUserCredentials = 7; - AddSSHKey addSSHkey = 8; + AddSSHKey addSSHkey = 8; } } diff --git a/server/credentials.go b/server/credentials.go index 3a94286..697d558 100644 --- a/server/credentials.go +++ b/server/credentials.go @@ -30,7 +30,7 @@ credentials to calling processes. No caching is done of these results other than that which the CredentialService does itself. */ type CredentialService interface { - AssumeRole(user *User, role string, enableLDAPRoles bool) (*sts.Credentials, error) + AssumeRole(user *User, role string, enableServerRoles bool) (*sts.Credentials, error) } /* @@ -77,12 +77,12 @@ func (s *directSessionTokenService) buildARN(role string) string { return arn } -func (s *directSessionTokenService) AssumeRole(user *User, role string, enableLDAPRoles bool) (*sts.Credentials, error) { +func (s *directSessionTokenService) AssumeRole(user *User, role string, enableServerRoles bool) (*sts.Credentials, error) { var arn string = s.buildARN(role) log.Debug("Checking ARN %s against user %s (with access %s)", arn, user.Username, user.ARNs) - if enableLDAPRoles { + if enableServerRoles { found := false for _, a := range user.ARNs { a = s.buildARN(a) diff --git a/server/persistent_keys_file.go b/server/persistent_keys_file.go new file mode 100644 index 0000000..be45b4d --- /dev/null +++ b/server/persistent_keys_file.go @@ -0,0 +1,98 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" +) + +type KeysMap map[string]map[string]interface{} + +type persistentKeysFile struct { + // Function that return the contents of the file + open func() ([]byte, error) + // Function to dump contents to the file + dump func([]byte) error + + userAttr string + roleAttr string + // Map from public ssh keys to a list of roles + keys KeysMap +} + +func (pkf *persistentKeysFile) Load() error { + fileContent, err := pkf.open() + if err != nil { + return err + } + + var keys KeysMap + + if err := json.Unmarshal(fileContent, &keys); err != nil { + return err + } + + pkf.keys = keys + + return nil +} + +func (pkf *persistentKeysFile) Keys() (KeysMap, error) { + if pkf.keys == nil { + err := pkf.Load() + if err != nil { + return nil, err + } + } + return pkf.keys, nil +} + +func (pkf *persistentKeysFile) Search(username string) (map[string]interface{}, error) { + if pkf.keys == nil { + err := pkf.Load() + if err != nil { + return nil, err + } + } + + data := map[string]interface{}{ + "username": username, + "password": "", + } + + sshPublicKeys := []string{} + + found := false + for key, userData := range pkf.keys { + u, _ := userData[pkf.userAttr] + user, _ := u.(string) + password, _ := userData["password"] + passwordHash, _ := password.(string) + if user == username { + sshPublicKeys = append(sshPublicKeys, key) + data["password"] = passwordHash + found = true + } + } + if found { + data["sshPublicKeys"] = sshPublicKeys + return data, nil + } + + return nil, errors.New(fmt.Sprintf("User %s not found!", username)) +} + +func (pkf *persistentKeysFile) SearchUser(userData map[string]string) (map[string]interface{}, error) { + return pkf.Search(userData["username"]) +} + +func (pkf *persistentKeysFile) Modify(username, sshPublicKey string) error { + pkf.keys[sshPublicKey] = map[string]interface{}{"username": username} + + keysBytes, _ := json.Marshal(pkf.keys) + return pkf.dump(keysBytes) // Dump contents of keys +} + +func NewPersistentKeysFile(open func() ([]byte, error), dump func([]byte) error, userAttr, roleAttr string) KeysFile { + return &persistentKeysFile{open: open, dump: dump, userAttr: userAttr, roleAttr: roleAttr} +} diff --git a/server/persistent_keys_file_test.go b/server/persistent_keys_file_test.go new file mode 100644 index 0000000..1cbf683 --- /dev/null +++ b/server/persistent_keys_file_test.go @@ -0,0 +1,58 @@ +package server_test + +import ( + "sort" + "testing" + + "github.com/AdRoll/hologram/server" + . "github.com/smartystreets/goconvey/convey" +) + +func TestPersistentKeysFile(t *testing.T) { + data := `{ + "KEY1": {"username": "user1", "password": "pass1", "roles": ["role1", "role11"]}, + "KEY2": {"username": "user2", "password": "pass2", "roles": ["role2", "role22"]}, + "KEY3": {"username": "user1", "password": "pass1", "roles": ["role111", "role1111"]} + }` + + open := func() ([]byte, error) { + return []byte(data), nil + } + + dump := func([]byte) error { + return nil + } + + Convey("Given data from keys file", t, func() { + Convey("Content from file should be loaded correctly", func() { + pkf := server.NewPersistentKeysFile(open, dump, "username", "roles") + err := pkf.Load() + So(err, ShouldBeNil) + }) + + Convey("An existing key in file should be found", func() { + pkf := server.NewPersistentKeysFile(open, dump, "username", "roles") + + keys := []string{"KEY3", "KEY1"} + sort.Strings(keys) + expected := map[string]interface{}{ + "username": "user1", + "sshPublicKeys": keys, + "password": "pass1", + } + actual, err := pkf.Search("user1") + actualKeys := actual["sshPublicKeys"] + sort.Strings(actualKeys.([]string)) + So(err, ShouldBeNil) + So(actual, ShouldResemble, expected) + }) + + Convey("An non existing key in file shouldn't be found", func() { + pkf := server.NewPersistentKeysFile(open, dump, "username", "roles") + + user, err := pkf.Search("missing user") + So(err, ShouldNotBeNil) + So(user, ShouldBeNil) + }) + }) +} diff --git a/server/persistent_ldap.go b/server/persistent_ldap.go index d812876..cba8934 100644 --- a/server/persistent_ldap.go +++ b/server/persistent_ldap.go @@ -1,12 +1,17 @@ package server import ( + "errors" + "fmt" + "github.com/nmcclain/ldap" ) type persistentLDAP struct { - open func() (LDAPImplementation, error) - conn LDAPImplementation + open func() (LDAPImplementation, error) + conn LDAPImplementation + baseDN string + userAttr string } func (pl *persistentLDAP) Refresh() error { @@ -28,8 +33,30 @@ func (pl *persistentLDAP) Search(searchRequest *ldap.SearchRequest) (*ldap.Searc } } +func (pl *persistentLDAP) SearchUser(userData map[string]string) (map[string]interface{}, error) { + sr := ldap.NewSearchRequest( + pl.baseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(%s=%s)", pl.userAttr, userData["username"]), + []string{"sshPublicKey", pl.userAttr, "userPassword"}, + nil) + r, err := pl.Search(sr) + if err != nil { + return nil, err + } + + if len(r.Entries) == 0 { + return nil, errors.New(fmt.Sprintf("User %s not found!", userData["username"])) + } + + return map[string]interface{}{ + "password": r.Entries[0].GetAttributeValue("userPassword"), + "sshPublicKeys": r.Entries[0].GetAttributeValues("sshPublicKey"), + }, nil +} + func (pl *persistentLDAP) Modify(modifyRequest *ldap.ModifyRequest) error { - if err := pl.conn.Modify(modifyRequest); err != nil && err.(*ldap.Error).ResultCode == ldap.ErrorNetwork { + if err := pl.conn.Modify(modifyRequest); err != nil && err.(*ldap.Error).ResultCode == ldap.ErrorNetwork { pl.Refresh() return pl.conn.Modify(modifyRequest) } else { @@ -37,6 +64,12 @@ func (pl *persistentLDAP) Modify(modifyRequest *ldap.ModifyRequest) error { } } +func (pl *persistentLDAP) ModifyUser(data map[string]string) error { + mr := ldap.NewModifyRequest(data["DN"]) + mr.Add("sshPublicKey", []string{data["sshPublicKey"]}) + return pl.Modify(mr) +} + func NewPersistentLDAP(open func() (LDAPImplementation, error)) (LDAPImplementation, error) { conn, err := open() if err != nil { diff --git a/server/persistent_ldap_test.go b/server/persistent_ldap_test.go index 18be030..88703ea 100644 --- a/server/persistent_ldap_test.go +++ b/server/persistent_ldap_test.go @@ -12,7 +12,7 @@ import ( // A server that fails after every call to Search/Modify! type FallibleLDAPServer struct { underlying *StubLDAPServer - dead bool + dead bool } func (fls *FallibleLDAPServer) Search(s *ldap.SearchRequest) (*ldap.SearchResult, error) { @@ -31,7 +31,6 @@ func (fls *FallibleLDAPServer) Modify(m *ldap.ModifyRequest) error { return fls.underlying.Modify(m) } - func TestPersistentLDAP(t *testing.T) { connWillFail := false @@ -59,7 +58,7 @@ func TestPersistentLDAP(t *testing.T) { expected, err := s.Search(nil) So(err, ShouldBeNil) actual, err := ldapServer.Search(nil) - So(err, ShouldBeNil) + So(err, ShouldBeNil) So(expected, ShouldResemble, actual) }) @@ -67,7 +66,7 @@ func TestPersistentLDAP(t *testing.T) { expected, err := s.Search(nil) So(err, ShouldBeNil) actual, err := ldapServer.Search(nil) - So(err, ShouldBeNil) + So(err, ShouldBeNil) So(expected, ShouldResemble, actual) }) @@ -81,7 +80,7 @@ func TestPersistentLDAP(t *testing.T) { Convey("An initially broken connection to an LDAP server should fail fast", t, func() { ldapServer, err = server.NewPersistentLDAP(open) - So(err, ShouldNotBeNil) + So(err, ShouldNotBeNil) So(ldapServer, ShouldBeNil) }) } diff --git a/server/server.go b/server/server.go index 3207902..9645c3f 100644 --- a/server/server.go +++ b/server/server.go @@ -23,7 +23,6 @@ import ( "github.com/AdRoll/hologram/log" "github.com/AdRoll/hologram/protocol" "github.com/aws/aws-sdk-go/service/sts" - "github.com/nmcclain/ldap" "github.com/peterbourgon/g2s" "golang.org/x/crypto/ssh" ) @@ -37,16 +36,18 @@ server is a wrapper for all of the connection and message handlers that this server implements. */ type server struct { - authenticator Authenticator - userCache UserCache - credentials CredentialService - stats g2s.Statter - defaultRole string - ldapServer LDAPImplementation - userAttr string - baseDN string - enableLDAPRoles bool - defaultRoleAttr string + authenticator Authenticator + userCache UserCache + credentials CredentialService + stats g2s.Statter + defaultRole string + userStorage UserStorage + enableServerRoles bool +} + +type UserStorage interface { + SearchUser(map[string]string) (map[string]interface{}, error) + ModifyUser(map[string]string) error } /* @@ -111,11 +112,11 @@ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *prot } if user != nil { - creds, err := sm.credentials.AssumeRole(user, role, sm.enableLDAPRoles) + creds, err := sm.credentials.AssumeRole(user, role, sm.enableServerRoles) if err != nil { // Update user cache and try again sm.userCache.Update() - creds, err := sm.credentials.AssumeRole(user, role, sm.enableLDAPRoles) + creds, err := sm.credentials.AssumeRole(user, role, sm.enableServerRoles) if err != nil { // error message from Amazon, so forward that on to the client @@ -128,7 +129,7 @@ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *prot sm.stats.Counter(1.0, "errors.assumeRole", 1) // Attempt to use the default role to fall back - creds, err = sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableLDAPRoles) + creds, err = sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableServerRoles) if err == nil { m.Write(makeCredsResponse(creds)) } @@ -148,12 +149,16 @@ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *prot } if user != nil { - creds, err := sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableLDAPRoles) + creds, err := sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableServerRoles) if err != nil { log.Errorf("Error trying to handle GetUserCredentials: %s", err.Error()) // Update user cache and try again - sm.userCache.Update() - creds, err = sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableLDAPRoles) + err := sm.userCache.Update() + if err != nil { + log.Errorf("Error trying to update cache: %s", err.Error()) + return + } + creds, err = sm.credentials.AssumeRole(user, user.DefaultRole, sm.enableServerRoles) if err != nil { errStr := fmt.Sprintf("Could not get user credentials. %s may not have been given Hologram access yet.", user.Username) errMsg := &protocol.Message{ @@ -169,35 +174,29 @@ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *prot } } else if addSSHKeyMsg := r.GetAddSSHkey(); addSSHKeyMsg != nil { sm.stats.Counter(1.0, "messages.addSSHKeyMsg", 1) - // Search for the user specified in this request. - sr := ldap.NewSearchRequest( - sm.baseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - fmt.Sprintf("(%s=%s)", sm.userAttr, addSSHKeyMsg.GetUsername()), - []string{"sshPublicKey", sm.userAttr, "userPassword"}, - nil) - - user, err := sm.ldapServer.Search(sr) + userData := map[string]string{ + "username": addSSHKeyMsg.GetUsername(), + } + resp, err := sm.userStorage.SearchUser(userData) if err != nil { log.Errorf("Error trying to handle addSSHKeyMsg: %s", err.Error()) return } - if len(user.Entries) == 0 { - log.Errorf("User %s not found!", addSSHKeyMsg.GetUsername()) - return - } - // Check their password. - password := user.Entries[0].GetAttributeValue("userPassword") - if password != addSSHKeyMsg.GetPasswordhash() { + password := resp["password"] + passwordHash, _ := password.(string) + if passwordHash != addSSHKeyMsg.GetPasswordhash() { log.Errorf("Provided password for user %s does not match %s!", addSSHKeyMsg.GetUsername(), password) return } + sshKeys := resp["sshPublicKeys"] + sshPublicKeys, _ := sshKeys.([]string) + // Check to see if this SSH key already exists. - for _, k := range user.Entries[0].GetAttributeValues("sshPublicKey") { + for _, k := range sshPublicKeys { if k == addSSHKeyMsg.GetSshkeybytes() { log.Warning("User %s already has this SSH key. Doing nothing.", addSSHKeyMsg.GetUsername()) successMsg := &protocol.Message{Success: &protocol.Success{}} @@ -206,11 +205,15 @@ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *prot } } - mr := ldap.NewModifyRequest(user.Entries[0].DN) - mr.Add("sshPublicKey", []string{addSSHKeyMsg.GetSshkeybytes()}) - err = sm.ldapServer.Modify(mr) + modifyData := map[string]string{} + for k, v := range resp { + strVal, _ := v.(string) + modifyData[k] = strVal + } + modifyData["sshPublicKey"] = addSSHKeyMsg.GetSshkeybytes() + err = sm.userStorage.ModifyUser(modifyData) if err != nil { - log.Errorf("Could not modify LDAP user: %s", err.Error()) + log.Errorf("Could not modify user: %s", err.Error()) return } @@ -307,21 +310,15 @@ func New(userCache UserCache, credentials CredentialService, defaultRole string, stats g2s.Statter, - ldapServer LDAPImplementation, - userAttr string, - baseDN string, - enableLDAPRoles bool, - defaultRoleAttr string) *server { + userStorage UserStorage, + enableServerRoles bool) *server { return &server{ - credentials: credentials, - authenticator: userCache, - userCache: userCache, - defaultRole: defaultRole, - stats: stats, - ldapServer: ldapServer, - userAttr: userAttr, - baseDN: baseDN, - enableLDAPRoles: enableLDAPRoles, - defaultRoleAttr: defaultRoleAttr, + credentials: credentials, + authenticator: userCache, + userCache: userCache, + defaultRole: defaultRole, + stats: stats, + userStorage: userStorage, + enableServerRoles: enableServerRoles, } } diff --git a/server/server_test.go b/server/server_test.go index f0b367a..e61d028 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -76,7 +76,7 @@ func (*dummyCredentials) GetSessionToken(user *server.User) (*sts.Credentials, e }, nil } -func (*dummyCredentials) AssumeRole(user *server.User, role string, enableLDAPRoles bool) (*sts.Credentials, error) { +func (*dummyCredentials) AssumeRole(user *server.User, role string, enableServerRoles bool) (*sts.Credentials, error) { accessKey := "access_key" secretKey := "secret" token := "token" @@ -119,6 +119,14 @@ func (l *DummyLDAP) Search(*ldap.SearchRequest) (*ldap.SearchResult, error) { }, nil } +func (l *DummyLDAP) SearchUser(userData map[string]string) (map[string]interface{}, error) { + r, _ := l.Search(nil) + return map[string]interface{}{ + "password": r.Entries[0].GetAttributeValue("userPassword"), + "sshPublicKeys": r.Entries[0].GetAttributeValues("sshPublicKey"), + }, nil +} + func (l *DummyLDAP) Modify(mr *ldap.ModifyRequest) error { if reflect.DeepEqual(mr, l.req) { l.sshKeys = []string{"test"} @@ -126,6 +134,10 @@ func (l *DummyLDAP) Modify(mr *ldap.ModifyRequest) error { return nil } +func (l *DummyLDAP) ModifyUser(data map[string]string) error { + return l.Modify(l.req) +} + func TestServerStateMachine(t *testing.T) { // This silly thing is needed for equality testing for the LDAP dummy. neededModifyRequest := ldap.NewModifyRequest("something") @@ -139,7 +151,7 @@ func TestServerStateMachine(t *testing.T) { sshKeys: []string{}, req: neededModifyRequest, } - testServer := server.New(authenticator, &dummyCredentials{}, "default", g2s.Noop(), ldap, "cn", "dc=testdn,dc=com", false, "") + testServer := server.New(authenticator, &dummyCredentials{}, "default", g2s.Noop(), ldap, false) r, w := io.Pipe() testConnection := protocol.NewMessageConnection(ReadWriter(r, w)) diff --git a/server/usercache.go b/server/usercache.go index 89e0f70..8bc07ae 100644 --- a/server/usercache.go +++ b/server/usercache.go @@ -56,16 +56,16 @@ type LDAPImplementation interface { ldapUserCache connects to LDAP and pulls user settings from it. */ type ldapUserCache struct { - users map[string]*User - groups map[string][]string - server LDAPImplementation - stats g2s.Statter - userAttr string - baseDN string - enableLDAPRoles bool - roleAttribute string - defaultRole string - defaultRoleAttr string + users map[string]*User + groups map[string][]string + server LDAPImplementation + stats g2s.Statter + userAttr string + baseDN string + enableServerRoles bool + roleAttribute string + defaultRole string + defaultRoleAttr string } /* @@ -77,7 +77,7 @@ been recently added to LDAP work, instead of requiring a server restart. */ func (luc *ldapUserCache) Update() error { start := time.Now() - if luc.enableLDAPRoles { + if luc.enableServerRoles { groupSearchRequest := ldap.NewSearchRequest( luc.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, @@ -132,7 +132,7 @@ func (luc *ldapUserCache) Update() error { userDefaultRole := luc.defaultRole arns := []string{} - if luc.enableLDAPRoles { + if luc.enableServerRoles { userDefaultRole = entry.GetAttributeValue(luc.defaultRoleAttr) if userDefaultRole == "" { userDefaultRole = luc.defaultRole @@ -176,9 +176,6 @@ func (luc *ldapUserCache) _verify(username string, challenge []byte, sshSig *ssh return nil, nil } -/* - - */ func (luc *ldapUserCache) Authenticate(username string, challenge []byte, sshSig *ssh.Signature) ( *User, error) { // Loop through all of the keys and attempt verification. @@ -189,7 +186,10 @@ func (luc *ldapUserCache) Authenticate(username string, challenge []byte, sshSig luc.stats.Counter(1.0, "ldapCacheMiss", 1) // We should update LDAP cache again to retry keys. - luc.Update() + err := luc.Update() + if err != nil { + return nil, err + } return luc._verify(username, challenge, sshSig) } return retUser, nil @@ -198,18 +198,18 @@ func (luc *ldapUserCache) Authenticate(username string, challenge []byte, sshSig /* NewLDAPUserCache returns a properly-configured LDAP cache. */ -func NewLDAPUserCache(server LDAPImplementation, stats g2s.Statter, userAttr string, baseDN string, enableLDAPRoles bool, roleAttribute string, defaultRole string, defaultRoleAttr string) (*ldapUserCache, error) { +func NewLDAPUserCache(server LDAPImplementation, stats g2s.Statter, userAttr string, baseDN string, enableServerRoles bool, roleAttribute string, defaultRole string, defaultRoleAttr string) (*ldapUserCache, error) { retCache := &ldapUserCache{ - users: map[string]*User{}, - groups: map[string][]string{}, - server: server, - stats: stats, - userAttr: userAttr, - baseDN: baseDN, - enableLDAPRoles: enableLDAPRoles, - roleAttribute: roleAttribute, - defaultRole: defaultRole, - defaultRoleAttr: defaultRoleAttr, + users: map[string]*User{}, + groups: map[string][]string{}, + server: server, + stats: stats, + userAttr: userAttr, + baseDN: baseDN, + enableServerRoles: enableServerRoles, + roleAttribute: roleAttribute, + defaultRole: defaultRole, + defaultRoleAttr: defaultRoleAttr, } updateError := retCache.Update() @@ -217,3 +217,139 @@ func NewLDAPUserCache(server LDAPImplementation, stats g2s.Statter, userAttr str // Start updating the user cache. return retCache, updateError } + +type KeysFile interface { + Search(string) (map[string]interface{}, error) + Load() error + Keys() (KeysMap, error) +} + +/* + keysFileUserCache read the file that contains public ssh keys and user info + . +*/ +type keysFileUserCache struct { + users map[string]*User + stats g2s.Statter + keysFile KeysFile + userAttr string + enableServerRoles bool + roleAttr string + defaultRole string + defaultRoleAttr string +} + +func (kfuc *keysFileUserCache) Update() error { + start := time.Now() + + users := map[string]*User{} + seenRoles := map[[2]string]bool{} + + err := kfuc.keysFile.Load() // Load keys from file + if err != nil { + return err + } + keys, err := kfuc.keysFile.Keys() + + if err != nil { + return err + } + + for key, userData := range keys { + username := userData[kfuc.userAttr].(string) + defaultRole, ok := userData[kfuc.defaultRoleAttr].(string) + if !ok || defaultRole == "" { + defaultRole = kfuc.defaultRole + } + user, found := users[username] + if !found { // Create a new user in the cache if doesn't exist + user = &User{ + Username: username, + SSHKeys: []ssh.PublicKey{}, + ARNs: []string{}, + DefaultRole: defaultRole, + } + } + + sshKeyBytes, _ := base64.StdEncoding.DecodeString(key) + sshKey, err := ssh.ParsePublicKey(sshKeyBytes) + if err != nil { + sshKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(key)) + if err != nil { + log.Warning("SSH key parsing for user %s failed (key was '%s')! This key will not be added into the keys file.", username, key) + continue + } + } + user.SSHKeys = append(user.SSHKeys, sshKey) + + if kfuc.enableServerRoles { + roles := userData[kfuc.roleAttr].([]interface{}) + for _, r := range roles { + role := r.(string) + if seenRoles[[2]string{username, role}] { + continue + } + user.ARNs = append(user.ARNs, role) + seenRoles[[2]string{username, role}] = true + } + } + + // Create/Update user + users[username] = user + } + + kfuc.users = users + + log.Debug("Keys file information re-cached.") + kfuc.stats.Timing(1.0, "keysFileCacheUpdate", time.Since(start)) + + return nil +} + +func (kfuc *keysFileUserCache) Users() map[string]*User { + return kfuc.users +} + +func (kfuc *keysFileUserCache) verify(challenge []byte, sshSig *ssh.Signature) (*User, error) { + for _, user := range kfuc.users { + for _, sshKey := range user.SSHKeys { + if err := sshKey.Verify(challenge, sshSig); err == nil { + return user, nil + } + } + } + return nil, nil +} + +func (kfuc *keysFileUserCache) Authenticate(username string, challenge []byte, sshSig *ssh.Signature) (*User, error) { + user, _ := kfuc.verify(challenge, sshSig) + + if user == nil { + log.Debug("Could not find %s in the keys file cache; updating from the file.", username) + kfuc.stats.Counter(1.0, "keysFileCacheMiss", 1) + + // We should update keys file cache again to retry keys. + err := kfuc.Update() + if err != nil { + return nil, err + } + return kfuc.verify(challenge, sshSig) + } + return user, nil +} + +func NewKeysFileUserCache(keysFile KeysFile, stats g2s.Statter, enableServerRoles bool, userAttr string, roleAttr string, defaultRole string, defaultRoleAttr string) (*keysFileUserCache, error) { + kfuc := &keysFileUserCache{ + users: map[string]*User{}, + stats: stats, + keysFile: keysFile, + userAttr: userAttr, + enableServerRoles: enableServerRoles, + roleAttr: roleAttr, + defaultRole: defaultRole, + defaultRoleAttr: defaultRoleAttr, + } + + err := kfuc.Update() + return kfuc, err +} diff --git a/server/usercache_test.go b/server/usercache_test.go index a13d152..4447f0d 100644 --- a/server/usercache_test.go +++ b/server/usercache_test.go @@ -217,3 +217,153 @@ func TestLDAPUserCache(t *testing.T) { }) }) } + +type stubKeysFile struct { + KeysData server.KeysMap +} + +func (skf *stubKeysFile) Load() error { + return nil +} + +func (skf *stubKeysFile) Keys() (server.KeysMap, error) { + return skf.KeysData, nil +} + +func (skf *stubKeysFile) Search(sshKey string) (map[string]interface{}, error) { + return nil, nil +} + +func TestKeysFileUserCache(t *testing.T) { + Convey("Given an keys file user cache connected to our server", t, func() { + // The SSH agent stuff was moved up here so that we can use it to + // dynamically create the LDAP result object. + sshSock := os.Getenv("SSH_AUTH_SOCK") + if sshSock == "" { + t.Skip() + } + + c, err := net.Dial("unix", sshSock) + if err != nil { + t.Fatal(err) + } + agent := agent.NewClient(c) + keys, err := agent.List() + if err != nil { + t.Fatal(err) + } + + keyValue := base64.StdEncoding.EncodeToString(keys[0].Blob) + + // Load in an additional key from the test data. + privateKey, _ := ssh.ParsePrivateKey(testKey) + testPublicKey := base64.StdEncoding.EncodeToString(privateKey.PublicKey().Marshal()) + + skf := &stubKeysFile{ + KeysData: map[string]map[string]interface{}{ + keyValue: { + "username": "user1", + "roles": []interface{}{"role1", "role11"}, + }, + testPublicKey: { + "username": "user1", + "roles": []interface{}{"role1", "role11"}, + }, + }, + } + + kfuc, err := server.NewKeysFileUserCache(skf, g2s.Noop(), false, "username", "roles", "", "") + So(err, ShouldBeNil) + So(kfuc, ShouldNotBeNil) + + Convey("It should retrieve users from file", func() { + So(kfuc.Users(), ShouldNotBeEmpty) + }) + + Convey("It should verify the current user positively.", func() { + success := false + + for i := 0; i < len(keys); i++ { + challenge := randomBytes(64) + sig, err := agent.Sign(keys[i], challenge) + if err != nil { + t.Fatal(err) + } + verifiedUser, err := kfuc.Authenticate("ericallen", challenge, sig) + success = success || (verifiedUser != nil) + } + + So(success, ShouldEqual, true) + }) + + Convey("When a user is requested that cannot be found in the cache", func() { + // Use an SSH key we're guaranteed to not have. + oldData := skf.KeysData[keyValue] + delete(skf.KeysData, keyValue) + kfuc.Update() + + // Swap the key back and try verifying. + // We should still get a result back. + skf.KeysData[keyValue] = oldData + success := false + + for i := 0; i < len(keys); i++ { + challenge := randomBytes(64) + sig, err := agent.Sign(keys[i], challenge) + if err != nil { + t.Fatal(err) + } + verifiedUser, err := kfuc.Authenticate("ericallen", challenge, sig) + success = success || (verifiedUser != nil) + } + + Convey("Then it should update from file again and find the user.", func() { + So(success, ShouldEqual, true) + }) + }) + + Convey("When a user with multiple SSH keys assigned tries to use Hologram", func() { + Convey("The system should allow them to use any key.", func() { + success := false + + for i := 0; i < len(keys); i++ { + challenge := randomBytes(64) + sig, err := privateKey.Sign(cryptrand.Reader, challenge) + if err != nil { + t.Fatal(err) + } + verifiedUser, err := kfuc.Authenticate("ericallen", challenge, sig) + success = success || (verifiedUser != nil) + } + + So(success, ShouldEqual, true) + + }) + }) + + testAuthorizedKey := string(ssh.MarshalAuthorizedKey(privateKey.PublicKey())) + + skf = &stubKeysFile{ + KeysData: map[string]map[string]interface{}{ + testAuthorizedKey: { + "username": "user1", + "roles": []interface{}{"role1", "role11"}, + }, + }, + } + kfuc, err = server.NewKeysFileUserCache(skf, g2s.Noop(), false, "username", "roles", "", "") + So(err, ShouldBeNil) + So(kfuc, ShouldNotBeNil) + + Convey("The usercache should understand the SSH authorized_keys format", func() { + challenge := randomBytes(64) + sig, err := privateKey.Sign(cryptrand.Reader, challenge) + if err != nil { + t.Fatal(err) + } + verifiedUser, err := kfuc.Authenticate("ericallen", challenge, sig) + So(verifiedUser, ShouldNotBeNil) + So(err, ShouldBeNil) + }) + }) +}