From ca9a1a9d7f60801b3d706efb1e806e0863ce2443 Mon Sep 17 00:00:00 2001 From: mattbr0wn Date: Sat, 17 Aug 2024 18:00:57 +0100 Subject: [PATCH] enhanced error handling, reduce number of dns calls --- internal/dns/{spf.go => authorizedSenders.go} | 30 +---- internal/dns/dns.go | 98 ++++++++++++++++ internal/dns/emailProvider.go | 29 +++++ internal/dns/known_email_providers.toml | 1 + internal/dns/mx.go | 106 ------------------ internal/mailserver/mailserver.go | 27 +++-- internal/run/run.go | 6 + mailvalidate/validation.go | 34 ++++-- 8 files changed, 176 insertions(+), 155 deletions(-) rename internal/dns/{spf.go => authorizedSenders.go} (67%) create mode 100644 internal/dns/dns.go create mode 100644 internal/dns/emailProvider.go delete mode 100644 internal/dns/mx.go diff --git a/internal/dns/spf.go b/internal/dns/authorizedSenders.go similarity index 67% rename from internal/dns/spf.go rename to internal/dns/authorizedSenders.go index e7b5e88..e6e5a84 100644 --- a/internal/dns/spf.go +++ b/internal/dns/authorizedSenders.go @@ -1,12 +1,8 @@ package dns import ( - "fmt" - "net" "regexp" "strings" - - "github.com/customeros/mailsherpa/internal/syntax" ) type AuthorizedSenders struct { @@ -17,29 +13,11 @@ type AuthorizedSenders struct { Other []string } -func GetAuthorizedSenders(email string, knownProviders *KnownProviders) (AuthorizedSenders, error) { - spfRecord, err := getSPFRecord(email) - if err != nil { - return AuthorizedSenders{}, fmt.Errorf("error getting SPF record: %w", err) - } - return processIncludes(spfRecord, knownProviders), nil -} - -func getSPFRecord(email string) (string, error) { - _, domain, ok := syntax.GetEmailUserAndDomain(email) - if !ok { - return "", fmt.Errorf("invalid email address") - } - records, err := net.LookupTXT(domain) - if err != nil { - return "", fmt.Errorf("error looking up TXT records: %w", err) - } - for _, record := range records { - if strings.HasPrefix(strings.TrimSpace(record), "v=spf1") { - return record, nil - } +func GetAuthorizedSenders(dns DNS, knownProviders *KnownProviders) AuthorizedSenders { + if dns.SPF == "" { + return AuthorizedSenders{} } - return "", fmt.Errorf("no SPF record found for domain %s", domain) + return processIncludes(dns.SPF, knownProviders) } func processIncludes(spfRecord string, knownProviders *KnownProviders) AuthorizedSenders { diff --git a/internal/dns/dns.go b/internal/dns/dns.go new file mode 100644 index 0000000..95a0550 --- /dev/null +++ b/internal/dns/dns.go @@ -0,0 +1,98 @@ +package dns + +import ( + "fmt" + "net" + "sort" + "strings" + + "github.com/customeros/mailsherpa/internal/syntax" +) + +type DNS struct { + MX []string + SPF string + Errors []string +} + +func GetDNS(email string) DNS { + var dns DNS + var mxErr error + var spfErr error + + dns.MX, mxErr = getMXRecordsForEmail(email) + dns.SPF, spfErr = getSPFRecord(email) + if mxErr != nil { + dns.Errors = append(dns.Errors, mxErr.Error()) + } + if spfErr != nil { + dns.Errors = append(dns.Errors, spfErr.Error()) + } + return dns +} + +func getMXRecordsForEmail(email string) ([]string, error) { + mxRecords, err := getRawMXRecords(email) + if err != nil { + return nil, err + } + + // Sort MX records by priority (lower number = higher priority) + sort.Slice(mxRecords, func(i, j int) bool { + return mxRecords[i].Pref < mxRecords[j].Pref + }) + + stripDot := func(s string) string { + return strings.ToLower(strings.TrimSuffix(s, ".")) + } + + // Extract hostnames into a string array + result := make([]string, len(mxRecords)) + for i, mx := range mxRecords { + result[i] = stripDot(mx.Host) + } + + return result, nil +} + +func getRawMXRecords(email string) ([]*net.MX, error) { + _, domain, ok := syntax.GetEmailUserAndDomain(email) + if !ok { + return nil, fmt.Errorf("Invalid domain") + } + + mxRecords, err := net.LookupMX(domain) + if err != nil { + return nil, err + } + + return mxRecords, nil +} + +func getSPFRecord(email string) (string, error) { + _, domain, ok := syntax.GetEmailUserAndDomain(email) + if !ok { + return "", fmt.Errorf("invalid email address") + } + records, err := net.LookupTXT(domain) + if err != nil { + return "", fmt.Errorf("error looking up TXT records: %w", err) + } + for _, record := range records { + spfRecord := parseTXTRecord(record) + if strings.HasPrefix(spfRecord, "v=spf1") { + return spfRecord, nil + } + } + return "", fmt.Errorf("no SPF record found for domain %s", domain) +} + +func parseTXTRecord(record string) string { + // Remove surrounding quotes if present + record = strings.Trim(record, "\"") + + // Replace multiple spaces with a single space + record = strings.Join(strings.Fields(record), " ") + + return record +} diff --git a/internal/dns/emailProvider.go b/internal/dns/emailProvider.go new file mode 100644 index 0000000..1b37df6 --- /dev/null +++ b/internal/dns/emailProvider.go @@ -0,0 +1,29 @@ +package dns + +func GetEmailProviderFromMx(dns DNS, knownProviders KnownProviders) (emailProvider, firewall string) { + if len(dns.MX) == 0 { + return "", "" + } + for _, record := range dns.MX { + domain := extractRootDomain(record) + provider, category := knownProviders.GetProviderByDomain(domain) + if provider == "" { + return domain, "" + } + + switch category { + case "enterprise": + return provider, "" + case "webmail": + return provider, "" + case "hosting": + return provider, "" + case "security": + return "", provider + default: + return "", "" + } + } + + return "", "" +} diff --git a/internal/dns/known_email_providers.toml b/internal/dns/known_email_providers.toml index d69f158..af006ad 100644 --- a/internal/dns/known_email_providers.toml +++ b/internal/dns/known_email_providers.toml @@ -93,6 +93,7 @@ domains = [ ["encrypttitan.net", "encrypt titan"], ["forcepoint.com", "forcepoint"], ["greathorn.com", "greathorn"], + ["iphmx.com", "cisco ironport"], ["mailcontrol.com", "mail control"], ["messagelabs.com", "broadcom"], ["mimecast.com", "mimecast"], diff --git a/internal/dns/mx.go b/internal/dns/mx.go deleted file mode 100644 index 5be8cd1..0000000 --- a/internal/dns/mx.go +++ /dev/null @@ -1,106 +0,0 @@ -package dns - -import ( - "fmt" - "net" - "sort" - "strings" - - "github.com/customeros/mailsherpa/internal/syntax" -) - -func GetEmailProviderFromMx(email string, knownProviders KnownProviders) (string, error) { - mx, err := GetMXRecordsForEmail(email) - if err != nil { - return "", err - } - - for _, record := range mx { - domain := extractRootDomain(record) - provider, category := knownProviders.GetProviderByDomain(domain) - if provider == "" { - return domain, nil - } - - switch category { - case "enterprise": - return provider, nil - case "webmail": - return provider, nil - case "hosting": - return provider, nil - default: - if category == "security" { - return "unknown", nil - } - return domain, nil - } - } - - return "", nil -} - -func GetMXRecordsForEmail(email string) ([]string, error) { - mxRecords, err := getRawMXRecords(email) - if err != nil { - return nil, err - } - - // Sort MX records by priority (lower number = higher priority) - sort.Slice(mxRecords, func(i, j int) bool { - return mxRecords[i].Pref < mxRecords[j].Pref - }) - - stripDot := func(s string) string { - return strings.ToLower(strings.TrimSuffix(s, ".")) - } - - // Extract hostnames into a string array - result := make([]string, len(mxRecords)) - for i, mx := range mxRecords { - result[i] = stripDot(mx.Host) - } - - return result, nil -} - -func getEmailServiceProviderFromMX(mxRecords []string) string { - if len(mxRecords) == 0 { - return "" - } - - // Use the first MX record as a reference - parts := strings.Split(mxRecords[0], ".") - numParts := len(parts) - - if numParts < 2 { - return "" - } - - // Start with the last two parts as the potential root domain - root := strings.Join(parts[numParts-2:], ".") - - // Check if all MX records contain this potential root - for _, record := range mxRecords { - if !strings.HasSuffix(record, root) { - // If not, return the last part only (TLD) - return parts[numParts-1] - } - } - - return root -} - -func getRawMXRecords(email string) ([]*net.MX, error) { - _, domain, ok := syntax.GetEmailUserAndDomain(email) - if !ok { - return nil, fmt.Errorf("Invalid domain") - } - - mxRecords, err := net.LookupMX(domain) - if err != nil { - return nil, err - } - - return mxRecords, nil -} diff --git a/internal/mailserver/mailserver.go b/internal/mailserver/mailserver.go index 2816938..05b0130 100644 --- a/internal/mailserver/mailserver.go +++ b/internal/mailserver/mailserver.go @@ -28,25 +28,30 @@ type ProxySetup struct { Password string } -func VerifyEmailAddress(email, fromDomain, fromEmail string) (bool, SMPTValidation, error) { +func VerifyEmailAddress(email, fromDomain, fromEmail string, dnsRecords dns.DNS) (bool, SMPTValidation, error) { results := SMPTValidation{} var isVerified bool - mxServers, err := dns.GetMXRecordsForEmail(email) - if err != nil { - return false, results, err - } - - if len(mxServers) == 0 { - return false, results, fmt.Errorf("no MX records found for domain") + if len(dnsRecords.MX) == 0 { + results.Description = "No MX records for domain" + return false, results, nil } var conn net.Conn var client *bufio.Reader + var err error + var connected bool + + for i := 0; i < len(dnsRecords.MX); i++ { + conn, client, err = connectToSMTP(dnsRecords.MX[i]) + if err == nil { + connected = true + break + } + } - conn, client, err = connectToSMTP(mxServers[0]) - if err != nil { - return false, results, err + if !connected { + return false, results, fmt.Errorf("failed to connect to any MX server: %w", err) } defer conn.Close() diff --git a/internal/run/run.go b/internal/run/run.go index 065ccb5..e1d47f3 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/customeros/mailsherpa/internal/dns" "github.com/customeros/mailsherpa/mailvalidate" ) @@ -48,6 +49,7 @@ func BuildRequest(email string) mailvalidate.EmailValidationRequest { FromDomain: fromDomain, FromEmail: fmt.Sprintf("%s.%s@%s", firstname, lastname, fromDomain), CatchAllTestUser: mailvalidate.GenerateCatchAllUsername(), + Dns: dns.GetDNS(email), } return request } @@ -58,6 +60,10 @@ func BuildResponse(emailAddress string, syntax mailvalidate.SyntaxValidation, do isRisky = true } + if !domain.HasMXRecord { + email.SmtpSuccess = true + } + risk := VerifyEmailRisk{ IsFirewalled: domain.IsFirewalled, IsFreeAccount: email.IsFreeAccount, diff --git a/mailvalidate/validation.go b/mailvalidate/validation.go index 617d123..c929b66 100644 --- a/mailvalidate/validation.go +++ b/mailvalidate/validation.go @@ -17,6 +17,7 @@ type EmailValidationRequest struct { FromDomain string FromEmail string CatchAllTestUser string + Dns dns.DNS } type DomainValidation struct { @@ -26,6 +27,8 @@ type DomainValidation struct { IsFirewalled bool IsCatchAll bool CanConnectSMTP bool + HasMXRecord bool + HasSPFRecord bool } type EmailValidation struct { @@ -78,28 +81,33 @@ func ValidateDomainWithCustomKnownProviders(validationRequest EmailValidationReq return results, errors.Wrap(err, "Invalid request") } - provider, err := dns.GetEmailProviderFromMx(validationRequest.Email, knownProviders) - if err != nil { - return results, errors.Wrap(err, "Error getting provider from MX records") + if len(validationRequest.Dns.MX) != 0 { + results.HasMXRecord = true + provider, firewall := dns.GetEmailProviderFromMx(validationRequest.Dns, knownProviders) + results.Provider = provider + if firewall != "" { + results.Firewall = firewall + results.IsFirewalled = true + } } - results.Provider = provider - authorizedSenders, err := dns.GetAuthorizedSenders(validationRequest.Email, &knownProviders) - if err != nil { - return results, errors.Wrap(err, "Error getting authorized senders from spf records") + if validationRequest.Dns.SPF != "" { + results.HasSPFRecord = true + authorizedSenders := dns.GetAuthorizedSenders(validationRequest.Dns, &knownProviders) + results.AuthorizedSenders = authorizedSenders } - results.AuthorizedSenders = authorizedSenders - if results.Provider == "unknown" && len(results.AuthorizedSenders.Enterprise) > 0 { + + if results.Provider == "" && len(results.AuthorizedSenders.Enterprise) > 0 { results.Provider = results.AuthorizedSenders.Enterprise[0] } - if results.Provider == "unknown" && len(results.AuthorizedSenders.Webmail) > 0 { + if results.Provider == "" && len(results.AuthorizedSenders.Webmail) > 0 { results.Provider = results.AuthorizedSenders.Webmail[0] } - if results.Provider == "unknown" && len(results.AuthorizedSenders.Hosting) > 0 { + if results.Provider == "" && len(results.AuthorizedSenders.Hosting) > 0 { results.Provider = results.AuthorizedSenders.Hosting[0] } - if len(results.AuthorizedSenders.Security) > 0 { + if !results.IsFirewalled && len(results.AuthorizedSenders.Security) > 0 { results.IsFirewalled = true results.Firewall = results.AuthorizedSenders.Security[0] } @@ -145,6 +153,7 @@ func ValidateEmail(validationRequest EmailValidationRequest) (EmailValidation, e validationRequest.Email, validationRequest.FromDomain, validationRequest.FromEmail, + validationRequest.Dns, ) if err != nil { return results, errors.Wrap(err, "Error validating email via SMTP") @@ -235,6 +244,7 @@ func catchAllTest(validationRequest EmailValidationRequest) (bool, mailserver.SM catchAllEmail, validationRequest.FromDomain, validationRequest.FromEmail, + validationRequest.Dns, ) if err != nil { log.Printf("Error validating email via SMTP: %v", err)