feat/epic-48-replace-dex #20
7 changed files with 187 additions and 5 deletions
144
polyculeconnect/controller/auth/approval.go
Normal file
144
polyculeconnect/controller/auth/approval.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const ApprovalRoute = "/approval"
|
||||
|
||||
var scopeDescriptions = map[string]string{
|
||||
"offline_access": "Have offline access",
|
||||
"profile": "View basic profile information",
|
||||
"email": "View your email address",
|
||||
"groups": "View your groups",
|
||||
}
|
||||
|
||||
func scopeDescription(rawScope string) string {
|
||||
if desc, ok := scopeDescriptions[rawScope]; ok {
|
||||
return desc
|
||||
}
|
||||
return rawScope
|
||||
}
|
||||
|
||||
type approvalData struct {
|
||||
Scopes []string
|
||||
Client string
|
||||
AuthReqID string
|
||||
}
|
||||
|
||||
type ApprovalController struct {
|
||||
l *zap.SugaredLogger
|
||||
st db.Storage
|
||||
baseDir string
|
||||
}
|
||||
|
||||
func NewApprovalController(l *zap.SugaredLogger, st db.Storage, baseDir string) *ApprovalController {
|
||||
return &ApprovalController{
|
||||
l: l,
|
||||
st: st,
|
||||
baseDir: baseDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ApprovalController) handleFormResponse(w http.ResponseWriter, r *http.Request) {
|
||||
reqID, err := uuid.Parse(r.Form.Get("req"))
|
||||
if err != nil {
|
||||
c.l.Errorf("Invalid request ID: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Form.Get("approval") != "approve" {
|
||||
c.l.Debug("Approval rejected")
|
||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("approval rejected"), c.l)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.st.AuthRequestStorage().GiveConsent(r.Context(), reqID); err != nil {
|
||||
c.l.Errorf("Failed to approve request: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, fmt.Sprintf("/callback?code=%s&state=%s", r.Form.Get("code"), reqID.String()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (c *ApprovalController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
c.l.Errorf("Failed to parse query: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
c.handleFormResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.Form.Get("state")
|
||||
reqID, err := uuid.Parse(state)
|
||||
if err != nil {
|
||||
c.l.Errorf("Invalid state %q: %s", state, err)
|
||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unexpected state"), c.l)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := c.st.AuthRequestStorage().GetAuthRequestByID(r.Context(), reqID)
|
||||
if err != nil {
|
||||
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
|
||||
app, err := c.st.ClientStorage().GetClientByID(r.Context(), req.ClientID)
|
||||
if err != nil {
|
||||
c.l.Errorf("Failed to get client details from DB: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
|
||||
data := approvalData{
|
||||
Scopes: []string{},
|
||||
Client: app.Name,
|
||||
AuthReqID: reqID.String(),
|
||||
}
|
||||
for _, s := range req.Scopes {
|
||||
if s == "openid" { // it's implied we want that, no consent is really important there
|
||||
continue
|
||||
}
|
||||
data.Scopes = append(data.Scopes, scopeDescription(s))
|
||||
}
|
||||
|
||||
lp := filepath.Join(c.baseDir, "templates", "approval.html")
|
||||
hdrTpl := filepath.Join(c.baseDir, "templates", "header.html")
|
||||
footTpl := filepath.Join(c.baseDir, "templates", "footer.html")
|
||||
tmpl, err := template.New("approval.html").ParseFiles(hdrTpl, footTpl, lp)
|
||||
if err != nil {
|
||||
c.l.Errorf("Failed to parse templates: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
if err := tmpl.Execute(buf, data); err != nil {
|
||||
c.l.Errorf("Failed to execute template: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(w, buf)
|
||||
if err != nil {
|
||||
c.l.Errorf("Failed to write response: %s", err)
|
||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l)
|
||||
return
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
||||
|
@ -99,6 +100,12 @@ func (c *CallbackDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
|
||||
if !req.Consent {
|
||||
c.l.Debug("Redirecting to consent endpoint")
|
||||
http.Redirect(w, r, fmt.Sprintf("/approval?state=%s&code=%s", state, r.URL.Query().Get("code")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
callbackHandler, ok := c.callbackHandlers[req.BackendID]
|
||||
if !ok {
|
||||
c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID)
|
||||
|
|
|
@ -15,13 +15,14 @@ import (
|
|||
|
||||
var ErrNotFound = errors.New("backend not found")
|
||||
|
||||
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id"`
|
||||
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id", "consent"`
|
||||
|
||||
type AuthRequestDB interface {
|
||||
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
||||
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
||||
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error
|
||||
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
||||
GiveConsent(ctx context.Context, reqID uuid.UUID) error
|
||||
}
|
||||
|
||||
type sqlAuthRequestDB struct {
|
||||
|
@ -38,7 +39,7 @@ func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID
|
|||
|
||||
var timestamp *time.Time
|
||||
|
||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID); err != nil {
|
||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID, &res.Consent); err != nil {
|
||||
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||
}
|
||||
if timestamp != nil {
|
||||
|
@ -64,7 +65,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
|||
return fmt.Errorf("failed to serialize scopes: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '')`, authRequestRows)
|
||||
query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', 0)`, authRequestRows)
|
||||
_, err = tx.ExecContext(ctx, query,
|
||||
req.ID, req.ClientID, req.BackendID,
|
||||
scopesStr, req.RedirectURI, req.State,
|
||||
|
@ -109,6 +110,32 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.
|
|||
return nil
|
||||
}
|
||||
|
||||
func (db *sqlAuthRequestDB) GiveConsent(ctx context.Context, reqID uuid.UUID) error {
|
||||
tx, err := db.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET consent = true WHERE id = $1`, reqID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update in DB: %w", err)
|
||||
}
|
||||
affectedRows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check number of affected rows: %w", err)
|
||||
}
|
||||
if affectedRows != 1 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error {
|
||||
logger.L.Debugf("Deleting auth request: %s", reqID)
|
||||
tx, err := db.db.BeginTx(ctx, nil)
|
||||
|
|
|
@ -34,6 +34,7 @@ type AuthRequest struct {
|
|||
User *User
|
||||
|
||||
DoneVal bool
|
||||
Consent bool
|
||||
}
|
||||
|
||||
func (a AuthRequest) GetID() string {
|
||||
|
|
1
polyculeconnect/migrations/2_consent.down.sql
Normal file
1
polyculeconnect/migrations/2_consent.down.sql
Normal file
|
@ -0,0 +1 @@
|
|||
ALTER TABLE "auth_request" DROP COLUMN consent;
|
1
polyculeconnect/migrations/2_consent.up.sql
Normal file
1
polyculeconnect/migrations/2_consent.up.sql
Normal file
|
@ -0,0 +1 @@
|
|||
ALTER TABLE "auth_request" ADD COLUMN consent INTEGER NOT NULL DEFAULT 0;
|
|
@ -68,8 +68,9 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
|
|||
}
|
||||
|
||||
controllers := map[string]http.Handler{
|
||||
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
||||
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
||||
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
||||
auth.ApprovalRoute: middlewares.WithLogger(auth.NewApprovalController(logger, st.LocalStorage, appConf.StaticDir), logger),
|
||||
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
||||
}
|
||||
|
||||
userInfoHandler := auth.NewAuthCallbackController(logger, st)
|
||||
|
|
Loading…
Reference in a new issue