From 754048e3d61a7034a78c70eec7fb3e367ed76953 Mon Sep 17 00:00:00 2001 From: Arzumify Date: Sun, 28 Apr 2024 21:24:50 +0100 Subject: [PATCH] More error handling --- .gitignore | 1 + main.go | 328 ++++++++++++++++++++++++++++++++++++++++++----------- schema.sql | 2 +- 3 files changed, 264 insertions(+), 67 deletions(-) diff --git a/.gitignore b/.gitignore index d18854c..1d824b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ config.ini database.db burgerauth +.idea diff --git a/main.go b/main.go index d474fe2..6bd28c1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/base64" "encoding/hex" + "errors" "fmt" "os" "regexp" @@ -20,7 +21,7 @@ import ( "golang.org/x/crypto/scrypt" ) -const SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +const salt_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" func genSalt(length int) string { if length <= 0 { @@ -35,7 +36,7 @@ func genSalt(length int) string { } for i := range salt { - salt[i] = SALT_CHARS[int(randomBytes[i])%len(SALT_CHARS)] + salt[i] = salt_chars[int(randomBytes[i])%len(salt_chars)] } return string(salt) } @@ -76,12 +77,17 @@ func get_db_connection() *sql.DB { func get_user(id int) (string, string, string, bool) { conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in get_user() defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var created, username, password string err := conn.QueryRow("SELECT created, username, password FROM users WHERE id = ? LIMIT 1", id).Scan(&created, &username, &password) norows := false if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { norows = true } else { fmt.Println("[ERROR] Unknown in get_user() at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -93,12 +99,17 @@ func get_user(id int) (string, string, string, bool) { func get_user_from_session(session string) (int, bool) { conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in get_user_from_session() defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var id int err := conn.QueryRow("SELECT id FROM sessions WHERE session = ? LIMIT 1", session).Scan(&id) norows := false if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { norows = true } else { fmt.Println("[ERROR] Unknown in get_user_from_session() at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -110,12 +121,17 @@ func get_user_from_session(session string) (int, bool) { func check_username_taken(username string) (int, bool) { conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in check_username_taken() defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var id int err := conn.QueryRow("SELECT id FROM users WHERE lower(username) = ? LIMIT 1", username).Scan(&id) norows := false if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { norows = true } else { fmt.Println("[ERROR] Unknown in get_user() at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -134,7 +150,11 @@ func init_db() { } else { fmt.Print("Proceeding will overwrite the database. Proceed? (y/n) ") var answer string - fmt.Scanln(&answer) + _, err := fmt.Scanln(&answer) + if err != nil { + fmt.Println("[ERROR] Unknown while scanning input at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } if answer == "y" || answer == "Y" { if err := generateDB(); err != nil { fmt.Println("[ERROR] Unknown while generating database at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -153,7 +173,12 @@ func generateDB() error { if err != nil { return err } - defer db.Close() + defer func(db *sql.DB) { + err := db.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in generateDB() defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(db) schemaBytes, err := os.ReadFile("schema.sql") if err != nil { @@ -260,7 +285,11 @@ func main() { router.POST("/api/signup", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } username := data["username"].(string) password := data["password"].(string) @@ -280,23 +309,40 @@ func main() { hashedPassword := hash(password, genSalt(16)) conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/signup defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) - conn.Exec("INSERT INTO users (username, password, created, uniqueid) VALUES (?, ?, ?, ?)", username, hashedPassword, strconv.FormatInt(time.Now().Unix(), 10), genSalt(512)) + _, err = conn.Exec("INSERT INTO users (username, password, created, uniqueid) VALUES (?, ?, ?, ?)", username, hashedPassword, strconv.FormatInt(time.Now().Unix(), 10), genSalt(512)) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/signup user creation at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } fmt.Println("[INFO] Added new user at", time.Now().Unix()) userid, _ := check_username_taken(username) randomchars := genSalt(512) - conn.Exec("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)", randomchars, userid, c.Request.Header.Get("User-Agent")) + _, err = conn.Exec("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)", randomchars, userid, c.Request.Header.Get("User-Agent")) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/signup session creation at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } c.JSON(200, gin.H{"key": randomchars}) }) router.POST("/api/login", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } username := data["username"].(string) password := data["password"].(string) @@ -320,13 +366,26 @@ func main() { randomchars := genSalt(512) conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/login defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) - conn.Exec("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)", randomchars, userid, c.Request.Header.Get("User-Agent")) + _, err = conn.Exec("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)", randomchars, userid, c.Request.Header.Get("User-Agent")) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/login session creation at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } if passwordchange == "yes" { hashpassword := hash(newpass, "") - conn.Exec("UPDATE users SET password = ? WHERE username = ?", hashpassword, username) + _, err = conn.Exec("UPDATE users SET password = ? WHERE username = ?", hashpassword, username) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/login password change at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } } c.JSON(200, gin.H{"key": randomchars}) @@ -334,7 +393,11 @@ func main() { router.POST("/api/userinfo", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretkey := data["secretKey"].(string) @@ -359,14 +422,19 @@ func main() { token := strings.Fields(c.Request.Header["Authorization"][0])[1] conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /userinfo defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var blacklisted bool err := conn.QueryRow("SELECT blacklisted FROM blacklist WHERE openid = ? LIMIT 1", token).Scan(&blacklisted) if err == nil { c.JSON(400, gin.H{"error": "Token is in blacklist"}) return } else { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /userinfo blacklist at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) return } @@ -417,19 +485,28 @@ func main() { router.POST("/api/uniqueid", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } token := data["access_token"].(string) conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/uniqueid defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var blacklisted bool - err := conn.QueryRow("SELECT blacklisted FROM blacklist WHERE token = ? LIMIT 1", token).Scan(&blacklisted) + err = conn.QueryRow("SELECT blacklisted FROM blacklist WHERE token = ? LIMIT 1", token).Scan(&blacklisted) if err == nil { c.JSON(400, gin.H{"error": "Token is in blacklist"}) return } else { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/uniqueid blacklist at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) return } @@ -471,7 +548,7 @@ func main() { var uniqueid string err = conn.QueryRow("SELECT uniqueid FROM users WHERE id = ? LIMIT 1", userid).Scan(&uniqueid) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { c.JSON(400, gin.H{"error": "User does not exist"}) return } else { @@ -509,7 +586,15 @@ func main() { var appidcheck, rdiruricheck string - conn.QueryRow("SELECT appId, rdiruri FROM oauth WHERE appId = ? LIMIT 1", appId).Scan(&appidcheck, &rdiruricheck) + err := conn.QueryRow("SELECT appId, rdiruri FROM oauth WHERE appId = ? LIMIT 1", appId).Scan(&appidcheck, &rdiruricheck) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + c.String(401, "OAuth screening failed") + } else { + fmt.Println("[ERROR] Unknown in /api/auth at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + return + } if !(rdiruricheck == redirect_uri) { c.String(401, "Redirect URI does not match") @@ -544,18 +629,26 @@ func main() { secret_token, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, datatemplate2).SignedString([]byte(SECRET_KEY)) randombytes := genSalt(512) - conn.Exec("INSERT INTO logins (appId, secret, nextsecret, code, nextcode, creator, openid, nextopenid, pkce, pkcemethod) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", appId, randombytes, "none", secret_token, "none", userid, jwt_token, "none", code, codemethod) + _, err = conn.Exec("INSERT INTO logins (appId, secret, nextsecret, code, nextcode, creator, openid, nextopenid, pkce, pkcemethod) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", appId, randombytes, "none", secret_token, "none", userid, jwt_token, "none", code, codemethod) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/auth insert at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } if randombytes != "" { c.Redirect(302, redirect_uri+"?code="+randombytes+"&state="+state) } else { c.String(500, "Something went wrong on our end. Please report this bug at https://centrifuge.hectabit.org/hectabit/burgerauth and refer to the docs for more detail. Include this error code: Secretkey not found.") - fmt.Println("[ERROR] Secretkey not found at", time.Now().Unix()) + fmt.Println("[ERROR] Secretkey not found at", strconv.FormatInt(time.Now().Unix(), 10)) } }) router.POST("/api/tokenauth", func(c *gin.Context) { - c.Request.ParseForm() + err := c.Request.ParseForm() + if err != nil { + c.JSON(400, gin.H{"error": "Invalid form data"}) + return + } data := c.Request.Form appId := data.Get("client_id") @@ -571,11 +664,24 @@ func main() { } conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/tokenauth defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) var appidcheck, secretcheck, openid, logincode, pkce, pkcemethod string - conn.QueryRow("SELECT o.appId, o.secret, l.openid, l.code, l.pkce, l.pkcemethod FROM oauth AS o JOIN logins AS l ON o.appId = l.appId WHERE o.appId = ? AND l.secret = ? LIMIT 1;", appId, code).Scan(&appidcheck, &secretcheck, &openid, &logincode, &pkce, &pkcemethod) + err = conn.QueryRow("SELECT o.appId, o.secret, l.openid, l.code, l.pkce, l.pkcemethod FROM oauth AS o JOIN logins AS l ON o.appId = l.appId WHERE o.appId = ? AND l.secret = ? LIMIT 1;", appId, code).Scan(&appidcheck, &secretcheck, &openid, &logincode, &pkce, &pkcemethod) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + c.JSON(401, gin.H{"error": "OAuth screening failed"}) + } else { + fmt.Println("[ERROR] Unknown in /api/tokenauth at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + return + } if appidcheck != appId { c.JSON(401, gin.H{"error": "OAuth screening failed"}) return @@ -608,14 +714,22 @@ func main() { } } - conn.Exec("DELETE FROM logins WHERE code = ?", logincode) + _, err = conn.Exec("DELETE FROM logins WHERE code = ?", logincode) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/tokenauth delete at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } c.JSON(200, gin.H{"access_token": logincode, "token_type": "bearer", "expires_in": 2592000, "id_token": openid}) }) router.POST("/api/deleteauth", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) appId := data["appId"].(string) @@ -627,10 +741,15 @@ func main() { } conn := get_db_connection() - defer conn.Close() - _, err := conn.Exec("DELETE FROM oauth WHERE appId = ? AND creator = ?", appId, id) + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/deleteauth defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) + _, err = conn.Exec("DELETE FROM oauth WHERE appId = ? AND creator = ?", appId, id) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { c.JSON(400, gin.H{"error": "AppID Not found"}) } else { fmt.Println("[ERROR] Unknown in /api/deleteauth at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -643,7 +762,11 @@ func main() { router.POST("/api/newauth", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) appId := data["appId"].(string) @@ -658,12 +781,17 @@ func main() { var testsecret string secret := genSalt(512) conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/newauth defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) for { err := conn.QueryRow("SELECT secret FROM oauth WHERE secret = ?", secret).Scan(&testsecret) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { break } else { fmt.Println("[ERROR] Unknown in /api/newauth secretselect at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -675,9 +803,9 @@ func main() { } } - _, err := conn.Exec("SELECT secret FROM oauth WHERE appId = ?", appId) + _, err = conn.Exec("SELECT secret FROM oauth WHERE appId = ?", appId) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { fmt.Println("[Info] New Oauth source added with ID:", appId) } else { fmt.Println("[ERROR] Unknown in /api/newauth at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -688,14 +816,22 @@ func main() { secret = genSalt(512) } - conn.Exec("INSERT INTO oauth (appId, creator, secret, rdiruri) VALUES (?, ?, ?, ?)", appId, id, secret, rdiruri) + _, err = conn.Exec("INSERT INTO oauth (appId, creator, secret, rdiruri) VALUES (?, ?, ?, ?)", appId, id, secret, rdiruri) + if err != nil { + fmt.Println("[ERROR] Unknown in /api/newauth insert at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + return + } c.JSON(200, gin.H{"key": secret}) }) router.POST("/api/listauth", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) @@ -706,14 +842,24 @@ func main() { } conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/listauth defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) rows, err := conn.Query("SELECT appId FROM oauth WHERE creator = ? ORDER BY creator DESC", id) if err != nil { c.JSON(500, gin.H{"error": "Failed to query database"}) return } - defer rows.Close() + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/listauth rows close at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(rows) var datatemplate []map[string]interface{} for rows.Next() { @@ -735,7 +881,11 @@ func main() { router.POST("/api/deleteaccount", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) @@ -746,11 +896,16 @@ func main() { } conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/deleteaccount defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) - _, err := conn.Exec("DELETE FROM userdata WHERE AND creator = ?", id) + _, err = conn.Exec("DELETE FROM userdata WHERE creator = ?", id) if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/deleteuser userdata at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } @@ -758,7 +913,7 @@ func main() { _, err = conn.Exec("DELETE FROM logins WHERE creator = ?", id) if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/deleteuser logins at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } @@ -766,7 +921,7 @@ func main() { _, err = conn.Exec("DELETE FROM oauth WHERE creator = ?", id) if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/deleteuser oauth at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } @@ -774,7 +929,7 @@ func main() { _, err = conn.Exec("DELETE FROM users WHERE id = ?", id) if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/deleteuser logins at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } @@ -785,7 +940,11 @@ func main() { router.POST("/api/sessions/list", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) @@ -796,16 +955,26 @@ func main() { } conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/sessions/list defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) rows, err := conn.Query("SELECT sessionid, session, device FROM sessions WHERE id = ? ORDER BY id DESC", id) if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/sessions/list at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } } - defer rows.Close() + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/sessions/list rows close at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(rows) var datatemplate []map[string]interface{} for rows.Next() { @@ -831,7 +1000,11 @@ func main() { router.POST("/api/sessions/remove", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } secretKey := data["secretKey"].(string) sessionId := data["sessionId"].(string) @@ -843,10 +1016,15 @@ func main() { } conn := get_db_connection() - defer conn.Close() - _, err := conn.Exec("DELETE FROM sessions WHERE sessionid = ? AND id = ?", sessionId, id) + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/sessions/remove defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) + _, err = conn.Exec("DELETE FROM sessions WHERE sessionid = ? AND id = ?", sessionId, id) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { c.JSON(422, gin.H{"error": "SessionID Not found"}) } else { fmt.Println("[ERROR] Unknown in /api/sessions/remove at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) @@ -859,22 +1037,36 @@ func main() { router.POST("/api/listusers", func(c *gin.Context) { var data map[string]interface{} - c.ShouldBindJSON(&data) + err := c.ShouldBindJSON(&data) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid JSON"}) + return + } masterkey := data["masterkey"].(string) if masterkey == SECRET_KEY { conn := get_db_connection() - defer conn.Close() + defer func(conn *sql.DB) { + err := conn.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/listusers defer at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(conn) rows, err := conn.Query("SELECT * FROM users ORDER BY id DESC") if err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { fmt.Println("[ERROR] Unknown in /api/listusers at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) c.JSON(500, gin.H{"error": "Unknown error occured"}) } } - defer rows.Close() + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + fmt.Println("[ERROR] Unknown in /api/listusers rows close at", strconv.FormatInt(time.Now().Unix(), 10)+":", err) + } + }(rows) var datatemplate []map[string]interface{} for rows.Next() { @@ -897,5 +1089,9 @@ func main() { fmt.Println("[INFO] Server started at", time.Now().Unix()) fmt.Println("[INFO] Welcome to Burgerauth! Today we are running on IP " + HOST + " on port " + strconv.Itoa(PORT) + ".") - router.Run(HOST + ":" + strconv.Itoa(PORT)) + err := router.Run(HOST + ":" + strconv.Itoa(PORT)) + if err != nil { + fmt.Println("[FATAL] Server failed to start at", time.Now().Unix(), err) + return + } } diff --git a/schema.sql b/schema.sql index 38475d9..34a606c 100644 --- a/schema.sql +++ b/schema.sql @@ -23,7 +23,7 @@ CREATE TABLE sessions ( sessionid INTEGER PRIMARY KEY AUTOINCREMENT, session TEXT NOT NULL, id INTEGER NOT NULL, - device TEXT NOT NULL DEFAULT "?" + device TEXT NOT NULL DEFAULT '?' ); CREATE TABLE logins (