69 lines
1.9 KiB
Go
69 lines
1.9 KiB
Go
|
package token
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
|
||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||
|
"github.com/google/uuid"
|
||
|
)
|
||
|
|
||
|
func strArrayToSlice(rawVal string) []string {
|
||
|
var res []string
|
||
|
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func sliceToStrArray(rawVal []string) string {
|
||
|
res, err := json.Marshal(rawVal)
|
||
|
if err != nil {
|
||
|
return "[]"
|
||
|
}
|
||
|
return string(res)
|
||
|
}
|
||
|
|
||
|
type TokenDB interface {
|
||
|
AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error
|
||
|
GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error)
|
||
|
}
|
||
|
|
||
|
type sqlTokenDB struct {
|
||
|
db *sql.DB
|
||
|
}
|
||
|
|
||
|
func (db *sqlTokenDB) AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error {
|
||
|
tx, err := db.db.BeginTx(ctx, nil)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||
|
}
|
||
|
defer func() { _ = tx.Rollback() }()
|
||
|
|
||
|
if _, err := tx.ExecContext(ctx, `INSERT INTO "refresh_token" ("id", "client_id", "user_id", "scopes", "auth_time") VALUES ($1, $2, $3, $4, $5)`, refreshToken.ID, refreshToken.ClientID, refreshToken.UserID, sliceToStrArray(refreshToken.Scopes), refreshToken.AuthTime); err != nil {
|
||
|
return fmt.Errorf("failed to exec query: %w", err)
|
||
|
}
|
||
|
|
||
|
if err := tx.Commit(); err != nil {
|
||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *sqlTokenDB) GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) {
|
||
|
row := db.db.QueryRowContext(ctx, `SELECT "id", "client_id", "user_id", "scopes", "auth_time" FROM "refresh_token" WHERE "id" = ?`, id)
|
||
|
var res model.RefreshToken
|
||
|
var strScopes string
|
||
|
if err := row.Scan(&res.ID, &res.ClientID, &res.UserID, &strScopes, &res.AuthTime); err != nil {
|
||
|
return nil, fmt.Errorf("failed to query DB: %w", err)
|
||
|
}
|
||
|
res.Scopes = strArrayToSlice(strScopes)
|
||
|
return &res, nil
|
||
|
}
|
||
|
|
||
|
func New(db *sql.DB) TokenDB {
|
||
|
return &sqlTokenDB{db: db}
|
||
|
}
|