Files
hldrCenter/server/util/auth/auth.go

185 lines
4.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
"database/sql"
"errors"
"fmt"
"time"
"github.com/JACKYMYPERSON/hldrCenter/init/database/cache"
"github.com/google/uuid"
_ "modernc.org/sqlite"
)
type Session struct {
SessionID string // 会话ID主键
UserID int // 关联用户ID
CreatedAt time.Time // 创建时间
ExpiredAt time.Time // 过期时间
UserAgent string // 客户端标识
IpAddress string // 客户端IP
IsValid int // 是否有效1=有效0=无效)
}
var db *sql.DB
// InitDB 初始化数据库连接并创建会话表
func InitDB(dbPath string) error {
var err error
// 打开SQLite数据库文件不存在则自动创建
db, err = sql.Open("sqlite", dbPath)
if err != nil {
return fmt.Errorf("数据库连接失败: %w", err)
}
// 验证连接有效性
if err := db.Ping(); err != nil {
return fmt.Errorf("数据库ping失败: %w", err)
}
// 创建sessions表如果不存在
createTableSQL := `
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expired_at DATETIME NOT NULL,
user_agent TEXT,
ip_address TEXT,
is_valid INTEGER NOT NULL DEFAULT 1
);
`
_, err = db.Exec(createTableSQL)
if err != nil {
return fmt.Errorf("创建会话表失败: %w", err)
}
return nil
}
var utc8, _ = time.LoadLocation("Asia/Shanghai")
func CreateSession(userID int, userAgent, ipAddress string, expireDuration time.Duration) (*Session, error) {
// 生成唯一session_idUUID v4
sessionID := uuid.New().String()
// 获取东八区当前时间
now := time.Now().In(utc8)
// 计算东八区过期时间(当前时间+过期时长)
expiredAt := now.Add(expireDuration)
// 插入会话记录时间格式化为东八区的RFC3339字符串
_, err := cache.GlobalDB.Exec(`
INSERT INTO auth_sessions
(session_id, user_id, created_at, expired_at, user_agent, ip_address)
VALUES (?, ?, ?, ?, ?, ?)
`,
sessionID,
userID,
now.Format(time.RFC3339), // 东八区创建时间
expiredAt.Format(time.RFC3339), // 东八区过期时间
userAgent,
ipAddress)
if err != nil {
return nil, fmt.Errorf("创建会话失败:%v", err)
}
// 返回创建的会话信息(时间均为东八区)
return &Session{
SessionID: sessionID,
UserID: userID,
CreatedAt: now, // 东八区创建时间
ExpiredAt: expiredAt, // 东八区过期时间
UserAgent: userAgent,
IpAddress: ipAddress,
IsValid: 1,
}, nil
}
// ValidateSession 验证会话有效性(请求时身份校验)
// 参数session_id从Cookie或Header中获取
// 返回:有效则返回会话信息,无效则返回错误
func ValidateSession(sessionID string) (*Session, error) {
if sessionID == "" {
return nil, errors.New("session_id不能为空")
}
// 查询会话记录(同时检查是否有效、是否过期)
var session Session
err := cache.GlobalDB.QueryRow(`
SELECT session_id, user_id, created_at, expired_at, user_agent, ip_address, is_valid
FROM auth_sessions
WHERE session_id = ?
AND is_valid = 1
AND expired_at > CURRENT_TIMESTAMP
`, sessionID).Scan(
&session.SessionID,
&session.UserID,
&session.CreatedAt,
&session.ExpiredAt,
&session.UserAgent,
&session.IpAddress,
&session.IsValid,
)
// 处理查询结果
switch {
case err == sql.ErrNoRows:
return nil, errors.New("会话无效或已过期")
case err != nil:
return nil, fmt.Errorf("验证会话失败:%v", err)
default:
return &session, nil
}
}
// InvalidateSession 注销会话(用户登出时调用)
func InvalidateSession(sessionID string) error {
if sessionID == "" {
return errors.New("session_id不能为空")
}
// 将会话标记为无效(而非删除,便于日志追溯)
result, err := db.Exec(`
UPDATE auth_sessions
SET is_valid = 0
WHERE session_id = ?
`, sessionID)
if err != nil {
return fmt.Errorf("注销会话失败:%v", err)
}
// 检查是否有记录被更新
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("检查会话状态失败:%v", err)
}
if rowsAffected == 0 {
return errors.New("会话不存在或已失效")
}
return nil
}
// CleanupExpiredSessions 清理过期/无效会话(建议定时任务调用)
func CleanupExpiredSessions() error {
// 删除已过期或已无效的会话
_, err := db.Exec(`
DELETE FROM auth_sessions
WHERE is_valid = 0
OR expired_at <= CURRENT_TIMESTAMP
`)
if err != nil {
return fmt.Errorf("清理过期会话失败:%v", err)
}
return nil
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
return db.Close()
}
return nil
}