spf/spf.go

484 lines
13 KiB
Go

package spf
import (
"net"
"strings"
)
type ErrorType int
const (
ErrTypeNeutral ErrorType = iota
ErrTypeNone
ErrTypeFail
ErrTypeSoftFail
ErrTypeInternal
)
var (
ErrNeutral = &Error{error: "record returned neutral", errorType: ErrTypeNeutral}
ErrFail = &Error{error: "record returned explicit fail", errorType: ErrTypeFail}
ErrSoftFail = &Error{error: "record returned soft fail", errorType: ErrTypeSoftFail}
ErrNone = &Error{error: "no record returned", errorType: ErrTypeNone}
)
type Error struct {
error string
errorType ErrorType
}
func (e *Error) Error() string {
return e.error
}
func (e *Error) Type() ErrorType {
return e.errorType
}
func CheckIP(ip string, domain string) *Error {
txt, err := net.LookupTXT(domain)
if err != nil {
return ErrNone
}
for _, record := range txt {
if strings.HasPrefix(record, "v=spf1") {
parts := strings.Split(record, " ")
for _, part := range parts {
switch {
case strings.HasPrefix(part, "all") || strings.HasPrefix(part, "+all"):
return nil
case strings.HasPrefix(part, "redirect="):
err := CheckIP(ip, part[9:])
if err != nil {
if err.Type() != ErrTypeInternal {
return err
}
} else {
return nil
}
case strings.HasPrefix(part, "ip4:") || strings.HasPrefix(part, "ip6:"):
if strings.Contains(part, "/") {
_, ipNet, err := net.ParseCIDR(part[4:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if ipNet.Contains(net.ParseIP(ip)) {
return nil
}
} else {
if part[4:] == ip {
return nil
}
}
case strings.HasPrefix(part, "+ip4:") || strings.HasPrefix(part, "+ip6:"):
if strings.Contains(part, "/") {
_, ipNet, err := net.ParseCIDR(part[5:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if ipNet.Contains(net.ParseIP(ip)) {
return nil
}
} else {
if part[5:] == ip {
return nil
}
}
case strings.HasPrefix(part, "-ip4:") || strings.HasPrefix(part, "-ip6:"):
if strings.Contains(part, "/") {
_, ipNet, err := net.ParseCIDR(part[5:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if ipNet.Contains(net.ParseIP(ip)) {
return ErrFail
}
} else {
if part[5:] == ip {
return ErrFail
}
}
case strings.HasPrefix(part, "~ip4:") || strings.HasPrefix(part, "~ip6:"):
if strings.Contains(part, "/") {
_, ipNet, err := net.ParseCIDR(part[5:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if ipNet.Contains(net.ParseIP(ip)) {
return ErrSoftFail
}
} else {
if part[5:] == ip {
return ErrSoftFail
}
}
case strings.HasPrefix(part, "?ip4:") || strings.HasPrefix(part, "?ip6:"):
if strings.Contains(part, "/") {
_, ipNet, err := net.ParseCIDR(part[5:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if ipNet.Contains(net.ParseIP(ip)) {
return ErrNeutral
}
} else {
if part[5:] == ip {
return ErrNeutral
}
}
case strings.HasPrefix(part, "include:"):
err := CheckIP(ip, part[8:])
if err != nil {
if err.Type() != ErrTypeInternal {
return err
}
} else {
return nil
}
case strings.HasPrefix(part, "+include:"):
err := CheckIP(ip, part[9:])
if err != nil {
if err.Type() != ErrTypeInternal {
return err
}
} else {
return nil
}
case strings.HasPrefix(part, "-include:"):
err := CheckIP(ip, part[9:])
if err == nil {
return ErrFail
}
case strings.HasPrefix(part, "~include:"):
err := CheckIP(ip, part[9:])
if err == nil {
return ErrSoftFail
}
case strings.HasPrefix(part, "?include:"):
err := CheckIP(ip, part[9:])
if err == nil {
return ErrNeutral
}
case part == "a" || part == "+a":
ips, err := net.LookupIP(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
case part == "-a":
ips, err := net.LookupIP(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrFail
}
}
case part == "~a":
ips, err := net.LookupIP(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrSoftFail
}
}
case part == "?a":
ips, err := net.LookupIP(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrNeutral
}
}
case strings.HasPrefix(part, "a:"):
ips, err := net.LookupIP(part[2:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
case strings.HasPrefix(part, "+a:"):
ips, err := net.LookupIP(part[3:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
case strings.HasPrefix(part, "-a:"):
ips, err := net.LookupIP(part[3:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrFail
}
}
case strings.HasPrefix(part, "~a:"):
ips, err := net.LookupIP(part[3:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrSoftFail
}
}
case strings.HasPrefix(part, "?a:"):
ips, err := net.LookupIP(part[3:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrNeutral
}
}
case strings.HasPrefix(part, "mx") || strings.HasPrefix(part, "+mx"):
mxs, err := net.LookupMX(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
}
case strings.HasPrefix(part, "-mx"):
mxs, err := net.LookupMX(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrFail
}
}
}
case strings.HasPrefix(part, "~mx"):
mxs, err := net.LookupMX(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrSoftFail
}
}
}
case strings.HasPrefix(part, "?mx"):
mxs, err := net.LookupMX(domain)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrNeutral
}
}
}
case strings.HasPrefix(part, "mx:"):
mxs, err := net.LookupMX(part[3:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
}
case strings.HasPrefix(part, "+mx:"):
mxs, err := net.LookupMX(part[4:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return nil
}
}
}
case strings.HasPrefix(part, "-mx:"):
mxs, err := net.LookupMX(part[4:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrFail
}
}
}
case strings.HasPrefix(part, "~mx:"):
mxs, err := net.LookupMX(part[4:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrSoftFail
}
}
}
case strings.HasPrefix(part, "?mx:"):
mxs, err := net.LookupMX(part[4:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, mx := range mxs {
ips, err := net.LookupIP(mx.Host)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, ipCheck := range ips {
if ipCheck.String() == ip {
return ErrNeutral
}
}
}
case strings.HasPrefix(part, "ptr:"):
names, err := net.LookupAddr(ip)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, name := range names {
if strings.HasSuffix(name, part[4:]+".") {
return nil
}
}
case strings.HasPrefix(part, "+ptr:"):
names, err := net.LookupAddr(ip)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, name := range names {
if strings.HasSuffix(name, part[5:]+".") {
return nil
}
}
case strings.HasPrefix(part, "-ptr:"):
names, err := net.LookupAddr(ip)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, name := range names {
if strings.HasSuffix(name, part[5:]+".") {
return ErrFail
}
}
case strings.HasPrefix(part, "~ptr:"):
names, err := net.LookupAddr(ip)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, name := range names {
if strings.HasSuffix(name, part[5:]+".") {
return ErrSoftFail
}
}
case strings.HasPrefix(part, "?ptr:"):
names, err := net.LookupAddr(ip)
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
for _, name := range names {
if strings.HasSuffix(name, part[5:]+".") {
return ErrNeutral
}
}
case strings.HasPrefix(part, "exists:"):
ips, err := net.LookupIP(part[7:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if len(ips) > 0 {
return nil
}
case strings.HasPrefix(part, "+exists:"):
ips, err := net.LookupIP(part[8:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if len(ips) > 0 {
return nil
}
case strings.HasPrefix(part, "-exists:"):
ips, err := net.LookupIP(part[8:])
if err != nil {
return &Error{error: err.Error(), errorType: ErrTypeInternal}
}
if len(ips) > 0 {
return ErrFail
}
case strings.HasPrefix(part, "-all"):
return ErrFail
case strings.HasPrefix(part, "~all"):
return ErrSoftFail
case strings.HasPrefix(part, "?all"):
return ErrNeutral
}
}
}
}
return ErrFail
}