feat/epic-48-replace-dex #20
11 changed files with 255 additions and 50 deletions
|
@ -4,8 +4,11 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,23 +26,7 @@ func NewAuthCallbackController(l *zap.SugaredLogger, st *storage.Storage) *AuthC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (c *AuthCallbackController) HandleUserInfoCallback(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
|
||||||
errMsg := r.URL.Query().Get("error")
|
|
||||||
if errMsg != "" {
|
|
||||||
errorDesc := r.URL.Query().Get("error_description")
|
|
||||||
c.l.Errorf("Failed to perform authentication: %s (%s)", errMsg, errorDesc)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
|
||||||
state := r.URL.Query().Get("state")
|
|
||||||
if code == "" || state == "" {
|
|
||||||
c.l.Error("Missing code or state in response")
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID, err := uuid.Parse(state)
|
requestID, err := uuid.Parse(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err)
|
c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err)
|
||||||
|
@ -47,7 +34,14 @@ func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Reques
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID)
|
c.l.Infof("Successful login from %s", info.Email)
|
||||||
|
user := model.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Email: info.Email,
|
||||||
|
Username: info.PreferredUsername,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, &user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.Errorf("Failed to validate auth request from storage: %s", err)
|
c.l.Errorf("Failed to validate auth request from storage: %s", err)
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
||||||
|
@ -56,3 +50,50 @@ func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Reques
|
||||||
|
|
||||||
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
|
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CallbackDispatchController struct {
|
||||||
|
l *zap.SugaredLogger
|
||||||
|
st *storage.Storage
|
||||||
|
callbackHandlers map[uuid.UUID]http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCallbackDispatchController(l *zap.SugaredLogger, st *storage.Storage, handlers map[uuid.UUID]http.Handler) *CallbackDispatchController {
|
||||||
|
return &CallbackDispatchController{
|
||||||
|
l: l,
|
||||||
|
st: st,
|
||||||
|
callbackHandlers: handlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CallbackDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
errMsg := r.URL.Query().Get("error")
|
||||||
|
if errMsg != "" {
|
||||||
|
errorDesc := r.URL.Query().Get("error_description")
|
||||||
|
c.l.Errorf("Failed to perform authentication: %s (%s)", errMsg, errorDesc)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
requestID, err := uuid.Parse(state)
|
||||||
|
if err != nil {
|
||||||
|
c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID)
|
||||||
|
if err != nil {
|
||||||
|
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackHandler, ok := c.callbackHandlers[req.BackendID]
|
||||||
|
if !ok {
|
||||||
|
c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusNotFound, []byte("unknown backend"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callbackHandler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
54
polyculeconnect/controller/auth/authdispatch.go
Normal file
54
polyculeconnect/controller/auth/authdispatch.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthDispatchController struct {
|
||||||
|
l *zap.SugaredLogger
|
||||||
|
st *storage.Storage
|
||||||
|
redirectHandlers map[uuid.UUID]http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthDispatchController(l *zap.SugaredLogger, storage *storage.Storage, redirectHandlers map[uuid.UUID]http.Handler) *AuthDispatchController {
|
||||||
|
return &AuthDispatchController{
|
||||||
|
l: l,
|
||||||
|
st: storage,
|
||||||
|
redirectHandlers: redirectHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthDispatchController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestIDStr := r.URL.Query().Get("request_id")
|
||||||
|
if requestIDStr == "" {
|
||||||
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("no request ID in request"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID, err := uuid.Parse(requestIDStr)
|
||||||
|
if err != nil {
|
||||||
|
c.l.Errorf("Invalid UUID format for request ID: %s", err)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid request id"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID)
|
||||||
|
if err != nil {
|
||||||
|
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loginHandler, ok := c.redirectHandlers[req.BackendID]
|
||||||
|
if !ok {
|
||||||
|
c.l.Errorf("Backend %s does not exist for request %s", req.ID, req.BackendID)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusNotFound, []byte("unknown backend"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
loginHandler.ServeHTTP(w, r)
|
||||||
|
}
|
|
@ -13,14 +13,16 @@ import (
|
||||||
const AuthRedirectRoute = "/perform_auth"
|
const AuthRedirectRoute = "/perform_auth"
|
||||||
|
|
||||||
type AuthRedirectController struct {
|
type AuthRedirectController struct {
|
||||||
|
provider rp.RelyingParty
|
||||||
l *zap.SugaredLogger
|
l *zap.SugaredLogger
|
||||||
st *storage.Storage
|
st *storage.Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthRedirectController(l *zap.SugaredLogger, storage *storage.Storage) *AuthRedirectController {
|
func NewAuthRedirectController(l *zap.SugaredLogger, provider rp.RelyingParty, storage *storage.Storage) *AuthRedirectController {
|
||||||
return &AuthRedirectController{
|
return &AuthRedirectController{
|
||||||
l: l,
|
l: l,
|
||||||
st: storage,
|
st: storage,
|
||||||
|
provider: provider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,26 +40,26 @@ func (c *AuthRedirectController) ServeHTTP(w http.ResponseWriter, r *http.Reques
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID)
|
_, err = c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
||||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID)
|
// backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
c.l.Errorf("Failed to get backend from DB: %s", err)
|
// c.l.Errorf("Failed to get backend from DB: %s", err)
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes)
|
// provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
c.l.Errorf("Failed to init relying party: %s", err)
|
// c.l.Errorf("Failed to init relying party: %s", err)
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
rp.AuthURLHandler(func() string { return requestIDStr }, provider).ServeHTTP(w, r)
|
rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,12 +15,13 @@ import (
|
||||||
|
|
||||||
var ErrNotFound = errors.New("backend not found")
|
var ErrNotFound = errors.New("backend not found")
|
||||||
|
|
||||||
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time"`
|
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "claim_user_id", "claim_username", "claim_email"`
|
||||||
|
|
||||||
type AuthRequestDB interface {
|
type AuthRequestDB interface {
|
||||||
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
||||||
|
GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
||||||
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
||||||
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error
|
||||||
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,22 +29,75 @@ type sqlAuthRequestDB struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dbUser struct {
|
||||||
|
id string
|
||||||
|
username string
|
||||||
|
email string
|
||||||
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
||||||
logger.L.Debugf("Getting auth request with id %s", id)
|
logger.L.Debugf("Getting auth request with id %s", id)
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows)
|
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
|
|
||||||
var res model.AuthRequest
|
var res model.AuthRequest
|
||||||
|
var user dbUser
|
||||||
var scopesStr []byte
|
var scopesStr []byte
|
||||||
|
|
||||||
var timestamp *time.Time
|
var timestamp *time.Time
|
||||||
|
|
||||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp); err != nil {
|
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil {
|
||||||
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
}
|
}
|
||||||
if timestamp != nil {
|
if timestamp != nil {
|
||||||
res.AuthTime = *timestamp
|
res.AuthTime = *timestamp
|
||||||
}
|
}
|
||||||
|
if user.id != "" {
|
||||||
|
userID, err := uuid.Parse(user.id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid format for user id: %w", err)
|
||||||
|
}
|
||||||
|
res.User = &model.User{
|
||||||
|
ID: userID,
|
||||||
|
Username: user.username,
|
||||||
|
Email: user.email,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *sqlAuthRequestDB) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
||||||
|
logger.L.Debugf("Getting auth request with user id %s", id)
|
||||||
|
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "claim_user_id" = ?`, authRequestRows)
|
||||||
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
|
|
||||||
|
var res model.AuthRequest
|
||||||
|
var user dbUser
|
||||||
|
var scopesStr []byte
|
||||||
|
|
||||||
|
var timestamp *time.Time
|
||||||
|
|
||||||
|
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
|
}
|
||||||
|
if timestamp != nil {
|
||||||
|
res.AuthTime = *timestamp
|
||||||
|
}
|
||||||
|
if user.id != "" {
|
||||||
|
userID, err := uuid.Parse(user.id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid format for user id: %w", err)
|
||||||
|
}
|
||||||
|
res.User = &model.User{
|
||||||
|
ID: userID,
|
||||||
|
Username: user.username,
|
||||||
|
Email: user.email,
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
||||||
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -65,7 +119,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when the old table is done, rename into auth_request
|
// TODO: when the old table is done, rename into auth_request
|
||||||
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL)`, authRequestRows)
|
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', '', '')`, authRequestRows)
|
||||||
_, err = tx.ExecContext(ctx, query,
|
_, err = tx.ExecContext(ctx, query,
|
||||||
req.ID, req.ClientID, req.BackendID,
|
req.ID, req.ClientID, req.BackendID,
|
||||||
scopesStr, req.RedirectURI, req.State,
|
scopesStr, req.RedirectURI, req.State,
|
||||||
|
@ -83,7 +137,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error {
|
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error {
|
||||||
logger.L.Debugf("Validating auth request %s", reqID)
|
logger.L.Debugf("Validating auth request %s", reqID)
|
||||||
tx, err := db.db.BeginTx(ctx, nil)
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -91,7 +145,7 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1 WHERE id = $2`, time.Now().UTC(), reqID.String())
|
res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1, claim_user_id = $2, claim_username = $3, claim_email = $4 WHERE id = $5`, time.Now().UTC(), user.ID, user.Username, user.Email, reqID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update in DB: %w", err)
|
return fmt.Errorf("failed to update in DB: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ type AuthRequest struct {
|
||||||
BackendID uuid.UUID
|
BackendID uuid.UUID
|
||||||
Backend *Backend
|
Backend *Backend
|
||||||
|
|
||||||
UserID uuid.UUID
|
User *User
|
||||||
|
|
||||||
DoneVal bool
|
DoneVal bool
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,10 @@ func (a AuthRequest) GetState() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AuthRequest) GetSubject() string {
|
func (a AuthRequest) GetSubject() string {
|
||||||
return a.UserID.String()
|
if a.User == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.User.ID.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AuthRequest) Done() bool {
|
func (a AuthRequest) Done() bool {
|
||||||
|
|
9
polyculeconnect/internal/model/user.go
Normal file
9
polyculeconnect/internal/model/user.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Email string
|
||||||
|
Username string
|
||||||
|
}
|
|
@ -251,7 +251,22 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
||||||
// we'll use FromRequest instead
|
logger.L.Debugf("Setting user info for user %s", userID)
|
||||||
|
|
||||||
|
parsedID, err := uuid.Parse(userID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid userID: %w", err)
|
||||||
|
}
|
||||||
|
req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByUserID(ctx, parsedID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
|
}
|
||||||
|
if req.User == nil {
|
||||||
|
return errors.New("no user associated to that ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
userinfo.PreferredUsername = req.User.Username
|
||||||
|
userinfo.Email = req.User.Email
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE "auth_request_2" DROP COLUMN claim_user_id;
|
||||||
|
ALTER TABLE "auth_request_2" DROP COLUMN claim_username;
|
||||||
|
ALTER TABLE "auth_request_2" DROP COLUMN claim_email;
|
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE "auth_request_2" ADD COLUMN claim_user_id string;
|
||||||
|
ALTER TABLE "auth_request_2" ADD COLUMN claim_username string;
|
||||||
|
ALTER TABLE "auth_request_2" ADD COLUMN claim_email string;
|
Binary file not shown.
|
@ -13,6 +13,8 @@ import (
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/ui"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/ui"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/middlewares"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/middlewares"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
@ -66,12 +68,31 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
|
||||||
}
|
}
|
||||||
|
|
||||||
controllers := map[string]http.Handler{
|
controllers := map[string]http.Handler{
|
||||||
auth.AuthCallbackRoute: middlewares.WithLogger(auth.NewAuthCallbackController(logger, st), logger),
|
|
||||||
auth.AuthRedirectRoute: middlewares.WithLogger(auth.NewAuthRedirectController(logger, st), logger),
|
|
||||||
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
||||||
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userInfoHandler := auth.NewAuthCallbackController(logger, st)
|
||||||
|
loginHandlers := map[uuid.UUID]http.Handler{}
|
||||||
|
callbackHandlers := map[uuid.UUID]http.Handler{}
|
||||||
|
|
||||||
|
backends, err := st.LocalStorage.BackendStorage().GetAllBackends(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get list of backends from storage: %w", err)
|
||||||
|
}
|
||||||
|
for _, b := range backends {
|
||||||
|
provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, []string{"openid", "email"})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loginHandlers[b.ID] = middlewares.WithLogger(auth.NewAuthRedirectController(logger, provider, st), logger)
|
||||||
|
callbackHandlers[b.ID] = middlewares.WithLogger(rp.CodeExchangeHandler(rp.UserinfoCallback(userInfoHandler.HandleUserInfoCallback), provider), logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
controllers[auth.AuthRedirectRoute] = middlewares.WithLogger(auth.NewAuthDispatchController(logger, st, loginHandlers), logger)
|
||||||
|
controllers[auth.AuthCallbackRoute] = middlewares.WithLogger(auth.NewCallbackDispatchController(logger, st, callbackHandlers), logger)
|
||||||
|
|
||||||
m := http.NewServeMux()
|
m := http.NewServeMux()
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
|
|
Loading…
Reference in a new issue