feat/epic-48-replace-dex #20
8 changed files with 181 additions and 11 deletions
|
@ -68,8 +68,9 @@ func serve() {
|
|||
|
||||
st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey}
|
||||
opConf := op.Config{
|
||||
CryptoKey: key,
|
||||
CodeMethodS256: false,
|
||||
CryptoKey: key,
|
||||
CodeMethodS256: false,
|
||||
GrantTypeRefreshToken: true,
|
||||
}
|
||||
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
|
||||
// slogger :=
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/token"
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
|
||||
)
|
||||
|
||||
|
@ -19,6 +20,7 @@ type Storage interface {
|
|||
AuthRequestStorage() authrequest.AuthRequestDB
|
||||
AuthCodeStorage() authcode.AuthCodeDB
|
||||
UserStorage() user.UserDB
|
||||
TokenStorage() token.TokenDB
|
||||
}
|
||||
|
||||
type sqlStorage struct {
|
||||
|
@ -49,6 +51,10 @@ func (s *sqlStorage) UserStorage() user.UserDB {
|
|||
return user.New(s.db)
|
||||
}
|
||||
|
||||
func (s *sqlStorage) TokenStorage() token.TokenDB {
|
||||
return token.New(s.db)
|
||||
}
|
||||
|
||||
func New(conf config.AppConfig) (Storage, error) {
|
||||
db, err := sql.Open("sqlite3", conf.StorageConfig.File)
|
||||
if err != nil {
|
||||
|
|
68
polyculeconnect/internal/db/token/token.go
Normal file
68
polyculeconnect/internal/db/token/token.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func strArrayToSlice(rawVal string) []string {
|
||||
var res []string
|
||||
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
|
||||
return nil
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func sliceToStrArray(rawVal []string) string {
|
||||
res, err := json.Marshal(rawVal)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(res)
|
||||
}
|
||||
|
||||
type TokenDB interface {
|
||||
AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error
|
||||
GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error)
|
||||
}
|
||||
|
||||
type sqlTokenDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (db *sqlTokenDB) AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error {
|
||||
tx, err := db.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `INSERT INTO "refresh_token" ("id", "client_id", "user_id", "scopes", "auth_time") VALUES ($1, $2, $3, $4, $5)`, refreshToken.ID, refreshToken.ClientID, refreshToken.UserID, sliceToStrArray(refreshToken.Scopes), refreshToken.AuthTime); err != nil {
|
||||
return fmt.Errorf("failed to exec query: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *sqlTokenDB) GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) {
|
||||
row := db.db.QueryRowContext(ctx, `SELECT "id", "client_id", "user_id", "scopes", "auth_time" FROM "refresh_token" WHERE "id" = ?`, id)
|
||||
var res model.RefreshToken
|
||||
var strScopes string
|
||||
if err := row.Scan(&res.ID, &res.ClientID, &res.UserID, &strScopes, &res.AuthTime); err != nil {
|
||||
return nil, fmt.Errorf("failed to query DB: %w", err)
|
||||
}
|
||||
res.Scopes = strArrayToSlice(strScopes)
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func New(db *sql.DB) TokenDB {
|
||||
return &sqlTokenDB{db: db}
|
||||
}
|
|
@ -17,5 +17,52 @@ type Token struct {
|
|||
|
||||
type RefreshToken struct {
|
||||
ID uuid.UUID
|
||||
ClientID string
|
||||
UserID string
|
||||
Scopes []string
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
func (t RefreshToken) Request() *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
userID: t.UserID,
|
||||
clientID: t.ClientID,
|
||||
scopes: t.Scopes,
|
||||
authTime: t.AuthTime,
|
||||
}
|
||||
}
|
||||
|
||||
type RefreshTokenRequest struct {
|
||||
clientID string
|
||||
authTime time.Time
|
||||
userID string
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetAMR() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetAudience() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetAuthTime() time.Time {
|
||||
return r.authTime
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetClientID() string {
|
||||
return r.clientID
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetScopes() []string {
|
||||
return r.scopes
|
||||
}
|
||||
|
||||
func (r RefreshTokenRequest) GetSubject() string {
|
||||
return r.userID
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) {
|
||||
r.scopes = scopes
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
@ -135,14 +136,20 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
|
|||
|
||||
func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {
|
||||
accessTokenUUID := uuid.New()
|
||||
var authTime time.Time
|
||||
|
||||
// we are expecting our own request model
|
||||
authRequest, ok := req.(*model.AuthRequest)
|
||||
if !ok {
|
||||
switch typedReq := req.(type) {
|
||||
case *model.AuthRequest:
|
||||
logger.L.Debug("Creating access token for new authentication")
|
||||
authTime = typedReq.AuthTime
|
||||
case *model.RefreshTokenRequest:
|
||||
logger.L.Debug("Handling refresh token request")
|
||||
authTime = typedReq.GetAuthTime()
|
||||
default:
|
||||
logger.L.Errorf("Unexpected type for request %v", err)
|
||||
return "", time.Time{}, errors.New("failed to parse auth request")
|
||||
}
|
||||
|
||||
authTime := authRequest.AuthTime.UTC()
|
||||
expiration = authTime.Add(5 * time.Minute)
|
||||
|
||||
// token := model.Token{
|
||||
|
@ -160,14 +167,23 @@ func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (a
|
|||
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) {
|
||||
accessTokenUUID := uuid.New()
|
||||
refreshTokenUUID := uuid.New()
|
||||
var authTime time.Time
|
||||
var clientID string
|
||||
|
||||
// we are expecting our own request model
|
||||
authRequest, ok := request.(*model.AuthRequest)
|
||||
if !ok {
|
||||
switch typedReq := request.(type) {
|
||||
case *model.AuthRequest:
|
||||
logger.L.Debug("Creating access token for new authentication")
|
||||
clientID = typedReq.ClientID
|
||||
authTime = typedReq.AuthTime
|
||||
case *model.RefreshTokenRequest:
|
||||
logger.L.Debug("Handling refresh token request")
|
||||
clientID = typedReq.GetClientID()
|
||||
authTime = typedReq.GetAuthTime()
|
||||
default:
|
||||
logger.L.Errorf("Unexpected type for request %v", err)
|
||||
return "", "", time.Time{}, errors.New("failed to parse auth request")
|
||||
}
|
||||
|
||||
authTime := authRequest.AuthTime.UTC()
|
||||
expiration = authTime.Add(5 * time.Minute)
|
||||
|
||||
// token := model.Token{
|
||||
|
@ -178,12 +194,34 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T
|
|||
// Audiences: request.GetAudience(),
|
||||
// Scopes: request.GetScopes(),
|
||||
// }
|
||||
refreshToken := model.RefreshToken{
|
||||
ID: refreshTokenUUID,
|
||||
ClientID: clientID,
|
||||
UserID: request.GetSubject(),
|
||||
Scopes: request.GetScopes(),
|
||||
AuthTime: authTime,
|
||||
}
|
||||
if err := s.LocalStorage.TokenStorage().AddRefreshToken(ctx, &refreshToken); err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("failed to insert token in DB: %w", err)
|
||||
}
|
||||
|
||||
return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil
|
||||
}
|
||||
|
||||
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) {
|
||||
return nil, ErrNotImplemented("TokenRequestByRefreshToken")
|
||||
parsedID, err := uuid.Parse(refreshTokenID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid format for refresh token id: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, err := s.LocalStorage.TokenStorage().GetRefreshTokenByID(ctx, parsedID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, op.ErrInvalidRefreshToken
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get refresh token: %w", err)
|
||||
}
|
||||
return refreshToken.Request(), nil
|
||||
}
|
||||
|
||||
func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error {
|
||||
|
|
1
polyculeconnect/migrations/1_tokens.down.sql
Normal file
1
polyculeconnect/migrations/1_tokens.down.sql
Normal file
|
@ -0,0 +1 @@
|
|||
DROP TABLE refresh_token;
|
9
polyculeconnect/migrations/1_tokens.up.sql
Normal file
9
polyculeconnect/migrations/1_tokens.up.sql
Normal file
|
@ -0,0 +1,9 @@
|
|||
CREATE TABLE refresh_token (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
client_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
scopes blob NOT NULL, -- list of strings, json-encoded
|
||||
auth_time timestamp NOT NULL,
|
||||
FOREIGN KEY(client_id) REFERENCES client(id),
|
||||
FOREIGN KEY(user_id) REFERENCES user(id)
|
||||
);
|
Binary file not shown.
Loading…
Reference in a new issue