Compare commits
2 commits
c741400583
...
f0011e183d
Author | SHA1 | Date | |
---|---|---|---|
f0011e183d | |||
92d014965b |
9 changed files with 182 additions and 12 deletions
|
@ -68,8 +68,9 @@ func serve() {
|
||||||
|
|
||||||
st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey}
|
st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey}
|
||||||
opConf := op.Config{
|
opConf := op.Config{
|
||||||
CryptoKey: key,
|
CryptoKey: key,
|
||||||
CodeMethodS256: false,
|
CodeMethodS256: false,
|
||||||
|
GrantTypeRefreshToken: true,
|
||||||
}
|
}
|
||||||
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
|
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
|
||||||
// slogger :=
|
// slogger :=
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/token"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ type Storage interface {
|
||||||
AuthRequestStorage() authrequest.AuthRequestDB
|
AuthRequestStorage() authrequest.AuthRequestDB
|
||||||
AuthCodeStorage() authcode.AuthCodeDB
|
AuthCodeStorage() authcode.AuthCodeDB
|
||||||
UserStorage() user.UserDB
|
UserStorage() user.UserDB
|
||||||
|
TokenStorage() token.TokenDB
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlStorage struct {
|
type sqlStorage struct {
|
||||||
|
@ -49,6 +51,10 @@ func (s *sqlStorage) UserStorage() user.UserDB {
|
||||||
return user.New(s.db)
|
return user.New(s.db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *sqlStorage) TokenStorage() token.TokenDB {
|
||||||
|
return token.New(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
func New(conf config.AppConfig) (Storage, error) {
|
func New(conf config.AppConfig) (Storage, error) {
|
||||||
db, err := sql.Open("sqlite3", conf.StorageConfig.File)
|
db, err := sql.Open("sqlite3", conf.StorageConfig.File)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
68
polyculeconnect/internal/db/token/token.go
Normal file
68
polyculeconnect/internal/db/token/token.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func strArrayToSlice(rawVal string) []string {
|
||||||
|
var res []string
|
||||||
|
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func sliceToStrArray(rawVal []string) string {
|
||||||
|
res, err := json.Marshal(rawVal)
|
||||||
|
if err != nil {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
return string(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenDB interface {
|
||||||
|
AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error
|
||||||
|
GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlTokenDB struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *sqlTokenDB) AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error {
|
||||||
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
if _, err := tx.ExecContext(ctx, `INSERT INTO "refresh_token" ("id", "client_id", "user_id", "scopes", "auth_time") VALUES ($1, $2, $3, $4, $5)`, refreshToken.ID, refreshToken.ClientID, refreshToken.UserID, sliceToStrArray(refreshToken.Scopes), refreshToken.AuthTime); err != nil {
|
||||||
|
return fmt.Errorf("failed to exec query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *sqlTokenDB) GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) {
|
||||||
|
row := db.db.QueryRowContext(ctx, `SELECT "id", "client_id", "user_id", "scopes", "auth_time" FROM "refresh_token" WHERE "id" = ?`, id)
|
||||||
|
var res model.RefreshToken
|
||||||
|
var strScopes string
|
||||||
|
if err := row.Scan(&res.ID, &res.ClientID, &res.UserID, &strScopes, &res.AuthTime); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query DB: %w", err)
|
||||||
|
}
|
||||||
|
res.Scopes = strArrayToSlice(strScopes)
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(db *sql.DB) TokenDB {
|
||||||
|
return &sqlTokenDB{db: db}
|
||||||
|
}
|
|
@ -22,7 +22,7 @@ const getUserQuery = `
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
const insertUserQuery = `
|
const insertUserQuery = `
|
||||||
INSERT INTO user (id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified)
|
INSERT OR REPLACE INTO user (id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
|
@ -17,5 +17,52 @@ type Token struct {
|
||||||
|
|
||||||
type RefreshToken struct {
|
type RefreshToken struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
|
ClientID string
|
||||||
|
UserID string
|
||||||
|
Scopes []string
|
||||||
AuthTime time.Time
|
AuthTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t RefreshToken) Request() *RefreshTokenRequest {
|
||||||
|
return &RefreshTokenRequest{
|
||||||
|
userID: t.UserID,
|
||||||
|
clientID: t.ClientID,
|
||||||
|
scopes: t.Scopes,
|
||||||
|
authTime: t.AuthTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshTokenRequest struct {
|
||||||
|
clientID string
|
||||||
|
authTime time.Time
|
||||||
|
userID string
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetAMR() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetAudience() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetAuthTime() time.Time {
|
||||||
|
return r.authTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetClientID() string {
|
||||||
|
return r.clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetScopes() []string {
|
||||||
|
return r.scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RefreshTokenRequest) GetSubject() string {
|
||||||
|
return r.userID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) {
|
||||||
|
r.scopes = scopes
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
@ -135,14 +136,20 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
|
||||||
|
|
||||||
func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {
|
func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {
|
||||||
accessTokenUUID := uuid.New()
|
accessTokenUUID := uuid.New()
|
||||||
|
var authTime time.Time
|
||||||
|
|
||||||
// we are expecting our own request model
|
switch typedReq := req.(type) {
|
||||||
authRequest, ok := req.(*model.AuthRequest)
|
case *model.AuthRequest:
|
||||||
if !ok {
|
logger.L.Debug("Creating access token for new authentication")
|
||||||
|
authTime = typedReq.AuthTime
|
||||||
|
case *model.RefreshTokenRequest:
|
||||||
|
logger.L.Debug("Handling refresh token request")
|
||||||
|
authTime = typedReq.GetAuthTime()
|
||||||
|
default:
|
||||||
|
logger.L.Errorf("Unexpected type for request %v", err)
|
||||||
return "", time.Time{}, errors.New("failed to parse auth request")
|
return "", time.Time{}, errors.New("failed to parse auth request")
|
||||||
}
|
}
|
||||||
|
|
||||||
authTime := authRequest.AuthTime.UTC()
|
|
||||||
expiration = authTime.Add(5 * time.Minute)
|
expiration = authTime.Add(5 * time.Minute)
|
||||||
|
|
||||||
// token := model.Token{
|
// token := model.Token{
|
||||||
|
@ -160,14 +167,23 @@ func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (a
|
||||||
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) {
|
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) {
|
||||||
accessTokenUUID := uuid.New()
|
accessTokenUUID := uuid.New()
|
||||||
refreshTokenUUID := uuid.New()
|
refreshTokenUUID := uuid.New()
|
||||||
|
var authTime time.Time
|
||||||
|
var clientID string
|
||||||
|
|
||||||
// we are expecting our own request model
|
switch typedReq := request.(type) {
|
||||||
authRequest, ok := request.(*model.AuthRequest)
|
case *model.AuthRequest:
|
||||||
if !ok {
|
logger.L.Debug("Creating access token for new authentication")
|
||||||
|
clientID = typedReq.ClientID
|
||||||
|
authTime = typedReq.AuthTime
|
||||||
|
case *model.RefreshTokenRequest:
|
||||||
|
logger.L.Debug("Handling refresh token request")
|
||||||
|
clientID = typedReq.GetClientID()
|
||||||
|
authTime = typedReq.GetAuthTime()
|
||||||
|
default:
|
||||||
|
logger.L.Errorf("Unexpected type for request %v", err)
|
||||||
return "", "", time.Time{}, errors.New("failed to parse auth request")
|
return "", "", time.Time{}, errors.New("failed to parse auth request")
|
||||||
}
|
}
|
||||||
|
|
||||||
authTime := authRequest.AuthTime.UTC()
|
|
||||||
expiration = authTime.Add(5 * time.Minute)
|
expiration = authTime.Add(5 * time.Minute)
|
||||||
|
|
||||||
// token := model.Token{
|
// token := model.Token{
|
||||||
|
@ -178,12 +194,34 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T
|
||||||
// Audiences: request.GetAudience(),
|
// Audiences: request.GetAudience(),
|
||||||
// Scopes: request.GetScopes(),
|
// Scopes: request.GetScopes(),
|
||||||
// }
|
// }
|
||||||
|
refreshToken := model.RefreshToken{
|
||||||
|
ID: refreshTokenUUID,
|
||||||
|
ClientID: clientID,
|
||||||
|
UserID: request.GetSubject(),
|
||||||
|
Scopes: request.GetScopes(),
|
||||||
|
AuthTime: authTime,
|
||||||
|
}
|
||||||
|
if err := s.LocalStorage.TokenStorage().AddRefreshToken(ctx, &refreshToken); err != nil {
|
||||||
|
return "", "", time.Time{}, fmt.Errorf("failed to insert token in DB: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil
|
return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) {
|
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) {
|
||||||
return nil, ErrNotImplemented("TokenRequestByRefreshToken")
|
parsedID, err := uuid.Parse(refreshTokenID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid format for refresh token id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, err := s.LocalStorage.TokenStorage().GetRefreshTokenByID(ctx, parsedID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, op.ErrInvalidRefreshToken
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get refresh token: %w", err)
|
||||||
|
}
|
||||||
|
return refreshToken.Request(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error {
|
func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error {
|
||||||
|
|
1
polyculeconnect/migrations/1_tokens.down.sql
Normal file
1
polyculeconnect/migrations/1_tokens.down.sql
Normal file
|
@ -0,0 +1 @@
|
||||||
|
DROP TABLE refresh_token;
|
9
polyculeconnect/migrations/1_tokens.up.sql
Normal file
9
polyculeconnect/migrations/1_tokens.up.sql
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
CREATE TABLE refresh_token (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
client_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
scopes blob NOT NULL, -- list of strings, json-encoded
|
||||||
|
auth_time timestamp NOT NULL,
|
||||||
|
FOREIGN KEY(client_id) REFERENCES client(id),
|
||||||
|
FOREIGN KEY(user_id) REFERENCES user(id)
|
||||||
|
);
|
Binary file not shown.
Loading…
Reference in a new issue