package authrequest import ( "context" "database/sql" "encoding/json" "errors" "fmt" "time" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger" "github.com/google/uuid" ) var ErrNotFound = errors.New("backend not found") const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "claim_user_id", "claim_username", "claim_email"` type AuthRequestDB interface { GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error } type sqlAuthRequestDB struct { db *sql.DB } type dbUser struct { id string username string email string } func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) { logger.L.Debugf("Getting auth request with id %s", id) query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows) row := db.db.QueryRowContext(ctx, query, id) var res model.AuthRequest var user dbUser var scopesStr []byte var timestamp *time.Time if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if timestamp != nil { res.AuthTime = *timestamp } if user.id != "" { userID, err := uuid.Parse(user.id) if err != nil { return nil, fmt.Errorf("invalid format for user id: %w", err) } res.User = &model.User{ ID: userID, Username: user.username, Email: user.email, } } if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil { return nil, fmt.Errorf("invalid format for scopes: %w", err) } return &res, nil } func (db *sqlAuthRequestDB) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) { logger.L.Debugf("Getting auth request with user id %s", id) query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "claim_user_id" = ?`, authRequestRows) row := db.db.QueryRowContext(ctx, query, id) var res model.AuthRequest var user dbUser var scopesStr []byte var timestamp *time.Time if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if timestamp != nil { res.AuthTime = *timestamp } if user.id != "" { userID, err := uuid.Parse(user.id) if err != nil { return nil, fmt.Errorf("invalid format for user id: %w", err) } res.User = &model.User{ ID: userID, Username: user.username, Email: user.email, } } if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil { return nil, fmt.Errorf("invalid format for scopes: %w", err) } return &res, nil } func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error { logger.L.Debugf("Creating a new auth request between client app %s and backend %s", req.ClientID, req.BackendID) tx, err := db.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to start transaction: %w", err) } defer func() { _ = tx.Rollback() }() scopesStr, err := json.Marshal(req.Scopes) if err != nil { return fmt.Errorf("failed to serialize scopes: %w", err) } // TODO: when the old table is done, rename into auth_request query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', '', '')`, authRequestRows) _, err = tx.ExecContext(ctx, query, req.ID, req.ClientID, req.BackendID, scopesStr, req.RedirectURI, req.State, req.Nonce, req.ResponseType, req.CreationDate, false, req.CodeChallenge, req.CodeChallengeMethod, ) if err != nil { return fmt.Errorf("failed to insert in DB: %w", err) } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error { logger.L.Debugf("Validating auth request %s", reqID) tx, err := db.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to start transaction: %w", err) } defer func() { _ = tx.Rollback() }() res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1, claim_user_id = $2, claim_username = $3, claim_email = $4 WHERE id = $5`, time.Now().UTC(), user.ID, user.Username, user.Email, reqID.String()) if err != nil { return fmt.Errorf("failed to update in DB: %w", err) } affectedRows, err := res.RowsAffected() if err != nil { return fmt.Errorf("failed to check number of affected rows: %w", err) } if affectedRows != 1 { return ErrNotFound } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error { logger.L.Debugf("Deleting auth request: %s", reqID) tx, err := db.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to start transaction: %w", err) } defer func() { _ = tx.Rollback() }() _, err = tx.ExecContext(ctx, `DELETE FROM "auth_request_2" WHERE id = $1`, reqID.String()) if err != nil { return fmt.Errorf("failed to delete auth request: %w", err) } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } func New(db *sql.DB) *sqlAuthRequestDB { return &sqlAuthRequestDB{db: db} }