Add refresh token flow (#42)
Some checks failed
/ docker-build-only (push) Failing after 34s
/ go-test (push) Failing after 1m4s

This commit is contained in:
Melora Hugues 2024-10-19 16:21:04 +02:00
parent 92d014965b
commit f0011e183d
8 changed files with 181 additions and 11 deletions

View file

@ -68,8 +68,9 @@ func serve() {
st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey} st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey}
opConf := op.Config{ opConf := op.Config{
CryptoKey: key, CryptoKey: key,
CodeMethodS256: false, CodeMethodS256: false,
GrantTypeRefreshToken: true,
} }
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil)) slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
// slogger := // slogger :=

View file

@ -9,6 +9,7 @@ import (
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/token"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
) )
@ -19,6 +20,7 @@ type Storage interface {
AuthRequestStorage() authrequest.AuthRequestDB AuthRequestStorage() authrequest.AuthRequestDB
AuthCodeStorage() authcode.AuthCodeDB AuthCodeStorage() authcode.AuthCodeDB
UserStorage() user.UserDB UserStorage() user.UserDB
TokenStorage() token.TokenDB
} }
type sqlStorage struct { type sqlStorage struct {
@ -49,6 +51,10 @@ func (s *sqlStorage) UserStorage() user.UserDB {
return user.New(s.db) return user.New(s.db)
} }
func (s *sqlStorage) TokenStorage() token.TokenDB {
return token.New(s.db)
}
func New(conf config.AppConfig) (Storage, error) { func New(conf config.AppConfig) (Storage, error) {
db, err := sql.Open("sqlite3", conf.StorageConfig.File) db, err := sql.Open("sqlite3", conf.StorageConfig.File)
if err != nil { if err != nil {

View file

@ -0,0 +1,68 @@
package token
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
"github.com/google/uuid"
)
func strArrayToSlice(rawVal string) []string {
var res []string
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
return nil
}
return res
}
func sliceToStrArray(rawVal []string) string {
res, err := json.Marshal(rawVal)
if err != nil {
return "[]"
}
return string(res)
}
type TokenDB interface {
AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error
GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error)
}
type sqlTokenDB struct {
db *sql.DB
}
func (db *sqlTokenDB) AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if _, err := tx.ExecContext(ctx, `INSERT INTO "refresh_token" ("id", "client_id", "user_id", "scopes", "auth_time") VALUES ($1, $2, $3, $4, $5)`, refreshToken.ID, refreshToken.ClientID, refreshToken.UserID, sliceToStrArray(refreshToken.Scopes), refreshToken.AuthTime); err != nil {
return fmt.Errorf("failed to exec query: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func (db *sqlTokenDB) GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) {
row := db.db.QueryRowContext(ctx, `SELECT "id", "client_id", "user_id", "scopes", "auth_time" FROM "refresh_token" WHERE "id" = ?`, id)
var res model.RefreshToken
var strScopes string
if err := row.Scan(&res.ID, &res.ClientID, &res.UserID, &strScopes, &res.AuthTime); err != nil {
return nil, fmt.Errorf("failed to query DB: %w", err)
}
res.Scopes = strArrayToSlice(strScopes)
return &res, nil
}
func New(db *sql.DB) TokenDB {
return &sqlTokenDB{db: db}
}

View file

@ -17,5 +17,52 @@ type Token struct {
type RefreshToken struct { type RefreshToken struct {
ID uuid.UUID ID uuid.UUID
ClientID string
UserID string
Scopes []string
AuthTime time.Time AuthTime time.Time
} }
func (t RefreshToken) Request() *RefreshTokenRequest {
return &RefreshTokenRequest{
userID: t.UserID,
clientID: t.ClientID,
scopes: t.Scopes,
authTime: t.AuthTime,
}
}
type RefreshTokenRequest struct {
clientID string
authTime time.Time
userID string
scopes []string
}
func (r RefreshTokenRequest) GetAMR() []string {
return []string{}
}
func (r RefreshTokenRequest) GetAudience() []string {
return []string{}
}
func (r RefreshTokenRequest) GetAuthTime() time.Time {
return r.authTime
}
func (r RefreshTokenRequest) GetClientID() string {
return r.clientID
}
func (r RefreshTokenRequest) GetScopes() []string {
return r.scopes
}
func (r RefreshTokenRequest) GetSubject() string {
return r.userID
}
func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) {
r.scopes = scopes
}

View file

@ -2,6 +2,7 @@ package storage
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -135,14 +136,20 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) { func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {
accessTokenUUID := uuid.New() accessTokenUUID := uuid.New()
var authTime time.Time
// we are expecting our own request model switch typedReq := req.(type) {
authRequest, ok := req.(*model.AuthRequest) case *model.AuthRequest:
if !ok { logger.L.Debug("Creating access token for new authentication")
authTime = typedReq.AuthTime
case *model.RefreshTokenRequest:
logger.L.Debug("Handling refresh token request")
authTime = typedReq.GetAuthTime()
default:
logger.L.Errorf("Unexpected type for request %v", err)
return "", time.Time{}, errors.New("failed to parse auth request") return "", time.Time{}, errors.New("failed to parse auth request")
} }
authTime := authRequest.AuthTime.UTC()
expiration = authTime.Add(5 * time.Minute) expiration = authTime.Add(5 * time.Minute)
// token := model.Token{ // token := model.Token{
@ -160,14 +167,23 @@ func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (a
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) { func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) {
accessTokenUUID := uuid.New() accessTokenUUID := uuid.New()
refreshTokenUUID := uuid.New() refreshTokenUUID := uuid.New()
var authTime time.Time
var clientID string
// we are expecting our own request model switch typedReq := request.(type) {
authRequest, ok := request.(*model.AuthRequest) case *model.AuthRequest:
if !ok { logger.L.Debug("Creating access token for new authentication")
clientID = typedReq.ClientID
authTime = typedReq.AuthTime
case *model.RefreshTokenRequest:
logger.L.Debug("Handling refresh token request")
clientID = typedReq.GetClientID()
authTime = typedReq.GetAuthTime()
default:
logger.L.Errorf("Unexpected type for request %v", err)
return "", "", time.Time{}, errors.New("failed to parse auth request") return "", "", time.Time{}, errors.New("failed to parse auth request")
} }
authTime := authRequest.AuthTime.UTC()
expiration = authTime.Add(5 * time.Minute) expiration = authTime.Add(5 * time.Minute)
// token := model.Token{ // token := model.Token{
@ -178,12 +194,34 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T
// Audiences: request.GetAudience(), // Audiences: request.GetAudience(),
// Scopes: request.GetScopes(), // Scopes: request.GetScopes(),
// } // }
refreshToken := model.RefreshToken{
ID: refreshTokenUUID,
ClientID: clientID,
UserID: request.GetSubject(),
Scopes: request.GetScopes(),
AuthTime: authTime,
}
if err := s.LocalStorage.TokenStorage().AddRefreshToken(ctx, &refreshToken); err != nil {
return "", "", time.Time{}, fmt.Errorf("failed to insert token in DB: %w", err)
}
return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil
} }
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) { func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) {
return nil, ErrNotImplemented("TokenRequestByRefreshToken") parsedID, err := uuid.Parse(refreshTokenID)
if err != nil {
return nil, fmt.Errorf("invalid format for refresh token id: %w", err)
}
refreshToken, err := s.LocalStorage.TokenStorage().GetRefreshTokenByID(ctx, parsedID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, op.ErrInvalidRefreshToken
}
return nil, fmt.Errorf("failed to get refresh token: %w", err)
}
return refreshToken.Request(), nil
} }
func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error { func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error {

View file

@ -0,0 +1 @@
DROP TABLE refresh_token;

View file

@ -0,0 +1,9 @@
CREATE TABLE refresh_token (
id TEXT NOT NULL PRIMARY KEY,
client_id TEXT NOT NULL,
user_id TEXT NOT NULL,
scopes blob NOT NULL, -- list of strings, json-encoded
auth_time timestamp NOT NULL,
FOREIGN KEY(client_id) REFERENCES client(id),
FOREIGN KEY(user_id) REFERENCES user(id)
);

Binary file not shown.