Compare commits

..

No commits in common. "main" and "v1.0.4" have entirely different histories.
main ... v1.0.4

2 changed files with 36 additions and 28 deletions

View file

@ -22,9 +22,10 @@ var DatabaseBackend = smtp.DatabaseBackend{
// AuthenticationBackend is a smtp.AuthenticationBackend implementation that always returns a fixed address for Authenticate. // AuthenticationBackend is a smtp.AuthenticationBackend implementation that always returns a fixed address for Authenticate.
var AuthenticationBackend = smtp.AuthenticationBackend{ var AuthenticationBackend = smtp.AuthenticationBackend{
Authenticate: func(initial string, conn *textproto.Conn) (smtp.CheckAddress, error) { Authenticate: func(initial string, conn *textproto.Conn) (*smtp.Address, error) {
return func(address *smtp.Address) (bool, error) { return &smtp.Address{
return true, nil Name: "test",
Address: "example.org",
}, nil }, nil
}, },
} }

57
smtp.go
View file

@ -58,12 +58,9 @@ type DatabaseBackend struct {
// AuthenticationBackend is a struct that represents an authentication backend // AuthenticationBackend is a struct that represents an authentication backend
type AuthenticationBackend struct { type AuthenticationBackend struct {
Authenticate func(initial string, conn *textproto.Conn) (CheckAddress, error) Authenticate func(initial string, conn *textproto.Conn) (*Address, error)
SupportedMechanisms []string
} }
type CheckAddress func(*Address) (bool, error)
func readMultilineCodeResponse(conn *textproto.Conn) (int, string, error) { func readMultilineCodeResponse(conn *textproto.Conn) (int, string, error) {
var lines strings.Builder var lines strings.Builder
for { for {
@ -150,7 +147,7 @@ func speakMultiLine(conn *textproto.Conn, lines []string) error {
type Receiver struct { type Receiver struct {
underlyingListener net.Listener underlyingListener net.Listener
hostname string hostname string
ownedDomains map[string]struct{} ownedDomains map[string]any
enforceTLS bool enforceTLS bool
tlsConfig *tls.Config tlsConfig *tls.Config
database DatabaseBackend database DatabaseBackend
@ -159,9 +156,9 @@ type Receiver struct {
// NewReceiver creates a new Receiver // NewReceiver creates a new Receiver
func NewReceiver(conn net.Listener, hostname string, ownedDomains []string, enforceTLS bool, database DatabaseBackend, authentication AuthenticationBackend, tlsConfig *tls.Config) *Receiver { func NewReceiver(conn net.Listener, hostname string, ownedDomains []string, enforceTLS bool, database DatabaseBackend, authentication AuthenticationBackend, tlsConfig *tls.Config) *Receiver {
var ownedDomainsMap = make(map[string]struct{}) var ownedDomainsMap = make(map[string]any)
for _, domain := range ownedDomains { for _, domain := range ownedDomains {
ownedDomainsMap[domain] = struct{}{} ownedDomainsMap[domain] = nil
} }
return &Receiver{ return &Receiver{
underlyingListener: conn, underlyingListener: conn,
@ -194,7 +191,7 @@ func (fr *Receiver) Serve() error {
func (fr *Receiver) handleConnection(conn net.Conn) { func (fr *Receiver) handleConnection(conn net.Conn) {
var state struct { var state struct {
HELO bool HELO bool
AUTH CheckAddress AUTH *Address
TLS bool TLS bool
FROM *Address FROM *Address
RCPT []*Address RCPT []*Address
@ -212,6 +209,8 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
return return
} }
fmt.Println("Connection from", conn.RemoteAddr().String())
for { for {
line, err := textProto.ReadLine() line, err := textProto.ReadLine()
if err != nil { if err != nil {
@ -282,9 +281,6 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
if fr.enforceTLS { if fr.enforceTLS {
capabilities = append(capabilities, "250-REQUIRETLS") capabilities = append(capabilities, "250-REQUIRETLS")
} }
if fr.auth.SupportedMechanisms != nil {
capabilities = append(capabilities, "250-AUTH "+strings.Join(fr.auth.SupportedMechanisms, " "))
}
capabilities = append(capabilities, defaultCapabilities...) capabilities = append(capabilities, defaultCapabilities...)
state.HELO = true state.HELO = true
err = speakMultiLine(textProto, capabilities) err = speakMultiLine(textProto, capabilities)
@ -324,14 +320,14 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
} }
continue continue
} else { } else {
checkAddress, err := fr.auth.Authenticate(strings.TrimPrefix(line, "AUTH "), textProto) address, err := fr.auth.Authenticate(strings.TrimPrefix(line, "AUTH "), textProto)
if err != nil { if err != nil {
_ = textProto.PrintfLine(err.Error()) _ = textProto.PrintfLine("421 4.7.0 Temporary server error")
_ = conn.Close() _ = conn.Close()
return return
} }
if checkAddress == nil { if address == nil {
err = textProto.PrintfLine("535 5.7.8 Authentication failed") err = textProto.PrintfLine("535 5.7.8 Authentication failed")
if err != nil { if err != nil {
_ = textProto.PrintfLine("421 4.7.0 Temporary server error") _ = textProto.PrintfLine("421 4.7.0 Temporary server error")
@ -339,7 +335,7 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
return return
} }
} else { } else {
state.AUTH = checkAddress state.AUTH = address
err = textProto.PrintfLine("235 2.7.0 Authentication successful") err = textProto.PrintfLine("235 2.7.0 Authentication successful")
if err != nil { if err != nil {
_ = textProto.PrintfLine("421 4.7.0 Temporary server error") _ = textProto.PrintfLine("421 4.7.0 Temporary server error")
@ -414,14 +410,7 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
continue continue
} }
ok, err := state.AUTH(address) if *address != *state.AUTH {
if err != nil {
_ = textProto.PrintfLine("421 4.7.0 Temporary server error")
_ = conn.Close()
return
}
if !ok {
err = textProto.PrintfLine("535 5.7.8 Authenticated wrong user") err = textProto.PrintfLine("535 5.7.8 Authenticated wrong user")
if err != nil { if err != nil {
_ = textProto.PrintfLine("421 4.7.0 Temporary server error") _ = textProto.PrintfLine("421 4.7.0 Temporary server error")
@ -590,6 +579,7 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
Host: strings.Split(conn.RemoteAddr().String(), ":")[0], Host: strings.Split(conn.RemoteAddr().String(), ":")[0],
} }
go sendEmail(SenderArgs{ go sendEmail(SenderArgs{
Hostname: fr.hostname,
EnforceTLS: fr.enforceTLS, EnforceTLS: fr.enforceTLS,
}, mail, fr.database, queueID) }, mail, fr.database, queueID)
@ -618,6 +608,7 @@ func (fr *Receiver) handleConnection(conn net.Conn) {
// SenderArgs is a struct that represents the arguments for the Sender // SenderArgs is a struct that represents the arguments for the Sender
type SenderArgs struct { type SenderArgs struct {
Hostname string
EnforceTLS bool EnforceTLS bool
} }
@ -648,7 +639,7 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
return errors.New("unexpected greeting - " + line) return errors.New("unexpected greeting - " + line)
} }
err = textConn.PrintfLine("EHLO %s", mxHost) err = textConn.PrintfLine("EHLO %s", args.Hostname)
if err != nil { if err != nil {
return err return err
} }
@ -682,11 +673,16 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
InsecureSkipVerify: false, InsecureSkipVerify: false,
}) })
err = tlsConn.Handshake()
if err != nil {
return err
}
textConn = textproto.NewConn(tlsConn) textConn = textproto.NewConn(tlsConn)
// Just use HELO, no point using EHLO when we already have all the capabilities // Just use HELO, no point using EHLO when we already have all the capabilities
// This also gets us out of using readMultilineCodeResponse // This also gets us out of using readMultilineCodeResponse
err = textConn.PrintfLine("HELO %s", mxHost) err = textConn.PrintfLine("HELO %s", args.Hostname)
if err != nil { if err != nil {
return err return err
} }
@ -709,7 +705,10 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
} }
code, line, err = textConn.ReadCodeLine(0) code, line, err = textConn.ReadCodeLine(0)
fmt.Println(code, line, err)
if err != nil { if err != nil {
// For some reason the EHLO stuff ends up here
fmt.Println("5")
return err return err
} }
@ -724,7 +723,9 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
} }
code, line, err = textConn.ReadCodeLine(0) code, line, err = textConn.ReadCodeLine(0)
fmt.Println(code, line, err)
if err != nil { if err != nil {
fmt.Println("6")
return err return err
} }
@ -739,7 +740,9 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
} }
code, line, err = textConn.ReadCodeLine(0) code, line, err = textConn.ReadCodeLine(0)
fmt.Println(code, line, err)
if err != nil { if err != nil {
fmt.Println("7")
return err return err
} }
@ -759,7 +762,9 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
} }
code, line, err = textConn.ReadCodeLine(0) code, line, err = textConn.ReadCodeLine(0)
fmt.Println(code, line, err)
if err != nil { if err != nil {
fmt.Println("8")
return err return err
} }
@ -773,7 +778,9 @@ func Send(args SenderArgs, mail *Mail, conn net.Conn, mxHost string) (err error)
} }
code, line, err = textConn.ReadCodeLine(0) code, line, err = textConn.ReadCodeLine(0)
fmt.Println(code, line, err)
if err != nil { if err != nil {
fmt.Println("9")
return err return err
} }