feat/epic-48-replace-dex #20

Merged
faercol merged 20 commits from feat/epic-48-replace-dex into main 2024-10-27 15:16:40 +00:00
7 changed files with 187 additions and 5 deletions
Showing only changes of commit e99fabafb9 - Show all commits

View file

@ -0,0 +1,144 @@
package auth
import (
"bytes"
"fmt"
"html/template"
"io"
"net/http"
"path/filepath"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
"github.com/google/uuid"
"go.uber.org/zap"
)
const ApprovalRoute = "/approval"
var scopeDescriptions = map[string]string{
"offline_access": "Have offline access",
"profile": "View basic profile information",
"email": "View your email address",
"groups": "View your groups",
}
func scopeDescription(rawScope string) string {
if desc, ok := scopeDescriptions[rawScope]; ok {
return desc
}
return rawScope
}
type approvalData struct {
Scopes []string
Client string
AuthReqID string
}
type ApprovalController struct {
l *zap.SugaredLogger
st db.Storage
baseDir string
}
func NewApprovalController(l *zap.SugaredLogger, st db.Storage, baseDir string) *ApprovalController {
return &ApprovalController{
l: l,
st: st,
baseDir: baseDir,
}
}
func (c *ApprovalController) handleFormResponse(w http.ResponseWriter, r *http.Request) {
reqID, err := uuid.Parse(r.Form.Get("req"))
if err != nil {
c.l.Errorf("Invalid request ID: %s", err)
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l)
return
}
if r.Form.Get("approval") != "approve" {
c.l.Debug("Approval rejected")
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("approval rejected"), c.l)
return
}
if err := c.st.AuthRequestStorage().GiveConsent(r.Context(), reqID); err != nil {
c.l.Errorf("Failed to approve request: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
http.Redirect(w, r, fmt.Sprintf("/callback?code=%s&state=%s", r.Form.Get("code"), reqID.String()), http.StatusSeeOther)
}
func (c *ApprovalController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
c.l.Errorf("Failed to parse query: %s", err)
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l)
return
}
if r.Method == http.MethodPost {
c.handleFormResponse(w, r)
return
}
state := r.Form.Get("state")
reqID, err := uuid.Parse(state)
if err != nil {
c.l.Errorf("Invalid state %q: %s", state, err)
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unexpected state"), c.l)
return
}
req, err := c.st.AuthRequestStorage().GetAuthRequestByID(r.Context(), reqID)
if err != nil {
c.l.Errorf("Failed to get auth request from DB: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
app, err := c.st.ClientStorage().GetClientByID(r.Context(), req.ClientID)
if err != nil {
c.l.Errorf("Failed to get client details from DB: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
data := approvalData{
Scopes: []string{},
Client: app.Name,
AuthReqID: reqID.String(),
}
for _, s := range req.Scopes {
if s == "openid" { // it's implied we want that, no consent is really important there
continue
}
data.Scopes = append(data.Scopes, scopeDescription(s))
}
lp := filepath.Join(c.baseDir, "templates", "approval.html")
hdrTpl := filepath.Join(c.baseDir, "templates", "header.html")
footTpl := filepath.Join(c.baseDir, "templates", "footer.html")
tmpl, err := template.New("approval.html").ParseFiles(hdrTpl, footTpl, lp)
if err != nil {
c.l.Errorf("Failed to parse templates: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
buf := new(bytes.Buffer)
if err := tmpl.Execute(buf, data); err != nil {
c.l.Errorf("Failed to execute template: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
_, err = io.Copy(w, buf)
if err != nil {
c.l.Errorf("Failed to write response: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
return
}
}

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
@ -99,6 +100,12 @@ func (c *CallbackDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Re
return return
} }
if !req.Consent {
c.l.Debug("Redirecting to consent endpoint")
http.Redirect(w, r, fmt.Sprintf("/approval?state=%s&code=%s", state, r.URL.Query().Get("code")), http.StatusSeeOther)
return
}
callbackHandler, ok := c.callbackHandlers[req.BackendID] callbackHandler, ok := c.callbackHandlers[req.BackendID]
if !ok { if !ok {
c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID) c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID)

View file

@ -15,13 +15,14 @@ import (
var ErrNotFound = errors.New("backend not found") var ErrNotFound = errors.New("backend not found")
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id"` const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id", "consent"`
type AuthRequestDB interface { type AuthRequestDB interface {
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
GiveConsent(ctx context.Context, reqID uuid.UUID) error
} }
type sqlAuthRequestDB struct { type sqlAuthRequestDB struct {
@ -38,7 +39,7 @@ func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID
var timestamp *time.Time var timestamp *time.Time
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, &timestamp, &res.UserID); err != nil { if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, &timestamp, &res.UserID, &res.Consent); err != nil {
return nil, fmt.Errorf("failed to get auth request from DB: %w", err) return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
} }
if timestamp != nil { if timestamp != nil {
@ -64,7 +65,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
return fmt.Errorf("failed to serialize scopes: %w", err) return fmt.Errorf("failed to serialize scopes: %w", err)
} }
query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '')`, authRequestRows) query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', 0)`, authRequestRows)
_, err = tx.ExecContext(ctx, query, _, err = tx.ExecContext(ctx, query,
req.ID, req.ClientID, req.BackendID, req.ID, req.ClientID, req.BackendID,
scopesStr, req.RedirectURI, req.State, scopesStr, req.RedirectURI, req.State,
@ -109,6 +110,32 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.
return nil return nil
} }
func (db *sqlAuthRequestDB) GiveConsent(ctx context.Context, reqID uuid.UUID) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET consent = true WHERE id = $1`, reqID)
if err != nil {
return fmt.Errorf("failed to update in DB: %w", err)
}
affectedRows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check number of affected rows: %w", err)
}
if affectedRows != 1 {
return ErrNotFound
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error { func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error {
logger.L.Debugf("Deleting auth request: %s", reqID) logger.L.Debugf("Deleting auth request: %s", reqID)
tx, err := db.db.BeginTx(ctx, nil) tx, err := db.db.BeginTx(ctx, nil)

View file

@ -34,6 +34,7 @@ type AuthRequest struct {
User *User User *User
DoneVal bool DoneVal bool
Consent bool
} }
func (a AuthRequest) GetID() string { func (a AuthRequest) GetID() string {

View file

@ -0,0 +1 @@
ALTER TABLE "auth_request" DROP COLUMN consent;

View file

@ -0,0 +1 @@
ALTER TABLE "auth_request" ADD COLUMN consent INTEGER NOT NULL DEFAULT 0;

View file

@ -68,8 +68,9 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
} }
controllers := map[string]http.Handler{ controllers := map[string]http.Handler{
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger), ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger), auth.ApprovalRoute: middlewares.WithLogger(auth.NewApprovalController(logger, st.LocalStorage, appConf.StaticDir), logger),
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
} }
userInfoHandler := auth.NewAuthCallbackController(logger, st) userInfoHandler := auth.NewAuthCallbackController(logger, st)