diff --git a/httpserver/httpserver b/httpserver/httpserver index c8ecef6..15c82a5 100755 Binary files a/httpserver/httpserver and b/httpserver/httpserver differ diff --git a/httpserver/main.go b/httpserver/main.go index 7aa710d..2b56ac0 100644 --- a/httpserver/main.go +++ b/httpserver/main.go @@ -30,11 +30,12 @@ func handleErr(err error, errCode int) { func main() { if len(os.Args) < 2 { - err, errCode := httpserver.StartServer("8000", "./", "0.0.0.0", "1.0") + err, errCode := httpserver.StartServer("8000", "./", "0.0.0.0", "1.0", -1) handleErr(err, errCode) } else { args := os.Args[1:] protocolVer, path, address, port := "1.0", "./", "0.0.0.0", "8000" + throttleRate := int64(-1) for num, arg := range args { if strings.Contains(arg, "-") { if strings.Contains(arg, "-p") || strings.Contains(arg, "--protocol") { @@ -42,7 +43,7 @@ func main() { protocolVer = args[num+1] args = append(args[:num+1], args[num+2:]...) } else { - fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]\nhttpserver: error: argument -p/--protocol: expected one argument") + fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\nhttpserver: error: argument -p/--protocol: expected one argument") os.Exit(2) } } else if strings.Contains(arg, "-d") || strings.Contains(arg, "--directory") { @@ -50,7 +51,7 @@ func main() { path = args[num+1] args = append(args[:num+1], args[num+2:]...) } else { - fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]\nhttpserver: error: argument -d/--directory: expected one argument") + fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\nhttpserver: error: argument -d/--directory: expected one argument") os.Exit(2) } } else if strings.Contains(arg, "-b") || strings.Contains(arg, "--bind") { @@ -58,14 +59,26 @@ func main() { address = args[num+1] args = append(args[:num+1], args[num+2:]...) } else { - fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]\nhttpserver: error: argument -b/--bind: expected one argument") + fmt.Println("usage: httpserver [-h] [--cgi] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\nhttpserver: error: argument -b/--bind: expected one argument") + os.Exit(2) + } + } else if strings.Contains(arg, "-t") || strings.Contains(arg, "--throttle") { + if len(args) > num+1 { + var err error + throttleRate, err = strconv.ParseInt(args[num+1], 10, 64) + if err != nil { + fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\nhttpserver: error: argument -t/--throttle: invalid int value: '" + args[num+1] + "'") + os.Exit(2) + } + } else { + fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\nhttpserver: error: argument -t/--throttle: expected one argument") os.Exit(2) } } else if strings.Contains(arg, "-h") || strings.Contains(arg, "--help") { - fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]\n\npositional arguments:\n port bind to this port (default: 8000)\n\noptions:\n -h, --help show this help message and exit\n -b ADDRESS, --bind ADDRESS\n bind to this address (default: all interfaces)\n -d DIRECTORY, --directory DIRECTORY\n serve this directory (default: current directory)\n -p VERSION, --protocol VERSION\n conform to this HTTP version (default: HTTP/1.0)") + fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]\n\npositional arguments:\n port bind to this port (default: 8000)\n\noptions:\n -h, --help show this help message and exit\n -b ADDRESS, --bind ADDRESS\n bind to this address (default: all interfaces)\n -d DIRECTORY, --directory DIRECTORY\n serve this directory (default: current directory)\n -p VERSION, --protocol VERSION\n conform to this HTTP version (default: HTTP/1.0)") os.Exit(0) } else { - fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]") + fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]") fmt.Println("httpserver: error: unrecognized arguments: " + arg) os.Exit(2) } @@ -74,7 +87,7 @@ func main() { if args[num] == arg { _, err := strconv.Atoi(arg) if err != nil { - fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [port]") + fmt.Println("usage: httpserver [-h] [-b ADDRESS] [-d DIRECTORY] [-p VERSION] [-t RATE] [port]") fmt.Println("httpserver: error: argument port: invalid int value: '" + arg + "'") os.Exit(2) } else { @@ -84,7 +97,7 @@ func main() { } } } - err, errCode := httpserver.StartServer(port, path, address, protocolVer) + err, errCode := httpserver.StartServer(port, path, address, protocolVer, throttleRate) handleErr(err, errCode) } } diff --git a/main.go b/main.go index a34e4cb..9345378 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,7 @@ import ( // 1234567. Very clever, Google. var timeLayout = "02/Jan/2006 15:04:05" -func StartServer(port string, path string, address string, protocolVer string) (error, int) { +func StartServer(port string, path string, address string, protocolVer string, throttleRate int64) (error, int) { var httpServer *http.Server addressPort := address + ":" + port fileServer := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -29,11 +29,18 @@ func StartServer(port string, path string, address string, protocolVer string) ( fmt.Println(ip + " - - [" + time.Now().Format(timeLayout) + "] \"" + r.Method + " " + r.URL.Path + " " + r.Proto + "\" " + "200" + " -") }) + var throttledFileServer http.Handler + if throttleRate != -1 { + throttledFileServer = ThrottleMiddleware(throttleRate)(fileServer) + } else { + throttledFileServer = fileServer + } + fmt.Println("Serving HTTP on", address, "port", port, "(http://"+address+":"+port+"/) ...") if protocolVer == "2.0" || protocolVer == "2" { - httpServer = &http.Server{Addr: addressPort, Handler: fileServer, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler))} + httpServer = &http.Server{Addr: addressPort, Handler: throttledFileServer, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler))} } else { - httpServer = &http.Server{Addr: addressPort, Handler: fileServer} + httpServer = &http.Server{Addr: addressPort, Handler: throttledFileServer} } err := httpServer.ListenAndServe() if err != nil { @@ -45,3 +52,59 @@ func StartServer(port string, path string, address string, protocolVer string) ( } return nil, 0 } + +func ThrottleMiddleware(rate int64, burstSize ...int64) func(http.Handler) http.Handler { + defaultBurstSize := int64(128) + if len(burstSize) > 0 && burstSize[0] > 0 { + defaultBurstSize = burstSize[0] + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tw := &ThrottledResponseWriter{ + writer: w, + rate: rate, + burstSize: defaultBurstSize, + } + next.ServeHTTP(tw, r) + }) + } +} + +type ThrottledResponseWriter struct { + writer http.ResponseWriter + rate int64 + burstSize int64 +} + +func (tw *ThrottledResponseWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + var written int + for written < len(p) { + chunkSize := tw.burstSize + if int64(len(p)-written) < chunkSize { + chunkSize = int64(len(p) - written) + } + + n, err := tw.writer.Write(p[written : written+int(chunkSize)]) + if err != nil { + return written + n, err + } + written += n + + time.Sleep(time.Duration(chunkSize*8*int64(time.Second)/tw.rate) * time.Nanosecond) + } + + return written, nil +} + +func (tw *ThrottledResponseWriter) Header() http.Header { + return tw.writer.Header() +} + +func (tw *ThrottledResponseWriter) WriteHeader(statusCode int) { + tw.writer.WriteHeader(statusCode) +}