From ff35d6f004f16ac71c898066f9ebaf3e83dfa60d Mon Sep 17 00:00:00 2001 From: arzumify Date: Sun, 8 Dec 2024 15:13:33 +0000 Subject: [PATCH] Made the authentication mechanisms more modular by returning an authentication function rather than an address --- smtp.go | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/smtp.go b/smtp.go index 1cc0fad..7d84e1a 100644 --- a/smtp.go +++ b/smtp.go @@ -58,9 +58,11 @@ type DatabaseBackend struct { // AuthenticationBackend is a struct that represents an authentication backend type AuthenticationBackend struct { - Authenticate func(initial string, conn *textproto.Conn) (*Address, error) + Authenticate func(initial string, conn *textproto.Conn) (CheckAddress, error) } +type CheckAddress func(*Address) (bool, error) + func readMultilineCodeResponse(conn *textproto.Conn) (int, string, error) { var lines strings.Builder for { @@ -147,7 +149,7 @@ func speakMultiLine(conn *textproto.Conn, lines []string) error { type Receiver struct { underlyingListener net.Listener hostname string - ownedDomains map[string]any + ownedDomains map[string]struct{} enforceTLS bool tlsConfig *tls.Config database DatabaseBackend @@ -156,9 +158,9 @@ type Receiver struct { // NewReceiver creates a new Receiver func NewReceiver(conn net.Listener, hostname string, ownedDomains []string, enforceTLS bool, database DatabaseBackend, authentication AuthenticationBackend, tlsConfig *tls.Config) *Receiver { - var ownedDomainsMap = make(map[string]any) + var ownedDomainsMap = make(map[string]struct{}) for _, domain := range ownedDomains { - ownedDomainsMap[domain] = nil + ownedDomainsMap[domain] = struct{}{} } return &Receiver{ underlyingListener: conn, @@ -191,7 +193,7 @@ func (fr *Receiver) Serve() error { func (fr *Receiver) handleConnection(conn net.Conn) { var state struct { HELO bool - AUTH *Address + AUTH CheckAddress TLS bool FROM *Address RCPT []*Address @@ -320,14 +322,14 @@ func (fr *Receiver) handleConnection(conn net.Conn) { } continue } else { - address, err := fr.auth.Authenticate(strings.TrimPrefix(line, "AUTH "), textProto) + checkAddress, err := fr.auth.Authenticate(strings.TrimPrefix(line, "AUTH "), textProto) if err != nil { _ = textProto.PrintfLine("421 4.7.0 Temporary server error") _ = conn.Close() return } - if address == nil { + if checkAddress == nil { err = textProto.PrintfLine("535 5.7.8 Authentication failed") if err != nil { _ = textProto.PrintfLine("421 4.7.0 Temporary server error") @@ -335,7 +337,7 @@ func (fr *Receiver) handleConnection(conn net.Conn) { return } } else { - state.AUTH = address + state.AUTH = checkAddress err = textProto.PrintfLine("235 2.7.0 Authentication successful") if err != nil { _ = textProto.PrintfLine("421 4.7.0 Temporary server error") @@ -410,7 +412,14 @@ func (fr *Receiver) handleConnection(conn net.Conn) { continue } - if *address != *state.AUTH { + ok, err := state.AUTH(address) + if err != nil { + _ = textProto.PrintfLine("421 4.7.0 Temporary server error") + _ = conn.Close() + return + } + + if !ok { err = textProto.PrintfLine("535 5.7.8 Authenticated wrong user") if err != nil { _ = textProto.PrintfLine("421 4.7.0 Temporary server error")