Skip to content

Commit

Permalink
enhanced error handling, reduce number of dns calls
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbr0wn committed Aug 17, 2024
1 parent 4bbaaf2 commit ca9a1a9
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 155 deletions.
30 changes: 4 additions & 26 deletions internal/dns/spf.go → internal/dns/authorizedSenders.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package dns

import (
"fmt"
"net"
"regexp"
"strings"

"github.com/customeros/mailsherpa/internal/syntax"
)

type AuthorizedSenders struct {
Expand All @@ -17,29 +13,11 @@ type AuthorizedSenders struct {
Other []string
}

func GetAuthorizedSenders(email string, knownProviders *KnownProviders) (AuthorizedSenders, error) {
spfRecord, err := getSPFRecord(email)
if err != nil {
return AuthorizedSenders{}, fmt.Errorf("error getting SPF record: %w", err)
}
return processIncludes(spfRecord, knownProviders), nil
}

func getSPFRecord(email string) (string, error) {
_, domain, ok := syntax.GetEmailUserAndDomain(email)
if !ok {
return "", fmt.Errorf("invalid email address")
}
records, err := net.LookupTXT(domain)
if err != nil {
return "", fmt.Errorf("error looking up TXT records: %w", err)
}
for _, record := range records {
if strings.HasPrefix(strings.TrimSpace(record), "v=spf1") {
return record, nil
}
func GetAuthorizedSenders(dns DNS, knownProviders *KnownProviders) AuthorizedSenders {
if dns.SPF == "" {
return AuthorizedSenders{}
}
return "", fmt.Errorf("no SPF record found for domain %s", domain)
return processIncludes(dns.SPF, knownProviders)
}

func processIncludes(spfRecord string, knownProviders *KnownProviders) AuthorizedSenders {
Expand Down
98 changes: 98 additions & 0 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dns

import (
"fmt"
"net"
"sort"
"strings"

"github.com/customeros/mailsherpa/internal/syntax"
)

type DNS struct {
MX []string
SPF string
Errors []string
}

func GetDNS(email string) DNS {
var dns DNS
var mxErr error
var spfErr error

dns.MX, mxErr = getMXRecordsForEmail(email)
dns.SPF, spfErr = getSPFRecord(email)
if mxErr != nil {
dns.Errors = append(dns.Errors, mxErr.Error())
}
if spfErr != nil {
dns.Errors = append(dns.Errors, spfErr.Error())
}
return dns
}

func getMXRecordsForEmail(email string) ([]string, error) {
mxRecords, err := getRawMXRecords(email)
if err != nil {
return nil, err
}

// Sort MX records by priority (lower number = higher priority)
sort.Slice(mxRecords, func(i, j int) bool {
return mxRecords[i].Pref < mxRecords[j].Pref
})

stripDot := func(s string) string {
return strings.ToLower(strings.TrimSuffix(s, "."))
}

// Extract hostnames into a string array
result := make([]string, len(mxRecords))
for i, mx := range mxRecords {
result[i] = stripDot(mx.Host)
}

return result, nil
}

func getRawMXRecords(email string) ([]*net.MX, error) {
_, domain, ok := syntax.GetEmailUserAndDomain(email)
if !ok {
return nil, fmt.Errorf("Invalid domain")
}

mxRecords, err := net.LookupMX(domain)
if err != nil {
return nil, err
}

return mxRecords, nil
}

func getSPFRecord(email string) (string, error) {
_, domain, ok := syntax.GetEmailUserAndDomain(email)
if !ok {
return "", fmt.Errorf("invalid email address")
}
records, err := net.LookupTXT(domain)
if err != nil {
return "", fmt.Errorf("error looking up TXT records: %w", err)
}
for _, record := range records {
spfRecord := parseTXTRecord(record)
if strings.HasPrefix(spfRecord, "v=spf1") {
return spfRecord, nil
}
}
return "", fmt.Errorf("no SPF record found for domain %s", domain)
}

func parseTXTRecord(record string) string {
// Remove surrounding quotes if present
record = strings.Trim(record, "\"")

// Replace multiple spaces with a single space
record = strings.Join(strings.Fields(record), " ")

return record
}
29 changes: 29 additions & 0 deletions internal/dns/emailProvider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package dns

func GetEmailProviderFromMx(dns DNS, knownProviders KnownProviders) (emailProvider, firewall string) {
if len(dns.MX) == 0 {
return "", ""
}
for _, record := range dns.MX {
domain := extractRootDomain(record)
provider, category := knownProviders.GetProviderByDomain(domain)
if provider == "" {
return domain, ""
}

switch category {
case "enterprise":
return provider, ""
case "webmail":
return provider, ""
case "hosting":
return provider, ""
case "security":
return "", provider
default:
return "", ""
}
}

return "", ""
}
1 change: 1 addition & 0 deletions internal/dns/known_email_providers.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ domains = [
["encrypttitan.net", "encrypt titan"],
["forcepoint.com", "forcepoint"],
["greathorn.com", "greathorn"],
["iphmx.com", "cisco ironport"],
["mailcontrol.com", "mail control"],
["messagelabs.com", "broadcom"],
["mimecast.com", "mimecast"],
Expand Down
106 changes: 0 additions & 106 deletions internal/dns/mx.go

This file was deleted.

27 changes: 16 additions & 11 deletions internal/mailserver/mailserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,30 @@ type ProxySetup struct {
Password string
}

func VerifyEmailAddress(email, fromDomain, fromEmail string) (bool, SMPTValidation, error) {
func VerifyEmailAddress(email, fromDomain, fromEmail string, dnsRecords dns.DNS) (bool, SMPTValidation, error) {
results := SMPTValidation{}
var isVerified bool

mxServers, err := dns.GetMXRecordsForEmail(email)
if err != nil {
return false, results, err
}

if len(mxServers) == 0 {
return false, results, fmt.Errorf("no MX records found for domain")
if len(dnsRecords.MX) == 0 {
results.Description = "No MX records for domain"
return false, results, nil
}

var conn net.Conn
var client *bufio.Reader
var err error
var connected bool

for i := 0; i < len(dnsRecords.MX); i++ {
conn, client, err = connectToSMTP(dnsRecords.MX[i])
if err == nil {
connected = true
break
}
}

conn, client, err = connectToSMTP(mxServers[0])
if err != nil {
return false, results, err
if !connected {
return false, results, fmt.Errorf("failed to connect to any MX server: %w", err)
}

defer conn.Close()
Expand Down
6 changes: 6 additions & 0 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"

"github.com/customeros/mailsherpa/internal/dns"
"github.com/customeros/mailsherpa/mailvalidate"
)

Expand Down Expand Up @@ -48,6 +49,7 @@ func BuildRequest(email string) mailvalidate.EmailValidationRequest {
FromDomain: fromDomain,
FromEmail: fmt.Sprintf("%s.%s@%s", firstname, lastname, fromDomain),
CatchAllTestUser: mailvalidate.GenerateCatchAllUsername(),
Dns: dns.GetDNS(email),
}
return request
}
Expand All @@ -58,6 +60,10 @@ func BuildResponse(emailAddress string, syntax mailvalidate.SyntaxValidation, do
isRisky = true
}

if !domain.HasMXRecord {
email.SmtpSuccess = true
}

risk := VerifyEmailRisk{
IsFirewalled: domain.IsFirewalled,
IsFreeAccount: email.IsFreeAccount,
Expand Down
Loading

0 comments on commit ca9a1a9

Please sign in to comment.