diff --git a/polyculeconnect/controller/auth/authcallback.go b/polyculeconnect/controller/auth/authcallback.go index d31bb51..50e588a 100644 --- a/polyculeconnect/controller/auth/authcallback.go +++ b/polyculeconnect/controller/auth/authcallback.go @@ -4,8 +4,11 @@ import ( "net/http" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers" + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage" "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" "go.uber.org/zap" ) @@ -23,23 +26,7 @@ func NewAuthCallbackController(l *zap.SugaredLogger, st *storage.Storage) *AuthC } } -func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Request) { - errMsg := r.URL.Query().Get("error") - if errMsg != "" { - errorDesc := r.URL.Query().Get("error_description") - c.l.Errorf("Failed to perform authentication: %s (%s)", errMsg, errorDesc) - helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) - return - } - - code := r.URL.Query().Get("code") - state := r.URL.Query().Get("state") - if code == "" || state == "" { - c.l.Error("Missing code or state in response") - helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) - return - } - +func (c *AuthCallbackController) HandleUserInfoCallback(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { requestID, err := uuid.Parse(state) if err != nil { c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err) @@ -47,7 +34,14 @@ func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Reques return } - err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID) + c.l.Infof("Successful login from %s", info.Email) + user := model.User{ + ID: uuid.New(), + Email: info.Email, + Username: info.PreferredUsername, + } + + err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, &user) if err != nil { c.l.Errorf("Failed to validate auth request from storage: %s", err) helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) @@ -56,3 +50,50 @@ func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Reques http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound) } + +type CallbackDispatchController struct { + l *zap.SugaredLogger + st *storage.Storage + callbackHandlers map[uuid.UUID]http.Handler +} + +func NewCallbackDispatchController(l *zap.SugaredLogger, st *storage.Storage, handlers map[uuid.UUID]http.Handler) *CallbackDispatchController { + return &CallbackDispatchController{ + l: l, + st: st, + callbackHandlers: handlers, + } +} + +func (c *CallbackDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Request) { + errMsg := r.URL.Query().Get("error") + if errMsg != "" { + errorDesc := r.URL.Query().Get("error_description") + c.l.Errorf("Failed to perform authentication: %s (%s)", errMsg, errorDesc) + helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) + return + } + + state := r.URL.Query().Get("state") + requestID, err := uuid.Parse(state) + if err != nil { + c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err) + helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) + return + } + + req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID) + if err != nil { + c.l.Errorf("Failed to get auth request from DB: %s", err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l) + return + } + + callbackHandler, ok := c.callbackHandlers[req.BackendID] + if !ok { + c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID) + helpers.HandleResponse(w, r, http.StatusNotFound, []byte("unknown backend"), c.l) + return + } + callbackHandler.ServeHTTP(w, r) +} diff --git a/polyculeconnect/controller/auth/authdispatch.go b/polyculeconnect/controller/auth/authdispatch.go new file mode 100644 index 0000000..cf36b12 --- /dev/null +++ b/polyculeconnect/controller/auth/authdispatch.go @@ -0,0 +1,54 @@ +package auth + +import ( + "net/http" + + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers" + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage" + "github.com/google/uuid" + "go.uber.org/zap" +) + +type AuthDispatchController struct { + l *zap.SugaredLogger + st *storage.Storage + redirectHandlers map[uuid.UUID]http.Handler +} + +func NewAuthDispatchController(l *zap.SugaredLogger, storage *storage.Storage, redirectHandlers map[uuid.UUID]http.Handler) *AuthDispatchController { + return &AuthDispatchController{ + l: l, + st: storage, + redirectHandlers: redirectHandlers, + } +} + +func (c *AuthDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Request) { + requestIDStr := r.URL.Query().Get("request_id") + if requestIDStr == "" { + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("no request ID in request"), c.l) + return + } + + requestID, err := uuid.Parse(requestIDStr) + if err != nil { + c.l.Errorf("Invalid UUID format for request ID: %s", err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid request id"), c.l) + return + } + + req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID) + if err != nil { + c.l.Errorf("Failed to get auth request from DB: %s", err) + helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l) + return + } + + loginHandler, ok := c.redirectHandlers[req.BackendID] + if !ok { + c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID) + helpers.HandleResponse(w, r, http.StatusNotFound, []byte("unknown backend"), c.l) + return + } + loginHandler.ServeHTTP(w, r) +} diff --git a/polyculeconnect/controller/auth/authredirect.go b/polyculeconnect/controller/auth/authredirect.go index e128d64..0d44ebc 100644 --- a/polyculeconnect/controller/auth/authredirect.go +++ b/polyculeconnect/controller/auth/authredirect.go @@ -13,14 +13,16 @@ import ( const AuthRedirectRoute = "/perform_auth" type AuthRedirectController struct { - l *zap.SugaredLogger - st *storage.Storage + provider rp.RelyingParty + l *zap.SugaredLogger + st *storage.Storage } -func NewAuthRedirectController(l *zap.SugaredLogger, storage *storage.Storage) *AuthRedirectController { +func NewAuthRedirectController(l *zap.SugaredLogger, provider rp.RelyingParty, storage *storage.Storage) *AuthRedirectController { return &AuthRedirectController{ - l: l, - st: storage, + l: l, + st: storage, + provider: provider, } } @@ -38,26 +40,26 @@ func (c *AuthRedirectController) ServeHTTP(w http.ResponseWriter, r *http.Reques return } - req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID) + _, err = c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID) if err != nil { c.l.Errorf("Failed to get auth request from DB: %s", err) helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l) return } - backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID) - if err != nil { - c.l.Errorf("Failed to get backend from DB: %s", err) - helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l) - return - } + // backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID) + // if err != nil { + // c.l.Errorf("Failed to get backend from DB: %s", err) + // helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l) + // return + // } - provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes) - if err != nil { - c.l.Errorf("Failed to init relying party: %s", err) - helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l) - return - } + // provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes) + // if err != nil { + // c.l.Errorf("Failed to init relying party: %s", err) + // helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l) + // return + // } - rp.AuthURLHandler(func() string { return requestIDStr }, provider).ServeHTTP(w, r) + rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r) } diff --git a/polyculeconnect/internal/db/authrequest/authrequest.go b/polyculeconnect/internal/db/authrequest/authrequest.go index 0b878a0..0b51f9f 100644 --- a/polyculeconnect/internal/db/authrequest/authrequest.go +++ b/polyculeconnect/internal/db/authrequest/authrequest.go @@ -15,12 +15,13 @@ import ( var ErrNotFound = errors.New("backend not found") -const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time"` +const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "claim_user_id", "claim_username", "claim_email"` type AuthRequestDB interface { GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) + GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error - ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error + ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error } @@ -28,22 +29,75 @@ type sqlAuthRequestDB struct { db *sql.DB } +type dbUser struct { + id string + username string + email string +} + func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) { logger.L.Debugf("Getting auth request with id %s", id) query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows) row := db.db.QueryRowContext(ctx, query, id) var res model.AuthRequest + var user dbUser var scopesStr []byte var timestamp *time.Time - if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp); err != nil { + if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if timestamp != nil { res.AuthTime = *timestamp } + if user.id != "" { + userID, err := uuid.Parse(user.id) + if err != nil { + return nil, fmt.Errorf("invalid format for user id: %w", err) + } + res.User = &model.User{ + ID: userID, + Username: user.username, + Email: user.email, + } + } + if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil { + return nil, fmt.Errorf("invalid format for scopes: %w", err) + } + + return &res, nil +} + +func (db *sqlAuthRequestDB) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) { + logger.L.Debugf("Getting auth request with user id %s", id) + query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "claim_user_id" = ?`, authRequestRows) + row := db.db.QueryRowContext(ctx, query, id) + + var res model.AuthRequest + var user dbUser + var scopesStr []byte + + var timestamp *time.Time + + if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil { + return nil, fmt.Errorf("failed to get auth request from DB: %w", err) + } + if timestamp != nil { + res.AuthTime = *timestamp + } + if user.id != "" { + userID, err := uuid.Parse(user.id) + if err != nil { + return nil, fmt.Errorf("invalid format for user id: %w", err) + } + res.User = &model.User{ + ID: userID, + Username: user.username, + Email: user.email, + } + } if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil { return nil, fmt.Errorf("invalid format for scopes: %w", err) } @@ -65,7 +119,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut } // TODO: when the old table is done, rename into auth_request - query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL)`, authRequestRows) + query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', '', '')`, authRequestRows) _, err = tx.ExecContext(ctx, query, req.ID, req.ClientID, req.BackendID, scopesStr, req.RedirectURI, req.State, @@ -83,7 +137,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut return nil } -func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error { +func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error { logger.L.Debugf("Validating auth request %s", reqID) tx, err := db.db.BeginTx(ctx, nil) if err != nil { @@ -91,7 +145,7 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid. } defer func() { _ = tx.Rollback() }() - res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1 WHERE id = $2`, time.Now().UTC(), reqID.String()) + res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1, claim_user_id = $2, claim_username = $3, claim_email = $4 WHERE id = $5`, time.Now().UTC(), user.ID, user.Username, user.Email, reqID.String()) if err != nil { return fmt.Errorf("failed to update in DB: %w", err) } diff --git a/polyculeconnect/internal/model/authrequest.go b/polyculeconnect/internal/model/authrequest.go index 343a72e..c1e1475 100644 --- a/polyculeconnect/internal/model/authrequest.go +++ b/polyculeconnect/internal/model/authrequest.go @@ -30,7 +30,7 @@ type AuthRequest struct { BackendID uuid.UUID Backend *Backend - UserID uuid.UUID + User *User DoneVal bool } @@ -91,7 +91,10 @@ func (a AuthRequest) GetState() string { } func (a AuthRequest) GetSubject() string { - return a.UserID.String() + if a.User == nil { + return "" + } + return a.User.ID.String() } func (a AuthRequest) Done() bool { diff --git a/polyculeconnect/internal/model/user.go b/polyculeconnect/internal/model/user.go new file mode 100644 index 0000000..75e87e8 --- /dev/null +++ b/polyculeconnect/internal/model/user.go @@ -0,0 +1,9 @@ +package model + +import "github.com/google/uuid" + +type User struct { + ID uuid.UUID + Email string + Username string +} diff --git a/polyculeconnect/internal/storage/storage.go b/polyculeconnect/internal/storage/storage.go index 44520f5..6575daa 100644 --- a/polyculeconnect/internal/storage/storage.go +++ b/polyculeconnect/internal/storage/storage.go @@ -251,7 +251,22 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS } func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { - // we'll use FromRequest instead + logger.L.Debugf("Setting user info for user %s", userID) + + parsedID, err := uuid.Parse(userID) + if err != nil { + return fmt.Errorf("invalid userID: %w", err) + } + req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByUserID(ctx, parsedID) + if err != nil { + return fmt.Errorf("failed to get auth request from DB: %w", err) + } + if req.User == nil { + return errors.New("no user associated to that ID") + } + + userinfo.PreferredUsername = req.User.Username + userinfo.Email = req.User.Email return nil } diff --git a/polyculeconnect/migrations/6_add_auth_request_auth_user.down.sql b/polyculeconnect/migrations/6_add_auth_request_auth_user.down.sql new file mode 100644 index 0000000..10a21a1 --- /dev/null +++ b/polyculeconnect/migrations/6_add_auth_request_auth_user.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE "auth_request_2" DROP COLUMN claim_user_id; +ALTER TABLE "auth_request_2" DROP COLUMN claim_username; +ALTER TABLE "auth_request_2" DROP COLUMN claim_email; diff --git a/polyculeconnect/migrations/6_add_auth_request_auth_user.up.sql b/polyculeconnect/migrations/6_add_auth_request_auth_user.up.sql new file mode 100644 index 0000000..bdea95c --- /dev/null +++ b/polyculeconnect/migrations/6_add_auth_request_auth_user.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE "auth_request_2" ADD COLUMN claim_user_id string; +ALTER TABLE "auth_request_2" ADD COLUMN claim_username string; +ALTER TABLE "auth_request_2" ADD COLUMN claim_email string; diff --git a/polyculeconnect/polyculeconnect.db b/polyculeconnect/polyculeconnect.db index 46f8e19..61f7ea9 100644 Binary files a/polyculeconnect/polyculeconnect.db and b/polyculeconnect/polyculeconnect.db differ diff --git a/polyculeconnect/server/server.go b/polyculeconnect/server/server.go index ea9a4be..504ca83 100644 --- a/polyculeconnect/server/server.go +++ b/polyculeconnect/server/server.go @@ -13,6 +13,8 @@ import ( "git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/ui" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/middlewares" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/op" "go.uber.org/zap" ) @@ -66,12 +68,31 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag } controllers := map[string]http.Handler{ - auth.AuthCallbackRoute: middlewares.WithLogger(auth.NewAuthCallbackController(logger, st), logger), - auth.AuthRedirectRoute: middlewares.WithLogger(auth.NewAuthRedirectController(logger, st), logger), - ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger), - "/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger), + ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger), + "/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger), } + userInfoHandler := auth.NewAuthCallbackController(logger, st) + loginHandlers := map[uuid.UUID]http.Handler{} + callbackHandlers := map[uuid.UUID]http.Handler{} + + backends, err := st.LocalStorage.BackendStorage().GetAllBackends(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get list of backends from storage: %w", err) + } + for _, b := range backends { + provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, []string{"openid", "email"}) + if err != nil { + return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err) + } + + loginHandlers[b.ID] = middlewares.WithLogger(auth.NewAuthRedirectController(logger, provider, st), logger) + callbackHandlers[b.ID] = middlewares.WithLogger(rp.CodeExchangeHandler(rp.UserinfoCallback(userInfoHandler.HandleUserInfoCallback), provider), logger) + } + + controllers[auth.AuthRedirectRoute] = middlewares.WithLogger(auth.NewAuthDispatchController(logger, st, loginHandlers), logger) + controllers[auth.AuthCallbackRoute] = middlewares.WithLogger(auth.NewCallbackDispatchController(logger, st, callbackHandlers), logger) + m := http.NewServeMux() return &Server{