package spf import ( "net" "strings" ) type ErrorType int const ( ErrTypeNeutral ErrorType = iota ErrTypeNone ErrTypeFail ErrTypeSoftFail ErrTypeInternal ) var ( ErrNeutral = &Error{error: "record returned neutral", errorType: ErrTypeNeutral} ErrFail = &Error{error: "record returned explicit fail", errorType: ErrTypeFail} ErrSoftFail = &Error{error: "record returned soft fail", errorType: ErrTypeSoftFail} ErrNone = &Error{error: "no record returned", errorType: ErrTypeNone} ) type Error struct { error string errorType ErrorType } func (e *Error) Error() string { return e.error } func (e *Error) Type() ErrorType { return e.errorType } func CheckIP(ip string, domain string) *Error { txt, err := net.LookupTXT(domain) if err != nil { return ErrNone } for _, record := range txt { if strings.HasPrefix(record, "v=spf1") { parts := strings.Split(record, " ") for _, part := range parts { switch { case strings.HasPrefix(part, "all") || strings.HasPrefix(part, "+all"): return nil case strings.HasPrefix(part, "redirect="): err := CheckIP(ip, part[9:]) if err != nil { if err.Type() != ErrTypeInternal { return err } } else { return nil } case strings.HasPrefix(part, "ip4:") || strings.HasPrefix(part, "ip6:"): if strings.Contains(part, "/") { _, ipNet, err := net.ParseCIDR(part[4:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if ipNet.Contains(net.ParseIP(ip)) { return nil } } else { if part[4:] == ip { return nil } } case strings.HasPrefix(part, "+ip4:") || strings.HasPrefix(part, "+ip6:"): if strings.Contains(part, "/") { _, ipNet, err := net.ParseCIDR(part[5:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if ipNet.Contains(net.ParseIP(ip)) { return nil } } else { if part[5:] == ip { return nil } } case strings.HasPrefix(part, "-ip4:") || strings.HasPrefix(part, "-ip6:"): if strings.Contains(part, "/") { _, ipNet, err := net.ParseCIDR(part[5:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if ipNet.Contains(net.ParseIP(ip)) { return ErrFail } } else { if part[5:] == ip { return ErrFail } } case strings.HasPrefix(part, "~ip4:") || strings.HasPrefix(part, "~ip6:"): if strings.Contains(part, "/") { _, ipNet, err := net.ParseCIDR(part[5:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if ipNet.Contains(net.ParseIP(ip)) { return ErrSoftFail } } else { if part[5:] == ip { return ErrSoftFail } } case strings.HasPrefix(part, "?ip4:") || strings.HasPrefix(part, "?ip6:"): if strings.Contains(part, "/") { _, ipNet, err := net.ParseCIDR(part[5:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if ipNet.Contains(net.ParseIP(ip)) { return ErrNeutral } } else { if part[5:] == ip { return ErrNeutral } } case strings.HasPrefix(part, "include:"): err := CheckIP(ip, part[8:]) if err != nil { if err.Type() != ErrTypeInternal { return err } } else { return nil } case strings.HasPrefix(part, "+include:"): err := CheckIP(ip, part[9:]) if err != nil { if err.Type() != ErrTypeInternal { return err } } else { return nil } case strings.HasPrefix(part, "-include:"): err := CheckIP(ip, part[9:]) if err == nil { return ErrFail } case strings.HasPrefix(part, "~include:"): err := CheckIP(ip, part[9:]) if err == nil { return ErrSoftFail } case strings.HasPrefix(part, "?include:"): err := CheckIP(ip, part[9:]) if err == nil { return ErrNeutral } case part == "a" || part == "+a": ips, err := net.LookupIP(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } case part == "-a": ips, err := net.LookupIP(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrFail } } case part == "~a": ips, err := net.LookupIP(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrSoftFail } } case part == "?a": ips, err := net.LookupIP(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrNeutral } } case strings.HasPrefix(part, "a:"): ips, err := net.LookupIP(part[2:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } case strings.HasPrefix(part, "+a:"): ips, err := net.LookupIP(part[3:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } case strings.HasPrefix(part, "-a:"): ips, err := net.LookupIP(part[3:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrFail } } case strings.HasPrefix(part, "~a:"): ips, err := net.LookupIP(part[3:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrSoftFail } } case strings.HasPrefix(part, "?a:"): ips, err := net.LookupIP(part[3:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrNeutral } } case strings.HasPrefix(part, "mx") || strings.HasPrefix(part, "+mx"): mxs, err := net.LookupMX(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } } case strings.HasPrefix(part, "-mx"): mxs, err := net.LookupMX(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrFail } } } case strings.HasPrefix(part, "~mx"): mxs, err := net.LookupMX(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrSoftFail } } } case strings.HasPrefix(part, "?mx"): mxs, err := net.LookupMX(domain) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrNeutral } } } case strings.HasPrefix(part, "mx:"): mxs, err := net.LookupMX(part[3:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } } case strings.HasPrefix(part, "+mx:"): mxs, err := net.LookupMX(part[4:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return nil } } } case strings.HasPrefix(part, "-mx:"): mxs, err := net.LookupMX(part[4:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrFail } } } case strings.HasPrefix(part, "~mx:"): mxs, err := net.LookupMX(part[4:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrSoftFail } } } case strings.HasPrefix(part, "?mx:"): mxs, err := net.LookupMX(part[4:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, mx := range mxs { ips, err := net.LookupIP(mx.Host) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, ipCheck := range ips { if ipCheck.String() == ip { return ErrNeutral } } } case strings.HasPrefix(part, "ptr:"): names, err := net.LookupAddr(ip) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, name := range names { if strings.HasSuffix(name, part[4:]+".") { return nil } } case strings.HasPrefix(part, "+ptr:"): names, err := net.LookupAddr(ip) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, name := range names { if strings.HasSuffix(name, part[5:]+".") { return nil } } case strings.HasPrefix(part, "-ptr:"): names, err := net.LookupAddr(ip) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, name := range names { if strings.HasSuffix(name, part[5:]+".") { return ErrFail } } case strings.HasPrefix(part, "~ptr:"): names, err := net.LookupAddr(ip) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, name := range names { if strings.HasSuffix(name, part[5:]+".") { return ErrSoftFail } } case strings.HasPrefix(part, "?ptr:"): names, err := net.LookupAddr(ip) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } for _, name := range names { if strings.HasSuffix(name, part[5:]+".") { return ErrNeutral } } case strings.HasPrefix(part, "exists:"): ips, err := net.LookupIP(part[7:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if len(ips) > 0 { return nil } case strings.HasPrefix(part, "+exists:"): ips, err := net.LookupIP(part[8:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if len(ips) > 0 { return nil } case strings.HasPrefix(part, "-exists:"): ips, err := net.LookupIP(part[8:]) if err != nil { return &Error{error: err.Error(), errorType: ErrTypeInternal} } if len(ips) > 0 { return ErrFail } case strings.HasPrefix(part, "-all"): return ErrFail case strings.HasPrefix(part, "~all"): return ErrSoftFail case strings.HasPrefix(part, "?all"): return ErrNeutral } } } } return ErrFail }