Skip to content

Commit

Permalink
centralize domain handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbr0wn authored and alexopenline committed Sep 6, 2024
1 parent cf1b396 commit 4341cf8
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 59 deletions.
22 changes: 3 additions & 19 deletions domaincheck/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"sort"
"strings"
"time"

"github.com/customeros/mailsherpa/internal/syntax"
)

type DNS struct {
Expand Down Expand Up @@ -70,7 +71,7 @@ func CheckRedirects(domain string) (bool, string) {
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
location := resp.Header.Get("Location")
if location != "" {
location = extractDomain(location)
location, _ = syntax.ExtractDomain(location)
if location != domain {
return true, location
}
Expand All @@ -81,23 +82,6 @@ func CheckRedirects(domain string) (bool, string) {
return false, ""
}

func extractDomain(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
return urlStr // Return as-is if parsing fails
}

// Remove 'www.' prefix if present
domain := strings.TrimPrefix(u.Hostname(), "www.")

// Split the domain and get the last two parts (or just one if it's a TLD)
parts := strings.Split(domain, ".")
if len(parts) > 2 {
return strings.Join(parts[len(parts)-2:], ".")
}
return domain
}

func getMXRecordsForDomain(domain string) ([]string, error) {
mxRecords, err := getRawMXRecords(domain)
if err != nil {
Expand Down
39 changes: 6 additions & 33 deletions internal/dns/authorizedSenders.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package dns

import (
"log"
"regexp"
"strings"

"github.com/customeros/mailsherpa/domaincheck"
"github.com/customeros/mailsherpa/internal/syntax"
)

type AuthorizedSenders struct {
Expand Down Expand Up @@ -45,7 +46,10 @@ func processIncludes(spfRecord string, knownProviders *KnownProviders) Authorize
if len(include) < 2 {
continue
}
includeDomain := extractRootDomain(include[1])
includeDomain, err := syntax.ExtractDomain(include[1])
if err != nil {
log.Printf("Error: %v", err)
}
providerName, category := knownProviders.GetProviderByDomain(includeDomain)
if providerName != "" {
if slice, exists := categoryMap[category]; exists {
Expand All @@ -65,34 +69,3 @@ func appendIfNotExists(slice *[]string, s string) {
}
*slice = append(*slice, s)
}

func extractRootDomain(fullDomain string) string {
parts := strings.Split(fullDomain, ".")
if len(parts) <= 2 {
return fullDomain
}

// List of known ccTLDs with second-level domains
ccTLDsWithSLD := map[string]bool{
"uk": true, "au": true, "nz": true, "jp": true,
}

// Common second-level domains
commonSLDs := map[string]bool{
"com": true, "org": true, "net": true, "edu": true, "gov": true, "co": true,
}

tldIndex := len(parts) - 1
sldIndex := tldIndex - 1

// Check for ccTLDs with second-level domains
if ccTLDsWithSLD[parts[tldIndex]] && commonSLDs[parts[sldIndex]] {
if len(parts) > 3 {
return strings.Join(parts[len(parts)-3:], ".")
}
return fullDomain
}

// For other cases, return the last two parts
return strings.Join(parts[sldIndex:], ".")
}
10 changes: 8 additions & 2 deletions internal/dns/emailProvider.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package dns

import "github.com/customeros/mailsherpa/domaincheck"
import (
"github.com/customeros/mailsherpa/domaincheck"
"github.com/customeros/mailsherpa/internal/syntax"
)

func GetEmailProviderFromMx(dns domaincheck.DNS, knownProviders KnownProviders) (emailProvider, firewall string) {
if len(dns.MX) == 0 {
return "", ""
}
for _, record := range dns.MX {
domain := extractRootDomain(record)
domain, err := syntax.ExtractDomain(record)
if err != nil {
continue
}
provider, category := knownProviders.GetProviderByDomain(domain)
if provider == "" {
return domain, ""
Expand Down
6 changes: 1 addition & 5 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,12 @@ func BuildResponse(
SecureGatewayProvider: domain.SecureGatewayProvider,
IsRisky: isRisky,
Risk: risk,
AlternateEmail: email.AlternateEmail,
RetryValidation: email.RetryValidation,
Syntax: syntax,
Smtp: email.SmtpResponse,
MailServerHealth: email.MailServerHealth,
}

if !domain.IsPrimaryDomain {
altEmail := fmt.Sprintf("%s@%s", syntax.User, domain.PrimaryDomain)
response.AlternateEmail.Email = altEmail
}

return response
}
36 changes: 36 additions & 0 deletions internal/syntax/syntax.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package syntax

import (
"fmt"
"net/url"
"regexp"
"strings"
"unicode"
Expand Down Expand Up @@ -124,3 +125,38 @@ func convertToAscii(input string) string {

return string(ascii)
}

func ExtractDomain(fullDomain string) (string, error) {
var domain string
if strings.Contains(fullDomain, "://") {
// It's likely a URL, so parse it
u, err := url.Parse(fullDomain)
if err != nil {
return "", fmt.Errorf("failed to parse URL: %v", err)
}
domain = u.Hostname()
} else {
// It's likely just a domain, so use it as-is
domain = fullDomain
}

// Remove 'www.' prefix if present
domain = strings.TrimPrefix(domain, "www.")

// Get the public suffix (e.g., "co.uk", "com")
_, icann := publicsuffix.PublicSuffix(domain)

// If the public suffix is not managed by ICANN, we might want to handle it differently
if !icann {
return "", fmt.Errorf("non-ICANN managed domain: %s", domain)
}

// Extract the registrable domain (eTLD+1)
registrableDomain, err := publicsuffix.EffectiveTLDPlusOne(domain)
if err != nil {
return "", fmt.Errorf("failed to get eTLD+1: %v", err)
}

return registrableDomain, nil

}
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"flag"
"fmt"

"github.com/customeros/mailsherpa/domaincheck"
"github.com/customeros/mailsherpa/internal/cmd"
)

Expand All @@ -24,6 +25,8 @@ func main() {
return
}
cmd.VerifySyntax(args[1], true)
case "redirect":
fmt.Println(domaincheck.PrimaryDomainCheck(args[1]))
case "version":
cmd.Version()
default:
Expand Down

0 comments on commit 4341cf8

Please sign in to comment.