369 lines
9.7 KiB
Go
369 lines
9.7 KiB
Go
|
|
package common
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"os"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestConnectingDatabase(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
db := Init()
|
||
|
|
dbPath := GetDBPath()
|
||
|
|
// Test create & close DB
|
||
|
|
_, err := os.Stat(dbPath)
|
||
|
|
asserts.NoError(err, "Db should exist")
|
||
|
|
sqlDB, err := db.DB()
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
asserts.NoError(sqlDB.Ping(), "Db should be able to ping")
|
||
|
|
|
||
|
|
// Test get a connecting from connection pools
|
||
|
|
connection := GetDB()
|
||
|
|
sqlDB, err = connection.DB()
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
asserts.NoError(sqlDB.Ping(), "Db should be able to ping")
|
||
|
|
sqlDB.Close()
|
||
|
|
|
||
|
|
// Test DB exceptions
|
||
|
|
os.Chmod(dbPath, 0000)
|
||
|
|
db = Init()
|
||
|
|
sqlDB, err = db.DB()
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
asserts.Error(sqlDB.Ping(), "Db should not be able to ping")
|
||
|
|
sqlDB.Close()
|
||
|
|
os.Chmod(dbPath, 0644)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestConnectingTestDatabase(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
// Test create & close DB
|
||
|
|
db := TestDBInit()
|
||
|
|
testDBPath := GetTestDBPath()
|
||
|
|
_, err := os.Stat(testDBPath)
|
||
|
|
asserts.NoError(err, "Db should exist")
|
||
|
|
sqlDB, err := db.DB()
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
asserts.NoError(sqlDB.Ping(), "Db should be able to ping")
|
||
|
|
TestDBFree(db)
|
||
|
|
|
||
|
|
// Test close delete DB
|
||
|
|
db = TestDBInit()
|
||
|
|
TestDBFree(db)
|
||
|
|
_, err = os.Stat(testDBPath)
|
||
|
|
|
||
|
|
asserts.Error(err, "Db should not exist")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDBDirCreation(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
// Set a nested path
|
||
|
|
os.Setenv("TEST_DB_PATH", "tmp/nested/test.db")
|
||
|
|
defer os.Unsetenv("TEST_DB_PATH")
|
||
|
|
|
||
|
|
db := TestDBInit()
|
||
|
|
testDBPath := GetTestDBPath()
|
||
|
|
_, err := os.Stat(testDBPath)
|
||
|
|
asserts.NoError(err, "Db should exist in nested directory")
|
||
|
|
TestDBFree(db)
|
||
|
|
|
||
|
|
// Cleanup directory
|
||
|
|
os.RemoveAll("tmp/nested")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDBPathOverride(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
customPath := "./custom_test.db"
|
||
|
|
os.Setenv("TEST_DB_PATH", customPath)
|
||
|
|
defer os.Unsetenv("TEST_DB_PATH")
|
||
|
|
|
||
|
|
asserts.Equal(customPath, GetTestDBPath(), "Should use env var")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRandString(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||
|
|
str := RandString(0)
|
||
|
|
asserts.Equal(str, "", "length should be ''")
|
||
|
|
|
||
|
|
str = RandString(10)
|
||
|
|
asserts.Equal(len(str), 10, "length should be 10")
|
||
|
|
for _, ch := range str {
|
||
|
|
asserts.Contains(letters, ch, "char should be a-z|A-Z|0-9")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRandInt(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
// Test that RandInt returns a value in valid range
|
||
|
|
val := RandInt()
|
||
|
|
asserts.GreaterOrEqual(val, 0, "RandInt should be >= 0")
|
||
|
|
asserts.Less(val, 1000000, "RandInt should be < 1000000")
|
||
|
|
|
||
|
|
// Test multiple calls return different values (statistically)
|
||
|
|
vals := make(map[int]bool)
|
||
|
|
for i := 0; i < 10; i++ {
|
||
|
|
vals[RandInt()] = true
|
||
|
|
}
|
||
|
|
asserts.Greater(len(vals), 1, "RandInt should return varied values")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGenToken(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
token := GenToken(2)
|
||
|
|
|
||
|
|
asserts.IsType(token, string("token"), "token type should be string")
|
||
|
|
asserts.Len(token, 115, "JWT's length should be 115")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGenTokenMultipleUsers(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
token1 := GenToken(1)
|
||
|
|
token2 := GenToken(2)
|
||
|
|
token100 := GenToken(100)
|
||
|
|
|
||
|
|
asserts.NotEqual(token1, token2, "Different user IDs should generate different tokens")
|
||
|
|
asserts.NotEqual(token2, token100, "Different user IDs should generate different tokens")
|
||
|
|
// Token length can vary by 1 character due to timestamp changes
|
||
|
|
asserts.GreaterOrEqual(len(token1), 114, "JWT's length should be >= 114 for user 1")
|
||
|
|
asserts.LessOrEqual(len(token1), 120, "JWT's length should be <= 120 for user 1")
|
||
|
|
asserts.GreaterOrEqual(len(token100), 114, "JWT's length should be >= 114 for user 100")
|
||
|
|
asserts.LessOrEqual(len(token100), 120, "JWT's length should be <= 120 for user 100")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHeaderTokenMock(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||
|
|
token := GenToken(5)
|
||
|
|
HeaderTokenMock(req, 5)
|
||
|
|
|
||
|
|
authHeader := req.Header.Get("Authorization")
|
||
|
|
asserts.Equal(fmt.Sprintf("Token %s", token), authHeader, "Authorization header should be set correctly")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestExtractTokenFromHeader(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
token := "valid.jwt.token"
|
||
|
|
header := fmt.Sprintf("Token %s", token)
|
||
|
|
|
||
|
|
extracted := ExtractTokenFromHeader(header)
|
||
|
|
asserts.Equal(token, extracted, "Should extract token from header")
|
||
|
|
|
||
|
|
invalidHeader := "Bearer " + token
|
||
|
|
extracted = ExtractTokenFromHeader(invalidHeader)
|
||
|
|
asserts.Empty(extracted, "Should return empty for non-Token header")
|
||
|
|
|
||
|
|
shortHeader := "Token"
|
||
|
|
extracted = ExtractTokenFromHeader(shortHeader)
|
||
|
|
asserts.Empty(extracted, "Should return empty for short header")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestVerifyTokenClaims(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
// Test valid token
|
||
|
|
userID := uint(123)
|
||
|
|
token := GenToken(userID)
|
||
|
|
claims, err := VerifyTokenClaims(token)
|
||
|
|
asserts.NoError(err, "VerifyTokenClaims should not error for valid token")
|
||
|
|
asserts.Equal(float64(userID), claims["id"], "Claims should contain correct user ID")
|
||
|
|
|
||
|
|
// Test invalid token
|
||
|
|
_, err = VerifyTokenClaims("invalid.token.string")
|
||
|
|
asserts.Error(err, "VerifyTokenClaims should error for invalid token")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNewValidatorError(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
type Login struct {
|
||
|
|
Username string `form:"username" json:"username" binding:"required,alphanum,min=4,max=255"`
|
||
|
|
Password string `form:"password" json:"password" binding:"required,min=8,max=255"`
|
||
|
|
}
|
||
|
|
|
||
|
|
var requestTests = []struct {
|
||
|
|
bodyData string
|
||
|
|
expectedCode int
|
||
|
|
responseRegexg string
|
||
|
|
msg string
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
`{"username": "wangzitian0","password": "0123456789"}`,
|
||
|
|
http.StatusOK,
|
||
|
|
`{"status":"you are logged in"}`,
|
||
|
|
"valid data and should return StatusCreated",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
`{"username": "wangzitian0","password": "01234567866"}`,
|
||
|
|
http.StatusUnauthorized,
|
||
|
|
`{"errors":{"user":"wrong username or password"}}`,
|
||
|
|
"wrong login status should return StatusUnauthorized",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
`{"username": "wangzitian0","password": "0122"}`,
|
||
|
|
http.StatusUnprocessableEntity,
|
||
|
|
`{"errors":{"Password":"{min: 8}"}}`,
|
||
|
|
"invalid password of too short and should return StatusUnprocessableEntity",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
`{"username": "_wangzitian0","password": "0123456789"}`,
|
||
|
|
http.StatusUnprocessableEntity,
|
||
|
|
`{"errors":{"Username":"{key: alphanum}"}}`,
|
||
|
|
"invalid username of non alphanum and should return StatusUnprocessableEntity",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
r := gin.Default()
|
||
|
|
|
||
|
|
r.POST("/login", func(c *gin.Context) {
|
||
|
|
var json Login
|
||
|
|
if err := Bind(c, &json); err == nil {
|
||
|
|
if json.Username == "wangzitian0" && json.Password == "0123456789" {
|
||
|
|
c.JSON(http.StatusOK, gin.H{"status": "you are logged in"})
|
||
|
|
} else {
|
||
|
|
c.JSON(http.StatusUnauthorized, NewError("user", errors.New("wrong username or password")))
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
c.JSON(http.StatusUnprocessableEntity, NewValidatorError(err))
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
for _, testData := range requestTests {
|
||
|
|
bodyData := testData.bodyData
|
||
|
|
req, err := http.NewRequest("POST", "/login", bytes.NewBufferString(bodyData))
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
asserts.NoError(err)
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
r.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
asserts.Equal(testData.expectedCode, w.Code, "Response Status - "+testData.msg)
|
||
|
|
asserts.Regexp(testData.responseRegexg, w.Body.String(), "Response Content - "+testData.msg)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNewError(t *testing.T) {
|
||
|
|
assert := assert.New(t)
|
||
|
|
|
||
|
|
db := TestDBInit()
|
||
|
|
defer TestDBFree(db)
|
||
|
|
|
||
|
|
type NonExistentTable struct {
|
||
|
|
Field string
|
||
|
|
}
|
||
|
|
// db.AutoMigrate(NonExistentTable{}) // Intentionally skipped to cause error
|
||
|
|
|
||
|
|
err := db.Find(&NonExistentTable{Field: "value"}).Error
|
||
|
|
if err == nil {
|
||
|
|
err = errors.New("no such table: non_existent_tables")
|
||
|
|
}
|
||
|
|
|
||
|
|
commonError := NewError("database", err)
|
||
|
|
assert.IsType(commonError, commonError, "commonError should have right type")
|
||
|
|
// The exact error message might vary by driver, checking key presence is safer, but keeping original assertion style
|
||
|
|
assert.Contains(commonError.Errors, "database", "commonError should contain database key")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDatabaseDirCreation(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
// Test directory creation in Init
|
||
|
|
origDBPath := os.Getenv("DB_PATH")
|
||
|
|
defer os.Setenv("DB_PATH", origDBPath)
|
||
|
|
|
||
|
|
// Create a temp dir path
|
||
|
|
tempDir := "./tmp/test_nested/db"
|
||
|
|
os.Setenv("DB_PATH", tempDir+"/test.db")
|
||
|
|
|
||
|
|
// Clean up before test
|
||
|
|
os.RemoveAll("./tmp/test_nested")
|
||
|
|
|
||
|
|
// Init should create the directory
|
||
|
|
|
||
|
|
db := Init()
|
||
|
|
|
||
|
|
sqlDB, err := db.DB()
|
||
|
|
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
|
||
|
|
asserts.NoError(sqlDB.Ping(), "DB should be created in nested directory")
|
||
|
|
|
||
|
|
// Clean up after test
|
||
|
|
|
||
|
|
sqlDB.Close()
|
||
|
|
|
||
|
|
os.RemoveAll("./tmp/test_nested")
|
||
|
|
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDBInitDirCreation(t *testing.T) {
|
||
|
|
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
// Test directory creation in TestDBInit
|
||
|
|
|
||
|
|
origTestDBPath := os.Getenv("TEST_DB_PATH")
|
||
|
|
|
||
|
|
defer os.Setenv("TEST_DB_PATH", origTestDBPath)
|
||
|
|
|
||
|
|
// Create a temp dir path
|
||
|
|
|
||
|
|
tempDir := "./tmp/test_nested_testdb"
|
||
|
|
|
||
|
|
os.Setenv("TEST_DB_PATH", tempDir+"/test.db")
|
||
|
|
|
||
|
|
// Clean up before test
|
||
|
|
|
||
|
|
os.RemoveAll(tempDir)
|
||
|
|
|
||
|
|
// TestDBInit should create the directory
|
||
|
|
|
||
|
|
db := TestDBInit()
|
||
|
|
|
||
|
|
sqlDB, err := db.DB()
|
||
|
|
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
|
||
|
|
asserts.NoError(sqlDB.Ping(), "Test DB should be created in nested directory")
|
||
|
|
|
||
|
|
// Clean up after test
|
||
|
|
|
||
|
|
TestDBFree(db)
|
||
|
|
|
||
|
|
os.RemoveAll(tempDir)
|
||
|
|
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDatabaseWithCurrentDirectory(t *testing.T) {
|
||
|
|
asserts := assert.New(t)
|
||
|
|
|
||
|
|
// Test with simple filename (no directory)
|
||
|
|
origDBPath := os.Getenv("DB_PATH")
|
||
|
|
defer os.Setenv("DB_PATH", origDBPath)
|
||
|
|
|
||
|
|
os.Setenv("DB_PATH", "test_simple.db")
|
||
|
|
|
||
|
|
// Init should work without directory creation
|
||
|
|
db := Init()
|
||
|
|
sqlDB, err := db.DB()
|
||
|
|
|
||
|
|
asserts.NoError(err, "Should get sql.DB")
|
||
|
|
asserts.NoError(sqlDB.Ping(), "DB should be created in current directory")
|
||
|
|
|
||
|
|
// Clean up
|
||
|
|
sqlDB.Close()
|
||
|
|
os.Remove("test_simple.db")
|
||
|
|
}
|