159 lines
5 KiB
Go
159 lines
5 KiB
Go
package authrequest
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var ErrNotFound = errors.New("backend not found")
|
|
|
|
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id", "consent"`
|
|
|
|
type AuthRequestDB interface {
|
|
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
|
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
|
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error
|
|
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
|
GiveConsent(ctx context.Context, reqID uuid.UUID) error
|
|
}
|
|
|
|
type sqlAuthRequestDB struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
|
logger.L.Debugf("Getting auth request with id %s", id)
|
|
query := fmt.Sprintf(`SELECT %s FROM "auth_request" WHERE "id" = ?`, authRequestRows)
|
|
row := db.db.QueryRowContext(ctx, query, id)
|
|
|
|
var res model.AuthRequest
|
|
var scopesStr []byte
|
|
|
|
var timestamp *time.Time
|
|
|
|
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID, &res.Consent); err != nil {
|
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
|
}
|
|
if timestamp != nil {
|
|
res.AuthTime = *timestamp
|
|
}
|
|
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
|
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
|
}
|
|
|
|
return &res, nil
|
|
}
|
|
|
|
func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error {
|
|
logger.L.Debugf("Creating a new auth request between client app %s and backend %s", req.ClientID, req.BackendID)
|
|
tx, err := db.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
scopesStr, err := json.Marshal(req.Scopes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to serialize scopes: %w", err)
|
|
}
|
|
|
|
query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', 0)`, authRequestRows)
|
|
_, err = tx.ExecContext(ctx, query,
|
|
req.ID, req.ClientID, req.BackendID,
|
|
scopesStr, req.RedirectURI, req.State,
|
|
req.Nonce, req.ResponseType, req.CreationDate, false,
|
|
req.CodeChallenge, req.CodeChallengeMethod,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error {
|
|
logger.L.Debugf("Validating auth request %s", reqID)
|
|
tx, err := db.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET done = true, auth_time = $1, user_id = $2 WHERE id = $3`, time.Now().UTC(), userID, reqID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update in DB: %w", err)
|
|
}
|
|
affectedRows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check number of affected rows: %w", err)
|
|
}
|
|
if affectedRows != 1 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *sqlAuthRequestDB) GiveConsent(ctx context.Context, reqID uuid.UUID) error {
|
|
tx, err := db.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET consent = true WHERE id = $1`, reqID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update in DB: %w", err)
|
|
}
|
|
affectedRows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check number of affected rows: %w", err)
|
|
}
|
|
if affectedRows != 1 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error {
|
|
logger.L.Debugf("Deleting auth request: %s", reqID)
|
|
tx, err := db.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
_, err = tx.ExecContext(ctx, `DELETE FROM "auth_request" WHERE id = $1`, reqID.String())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete auth request: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func New(db *sql.DB) *sqlAuthRequestDB {
|
|
return &sqlAuthRequestDB{db: db}
|
|
}
|