185 lines
4.6 KiB
Go
185 lines
4.6 KiB
Go
|
|
package auth
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"database/sql"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/JACKYMYPERSON/hldrCenter/init/database/cache"
|
|||
|
|
"github.com/google/uuid"
|
|||
|
|
_ "modernc.org/sqlite"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type Session struct {
|
|||
|
|
SessionID string // 会话ID(主键)
|
|||
|
|
UserID int // 关联用户ID
|
|||
|
|
CreatedAt time.Time // 创建时间
|
|||
|
|
ExpiredAt time.Time // 过期时间
|
|||
|
|
UserAgent string // 客户端标识
|
|||
|
|
IpAddress string // 客户端IP
|
|||
|
|
IsValid int // 是否有效(1=有效,0=无效)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var db *sql.DB
|
|||
|
|
|
|||
|
|
// InitDB 初始化数据库连接并创建会话表
|
|||
|
|
func InitDB(dbPath string) error {
|
|||
|
|
var err error
|
|||
|
|
// 打开SQLite数据库(文件不存在则自动创建)
|
|||
|
|
db, err = sql.Open("sqlite", dbPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("数据库连接失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证连接有效性
|
|||
|
|
if err := db.Ping(); err != nil {
|
|||
|
|
return fmt.Errorf("数据库ping失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建sessions表(如果不存在)
|
|||
|
|
createTableSQL := `
|
|||
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|||
|
|
session_id TEXT PRIMARY KEY,
|
|||
|
|
user_id INTEGER NOT NULL,
|
|||
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|||
|
|
expired_at DATETIME NOT NULL,
|
|||
|
|
user_agent TEXT,
|
|||
|
|
ip_address TEXT,
|
|||
|
|
is_valid INTEGER NOT NULL DEFAULT 1
|
|||
|
|
);
|
|||
|
|
`
|
|||
|
|
_, err = db.Exec(createTableSQL)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("创建会话表失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var utc8, _ = time.LoadLocation("Asia/Shanghai")
|
|||
|
|
|
|||
|
|
func CreateSession(userID int, userAgent, ipAddress string, expireDuration time.Duration) (*Session, error) {
|
|||
|
|
// 生成唯一session_id(UUID v4)
|
|||
|
|
sessionID := uuid.New().String()
|
|||
|
|
|
|||
|
|
// 获取东八区当前时间
|
|||
|
|
now := time.Now().In(utc8)
|
|||
|
|
// 计算东八区过期时间(当前时间+过期时长)
|
|||
|
|
expiredAt := now.Add(expireDuration)
|
|||
|
|
|
|||
|
|
// 插入会话记录(时间格式化为东八区的RFC3339字符串)
|
|||
|
|
_, err := cache.GlobalDB.Exec(`
|
|||
|
|
INSERT INTO auth_sessions
|
|||
|
|
(session_id, user_id, created_at, expired_at, user_agent, ip_address)
|
|||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|||
|
|
`,
|
|||
|
|
sessionID,
|
|||
|
|
userID,
|
|||
|
|
now.Format(time.RFC3339), // 东八区创建时间
|
|||
|
|
expiredAt.Format(time.RFC3339), // 东八区过期时间
|
|||
|
|
userAgent,
|
|||
|
|
ipAddress)
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("创建会话失败:%v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 返回创建的会话信息(时间均为东八区)
|
|||
|
|
return &Session{
|
|||
|
|
SessionID: sessionID,
|
|||
|
|
UserID: userID,
|
|||
|
|
CreatedAt: now, // 东八区创建时间
|
|||
|
|
ExpiredAt: expiredAt, // 东八区过期时间
|
|||
|
|
UserAgent: userAgent,
|
|||
|
|
IpAddress: ipAddress,
|
|||
|
|
IsValid: 1,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateSession 验证会话有效性(请求时身份校验)
|
|||
|
|
// 参数:session_id(从Cookie或Header中获取)
|
|||
|
|
// 返回:有效则返回会话信息,无效则返回错误
|
|||
|
|
func ValidateSession(sessionID string) (*Session, error) {
|
|||
|
|
if sessionID == "" {
|
|||
|
|
return nil, errors.New("session_id不能为空")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 查询会话记录(同时检查是否有效、是否过期)
|
|||
|
|
var session Session
|
|||
|
|
err := cache.GlobalDB.QueryRow(`
|
|||
|
|
SELECT session_id, user_id, created_at, expired_at, user_agent, ip_address, is_valid
|
|||
|
|
FROM auth_sessions
|
|||
|
|
WHERE session_id = ?
|
|||
|
|
AND is_valid = 1
|
|||
|
|
AND expired_at > CURRENT_TIMESTAMP
|
|||
|
|
`, sessionID).Scan(
|
|||
|
|
&session.SessionID,
|
|||
|
|
&session.UserID,
|
|||
|
|
&session.CreatedAt,
|
|||
|
|
&session.ExpiredAt,
|
|||
|
|
&session.UserAgent,
|
|||
|
|
&session.IpAddress,
|
|||
|
|
&session.IsValid,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 处理查询结果
|
|||
|
|
switch {
|
|||
|
|
case err == sql.ErrNoRows:
|
|||
|
|
return nil, errors.New("会话无效或已过期")
|
|||
|
|
case err != nil:
|
|||
|
|
return nil, fmt.Errorf("验证会话失败:%v", err)
|
|||
|
|
default:
|
|||
|
|
return &session, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// InvalidateSession 注销会话(用户登出时调用)
|
|||
|
|
func InvalidateSession(sessionID string) error {
|
|||
|
|
if sessionID == "" {
|
|||
|
|
return errors.New("session_id不能为空")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 将会话标记为无效(而非删除,便于日志追溯)
|
|||
|
|
result, err := db.Exec(`
|
|||
|
|
UPDATE auth_sessions
|
|||
|
|
SET is_valid = 0
|
|||
|
|
WHERE session_id = ?
|
|||
|
|
`, sessionID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("注销会话失败:%v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否有记录被更新
|
|||
|
|
rowsAffected, err := result.RowsAffected()
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("检查会话状态失败:%v", err)
|
|||
|
|
}
|
|||
|
|
if rowsAffected == 0 {
|
|||
|
|
return errors.New("会话不存在或已失效")
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CleanupExpiredSessions 清理过期/无效会话(建议定时任务调用)
|
|||
|
|
func CleanupExpiredSessions() error {
|
|||
|
|
// 删除已过期或已无效的会话
|
|||
|
|
_, err := db.Exec(`
|
|||
|
|
DELETE FROM auth_sessions
|
|||
|
|
WHERE is_valid = 0
|
|||
|
|
OR expired_at <= CURRENT_TIMESTAMP
|
|||
|
|
`)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("清理过期会话失败:%v", err)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CloseDB 关闭数据库连接
|
|||
|
|
func CloseDB() error {
|
|||
|
|
if db != nil {
|
|||
|
|
return db.Close()
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|