From 694fd073fdbf8d7bdb3027cd6d23dbf006b5a803 Mon Sep 17 00:00:00 2001 From: arzumify Date: Thu, 12 Dec 2024 21:08:05 +0000 Subject: [PATCH] Fixed a lot of various issues found via testing --- backend.go | 9 ----- go.mod | 2 + mailbox.go | 21 ++++++---- main.go | 112 +++++++++++++++++++++++++++++++++++++++++------------ user.go | 3 +- 5 files changed, 105 insertions(+), 42 deletions(-) diff --git a/backend.go b/backend.go index e7889c5..dd3b4e6 100644 --- a/backend.go +++ b/backend.go @@ -2,7 +2,6 @@ package main import ( "crypto/ed25519" - "fmt" "github.com/emersion/go-imap" "github.com/emersion/go-imap/backend" ) @@ -33,21 +32,13 @@ func (be *Backend) Login(_ *imap.ConnInfo, username, token string) (backend.User openMessages: make(map[*Message]struct{}), } - fmt.Println("YOU'VE GOT THIS FAR") - _, err = user.GetMailbox("INBOX") if err != nil { - fmt.Println("NO INBOX") err := user.CreateMailbox("INBOX") if err != nil { - fmt.Println("Failed to create mailbox: " + err.Error()) return nil, err } - - fmt.Println("INBOX CREATED") } - fmt.Println("LOGIN SUCCESSFUL") - return user, nil } diff --git a/go.mod b/go.mod index ef9a2f5..66b5d4f 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,8 @@ require ( github.com/lib/pq v1.10.9 ) +replace "git.ailur.dev/ailur/smtp" v1.1.0 => "/home/liqing/Projects/libraries/smtp" + require ( git.ailur.dev/ailur/spf v1.0.1 // indirect github.com/go-chi/chi/v5 v5.1.0 // indirect diff --git a/mailbox.go b/mailbox.go index 4970ce9..6ad543b 100644 --- a/mailbox.go +++ b/mailbox.go @@ -240,18 +240,20 @@ func (mbox *Mailbox) ListMessages(useUid bool, seqSet *imap.SeqSet, items []imap seqNum++ var uid, size uint32 - var date time.Time + var date int64 var flagsRaw string var ownerRaw []byte var idRaw []byte err := messages.Scan(&idRaw, &uid, &date, &size, &flagsRaw, &ownerRaw) if err != nil { + println(err.Error()) // Skip any emails that can't be read continue } - msg, err := LoadRawMessage(idRaw, uid, date, size, flagsRaw, ownerRaw, mbox) + msg, err := LoadRawMessage(idRaw, uid, time.Unix(date, 0), size, flagsRaw, ownerRaw, mbox) if err != nil { + println(err.Error()) // Skip any emails that fail to load from disk continue } @@ -301,7 +303,7 @@ func (mbox *Mailbox) SearchMessages(useUid bool, criteria *imap.SearchCriteria) seqNum++ var uid, size uint32 - var date time.Time + var date int64 var flagsRaw string var ownerRaw []byte var idRaw []byte @@ -311,7 +313,7 @@ func (mbox *Mailbox) SearchMessages(useUid bool, criteria *imap.SearchCriteria) continue } - msg, err := LoadRawMessage(idRaw, uid, date, size, flagsRaw, ownerRaw, mbox) + msg, err := LoadRawMessage(idRaw, uid, time.Unix(date, 0), size, flagsRaw, ownerRaw, mbox) if err != nil { // Skip any emails that fail to load from disk continue @@ -367,7 +369,12 @@ func (mbox *Mailbox) CreateMessage(flags []string, date time.Time, body imap.Lit return err } - _, err = Database.DB.Exec("INSERT INTO messages (mailbox, id, uid, created, bodySize, flags, owner) VALUES ($1, $2, $3, $4, $5, $6, $7)", mbox.id[:], messageID[:], uid, date, len(b), string(flagsRaw), mbox.user.sub[:]) + _, err = Database.DB.Exec("INSERT INTO messages (mailbox, id, uid, created, bodySize, flags, owner) VALUES ($1, $2, $3, $4, $5, $6, $7)", mbox.id[:], messageID[:], uid, date.Unix(), len(b), string(flagsRaw), mbox.user.sub[:]) + if err != nil { + return err + } + + err = StoreFile(messageID.String(), b, mbox.user.sub) if err != nil { return err } @@ -458,7 +465,7 @@ func (mbox *Mailbox) CopyMessages(useUid bool, seqset *imap.SeqSet, destName str seqNum++ var uid, size uint32 - var date time.Time + var date int64 var flagsRaw string var idRaw, ownerRaw []byte err := messages.Scan(&idRaw, &uid, &date, &size, &flagsRaw, &ownerRaw) @@ -467,7 +474,7 @@ func (mbox *Mailbox) CopyMessages(useUid bool, seqset *imap.SeqSet, destName str continue } - msg, err := LoadRawMessage(idRaw, uid, date, size, flagsRaw, ownerRaw, mbox) + msg, err := LoadRawMessage(idRaw, uid, time.Unix(date, 0), size, flagsRaw, ownerRaw, mbox) if err != nil { // Skip any emails that fail to load from disk continue diff --git a/main.go b/main.go index 3457e37..32b0f56 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,14 @@ package main import ( "bytes" + "encoding/base64" "errors" + "fmt" "io" "net" "net/url" "os" + "strings" "sync" "time" @@ -52,7 +55,18 @@ func log(message string, messageType uint64) { } } +// Authenticate Fake for testing func Authenticate(token string, config OAuthConfig) (uuid.UUID, error) { + println("called") + return uuid.MustParse("e59fece6-256f-4799-bb31-321268387d12"), nil +} + +// GetUsername Fake for testing +func GetUsername(token string, config OAuthConfig) (string, error) { + return "arzumify", nil +} + +func RAuthenticate(token string, config OAuthConfig) (uuid.UUID, error) { parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { return config.PublicKey, nil }) @@ -77,7 +91,7 @@ func Authenticate(token string, config OAuthConfig) (uuid.UUID, error) { return uuid.MustParse(claims["sub"].(string)), nil } -func GetUsername(token string, config OAuthConfig) (string, error) { +func RGetUsername(token string, config OAuthConfig) (string, error) { var responseData struct { Username string `json:"username"` } @@ -126,7 +140,7 @@ func StoreFile(name string, data []byte, owner uuid.UUID) error { case 0: return nil case 1, 2: - return errors.New(response.Message.(string)) + return response.Message.(error) default: return errors.New("unknown error") } @@ -181,7 +195,7 @@ func DeleteFile(name string, owner uuid.UUID) error { case 0: return nil case 1, 2: - return errors.New(response.Message.(string)) + return response.Message.(error) default: return errors.New("unknown error") } @@ -477,38 +491,85 @@ var ( return nil }, } - NewSMTPAuthenticationBackend = func(OAuthRegistration OAuthConfig) smtp.AuthenticationBackend { - return smtp.AuthenticationBackend{ - Authenticate: func(initial string, conn *textproto.Conn) (smtp.CheckAddress, error) { - sub, err := Authenticate(initial, OAuthRegistration) +) + +func NewSMTPAuthenticationBackend(OAuthRegistration OAuthConfig) smtp.AuthenticationBackend { + return smtp.AuthenticationBackend{ + SupportedMechanisms: []string{"PLAIN", "XOAUTH2", "OAUTHBEARER"}, + Authenticate: func(initial string, conn *textproto.Conn) (checkAddr smtp.CheckAddress, finalErr error) { + initialResponse := strings.SplitN(initial, " ", 2) + + var username, token string + switch initialResponse[0] { + case "PLAIN": + credentials, err := base64.StdEncoding.DecodeString(initialResponse[1]) if err != nil { - return nil, err + return nil, errors.New("421 4.7.0 Malformed credentials") } - return func(address *smtp.Address) (bool, error) { - rows, err := Database.DB.Query("SELECT prefix, suffix FROM emails WHERE creator = $1", sub[:]) + credentialSlice := bytes.SplitN(bytes.TrimPrefix(credentials, []byte{0x00}), []byte{0x00}, 2) + username = string(credentialSlice[0]) + token = string(credentialSlice[1]) + case "OAUTHBEARER", "XOAUTH2": + credentials, err := base64.StdEncoding.DecodeString(initialResponse[1]) + if err != nil { + return nil, errors.New("421 4.7.0 Malformed credentials") + } + + credentialSlice := bytes.SplitN(bytes.TrimSuffix(bytes.TrimPrefix(credentials, []byte("user=")), []byte{0x01, 0x01}), []byte{0x01}, 2) + username = string(credentialSlice[0]) + token = string(credentialSlice[1]) + default: + return nil, errors.New("503 5.5.1 Invalid authentication method: " + initialResponse[0]) + } + + fmt.Println("Username: " + username) + fmt.Println("Token: " + token) + + sub, err := Authenticate(token, OAuthRegistration) + if err != nil { + return nil, errors.New("421 4.7.0 Invalid credentials") + } + + usernameCheck, err := GetUsername(token, OAuthRegistration) + if err != nil { + return nil, errors.New("421 4.7.0 Invalid credentials") + } + + if username != usernameCheck { + return nil, errors.New("421 4.7.0 Username does not match") + } + + return func(address *smtp.Address) (bool, error) { + rows, err := Database.DB.Query("SELECT prefix, suffix FROM emails WHERE creator = $1", sub[:]) + if err != nil { + return false, err + } + + defer func() { + err := rows.Close() + if err != nil { + log("Failed to close rows: "+err.Error()+", resource leaks may occur", 1) + } + }() + + for rows.Next() { + var prefix, suffix string + err = rows.Scan(&prefix, &suffix) if err != nil { return false, err } - for rows.Next() { - var prefix, suffix string - err = rows.Scan(&prefix, &suffix) - if err != nil { - return false, err - } - - if address.Name == prefix && address.Address == suffix { - return true, nil - } + if address.Name == prefix && address.Address == suffix { + return true, nil } + } - return false, nil - }, nil - }, - } + return false, nil + }, nil + }, } -) +} func parseConfig() (hostName string, listenerHost string, ownedDomains []string, enforceTLS bool, enableTLS bool, certificatePath string, keyPath string, err error) { var ok bool @@ -651,6 +712,7 @@ func Main(information library.ServiceInitializationInformation) { log("Failed to listen on port 25: "+err.Error(), 3) return } + smtpBackend := smtp.NewReceiver(smtpListener, hostName, ownedDomains, enforceTLS, SMTPDatabaseBackend, NewSMTPAuthenticationBackend(oauthConfig), smtpTLSConfig) err = smtpBackend.Serve() if err != nil { diff --git a/user.go b/user.go index 4b47a2f..1521fb1 100644 --- a/user.go +++ b/user.go @@ -74,7 +74,8 @@ func (u *User) GetMailbox(name string) (mailbox backend.Mailbox, err error) { } func (u *User) CreateMailbox(name string) error { - _, err := Database.DB.Exec("INSERT INTO mailboxes (mailbox, id, owner) VALUES ($1, $2, $3)", name, uuid.New(), u.sub[:]) + newUUID := uuid.New() + _, err := Database.DB.Exec("INSERT INTO mailboxes (mailbox, id, owner) VALUES ($1, $2, $3)", name, newUUID[:], u.sub[:]) if err != nil { return err }