diff --git a/polyculeconnect/controller/auth/approval.go b/polyculeconnect/controller/auth/approval.go new file mode 100644 index 0000000..097c348 --- /dev/null +++ b/polyculeconnect/controller/auth/approval.go @@ -0,0 +1,144 @@ +package auth + +import ( + "bytes" + "fmt" + "html/template" + "io" + "net/http" + "path/filepath" + + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers" + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const ApprovalRoute = "/approval" + +var scopeDescriptions = map[string]string{ + "offline_access": "Have offline access", + "profile": "View basic profile information", + "email": "View your email address", + "groups": "View your groups", +} + +func scopeDescription(rawScope string) string { + if desc, ok := scopeDescriptions[rawScope]; ok { + return desc + } + return rawScope +} + +type approvalData struct { + Scopes []string + Client string + AuthReqID string +} + +type ApprovalController struct { + l *zap.SugaredLogger + st db.Storage + baseDir string +} + +func NewApprovalController(l *zap.SugaredLogger, st db.Storage, baseDir string) *ApprovalController { + return &ApprovalController{ + l: l, + st: st, + baseDir: baseDir, + } +} + +func (c *ApprovalController) handleFormResponse(w http.ResponseWriter, r *http.Request) { + reqID, err := uuid.Parse(r.Form.Get("req")) + if err != nil { + c.l.Errorf("Invalid request ID: %s", err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l) + return + } + + if r.Form.Get("approval") != "approve" { + c.l.Debug("Approval rejected") + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("approval rejected"), c.l) + return + } + + if err := c.st.AuthRequestStorage().GiveConsent(r.Context(), reqID); err != nil { + c.l.Errorf("Failed to approve request: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } + + http.Redirect(w, r, fmt.Sprintf("/callback?code=%s&state=%s", r.Form.Get("code"), reqID.String()), http.StatusSeeOther) +} + +func (c *ApprovalController) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + c.l.Errorf("Failed to parse query: %s", err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid query format"), c.l) + return + } + + if r.Method == http.MethodPost { + c.handleFormResponse(w, r) + return + } + + state := r.Form.Get("state") + reqID, err := uuid.Parse(state) + if err != nil { + c.l.Errorf("Invalid state %q: %s", state, err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unexpected state"), c.l) + return + } + + req, err := c.st.AuthRequestStorage().GetAuthRequestByID(r.Context(), reqID) + if err != nil { + c.l.Errorf("Failed to get auth request from DB: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } + + app, err := c.st.ClientStorage().GetClientByID(r.Context(), req.ClientID) + if err != nil { + c.l.Errorf("Failed to get client details from DB: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } + + data := approvalData{ + Scopes: []string{}, + Client: app.Name, + AuthReqID: reqID.String(), + } + for _, s := range req.Scopes { + if s == "openid" { // it's implied we want that, no consent is really important there + continue + } + data.Scopes = append(data.Scopes, scopeDescription(s)) + } + + lp := filepath.Join(c.baseDir, "templates", "approval.html") + hdrTpl := filepath.Join(c.baseDir, "templates", "header.html") + footTpl := filepath.Join(c.baseDir, "templates", "footer.html") + tmpl, err := template.New("approval.html").ParseFiles(hdrTpl, footTpl, lp) + if err != nil { + c.l.Errorf("Failed to parse templates: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } + buf := new(bytes.Buffer) + + if err := tmpl.Execute(buf, data); err != nil { + c.l.Errorf("Failed to execute template: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } + _, err = io.Copy(w, buf) + if err != nil { + c.l.Errorf("Failed to write response: %s", err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, nil, c.l) + return + } +} diff --git a/polyculeconnect/controller/auth/authcallback.go b/polyculeconnect/controller/auth/authcallback.go index f7ca9be..cd32116 100644 --- a/polyculeconnect/controller/auth/authcallback.go +++ b/polyculeconnect/controller/auth/authcallback.go @@ -1,6 +1,7 @@ package auth import ( + "fmt" "net/http" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers" @@ -99,6 +100,12 @@ func (c *CallbackDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Re return } + if !req.Consent { + c.l.Debug("Redirecting to consent endpoint") + http.Redirect(w, r, fmt.Sprintf("/approval?state=%s&code=%s", state, r.URL.Query().Get("code")), http.StatusSeeOther) + return + } + callbackHandler, ok := c.callbackHandlers[req.BackendID] if !ok { c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID) diff --git a/polyculeconnect/internal/db/authrequest/authrequest.go b/polyculeconnect/internal/db/authrequest/authrequest.go index 55aa451..b923954 100644 --- a/polyculeconnect/internal/db/authrequest/authrequest.go +++ b/polyculeconnect/internal/db/authrequest/authrequest.go @@ -15,13 +15,14 @@ import ( var ErrNotFound = errors.New("backend not found") -const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id"` +const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id", "consent"` type AuthRequestDB interface { GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error + GiveConsent(ctx context.Context, reqID uuid.UUID) error } type sqlAuthRequestDB struct { @@ -38,7 +39,7 @@ func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID var timestamp *time.Time - if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID); err != nil { + if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID, &res.Consent); err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if timestamp != nil { @@ -64,7 +65,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut return fmt.Errorf("failed to serialize scopes: %w", err) } - query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '')`, authRequestRows) + query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', 0)`, authRequestRows) _, err = tx.ExecContext(ctx, query, req.ID, req.ClientID, req.BackendID, scopesStr, req.RedirectURI, req.State, @@ -109,6 +110,32 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid. return nil } +func (db *sqlAuthRequestDB) GiveConsent(ctx context.Context, reqID uuid.UUID) error { + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET consent = true WHERE id = $1`, reqID) + if err != nil { + return fmt.Errorf("failed to update in DB: %w", err) + } + affectedRows, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("failed to check number of affected rows: %w", err) + } + if affectedRows != 1 { + return ErrNotFound + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error { logger.L.Debugf("Deleting auth request: %s", reqID) tx, err := db.db.BeginTx(ctx, nil) diff --git a/polyculeconnect/internal/model/authrequest.go b/polyculeconnect/internal/model/authrequest.go index 43a462d..21cb2ee 100644 --- a/polyculeconnect/internal/model/authrequest.go +++ b/polyculeconnect/internal/model/authrequest.go @@ -34,6 +34,7 @@ type AuthRequest struct { User *User DoneVal bool + Consent bool } func (a AuthRequest) GetID() string { diff --git a/polyculeconnect/migrations/2_consent.down.sql b/polyculeconnect/migrations/2_consent.down.sql new file mode 100644 index 0000000..f4fac91 --- /dev/null +++ b/polyculeconnect/migrations/2_consent.down.sql @@ -0,0 +1 @@ +ALTER TABLE "auth_request" DROP COLUMN consent; diff --git a/polyculeconnect/migrations/2_consent.up.sql b/polyculeconnect/migrations/2_consent.up.sql new file mode 100644 index 0000000..ccfe5b2 --- /dev/null +++ b/polyculeconnect/migrations/2_consent.up.sql @@ -0,0 +1 @@ +ALTER TABLE "auth_request" ADD COLUMN consent INTEGER NOT NULL DEFAULT 0; \ No newline at end of file diff --git a/polyculeconnect/server/server.go b/polyculeconnect/server/server.go index 65d8415..027203e 100644 --- a/polyculeconnect/server/server.go +++ b/polyculeconnect/server/server.go @@ -68,8 +68,9 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag } controllers := map[string]http.Handler{ - ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger), - "/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger), + ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger), + auth.ApprovalRoute: middlewares.WithLogger(auth.NewApprovalController(logger, st.LocalStorage, appConf.StaticDir), logger), + "/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger), } userInfoHandler := auth.NewAuthCallbackController(logger, st)