Cleanup DB a bit, and start correctly handling users (#42)
Some checks failed
/ docker-build-only (push) Failing after 27s
/ go-test (push) Failing after 1m18s

This commit is contained in:
Melora Hugues 2024-10-18 22:06:05 +02:00
parent 93d7b13928
commit 8d805cefe6
34 changed files with 312 additions and 186 deletions

View file

@ -1,12 +1,14 @@
package cmd package cmd
import ( import (
"context"
"fmt" "fmt"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services/app"
"github.com/dexidp/dex/storage"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -60,7 +62,12 @@ func generateSecret(interactive bool, currentValue, valueName string) (string, e
func addNewApp() { func addNewApp() {
c := utils.InitConfig("") c := utils.InitConfig("")
s := utils.InitStorage(c) logger.Init(c.LogLevel)
s, err := db.New(*c)
if err != nil {
utils.Failf("failed to init storage: %s", err.Error())
}
clientID, err := generateSecret(appInteractive, appClientID, "client ID") clientID, err := generateSecret(appInteractive, appClientID, "client ID")
if err != nil { if err != nil {
@ -71,14 +78,18 @@ func addNewApp() {
utils.Fail(err.Error()) utils.Fail(err.Error())
} }
appConf := storage.Client{ appConf := model.ClientConfig{
ID: clientID, ID: clientID,
Secret: clientSecret, Secret: clientSecret,
Name: appName, Name: appName,
RedirectURIs: appRedirectURIs, RedirectURIs: appRedirectURIs,
} }
if err := app.New(s).AddApp(appConf); err != nil { clt := model.Client{
utils.Failf("Failed to add new app to storage: %s", err.Error()) ClientConfig: appConf,
}
if err := s.ClientStorage().AddClient(context.Background(), &clt); err != nil {
utils.Failf("failed to create app: %s", err)
} }
fmt.Printf("New app %s added.\n", appName) fmt.Printf("New app %s added.\n", appName)

View file

@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
@ -18,6 +19,7 @@ var (
backendIssuer string backendIssuer string
backendClientID string backendClientID string
backendClientSecret string backendClientSecret string
backendScopes []string
) )
var backendAddCmd = &cobra.Command{ var backendAddCmd = &cobra.Command{
@ -38,6 +40,15 @@ Parameters to provide:
}, },
} }
func scopesValid(scopes []string) bool {
for _, s := range scopes {
if s == "openid" {
return true
}
}
return false
}
func addNewBackend() { func addNewBackend() {
c := utils.InitConfig("") c := utils.InitConfig("")
logger.Init(c.LogLevel) logger.Init(c.LogLevel)
@ -54,6 +65,10 @@ func addNewBackend() {
utils.Fail("Empty client secret") utils.Fail("Empty client secret")
} }
if !scopesValid(backendScopes) {
utils.Failf("Invalid list of scopes %s", strings.Join(backendScopes, ", "))
}
backendIDUUID := uuid.New() backendIDUUID := uuid.New()
backendConf := model.Backend{ backendConf := model.Backend{
@ -64,6 +79,7 @@ func addNewBackend() {
ClientSecret: backendClientSecret, ClientSecret: backendClientSecret,
Issuer: backendIssuer, Issuer: backendIssuer,
RedirectURI: c.RedirectURI(), RedirectURI: c.RedirectURI(),
Scopes: backendScopes,
}, },
} }
if err := s.BackendStorage().AddBackend(context.Background(), &backendConf); err != nil { if err := s.BackendStorage().AddBackend(context.Background(), &backendConf); err != nil {
@ -81,4 +97,5 @@ func init() {
backendAddCmd.Flags().StringVarP(&backendIssuer, "issuer", "d", "", "Full hostname of the backend") backendAddCmd.Flags().StringVarP(&backendIssuer, "issuer", "d", "", "Full hostname of the backend")
backendAddCmd.Flags().StringVarP(&backendClientID, "client-id", "", "", "OIDC Client ID for the backend") backendAddCmd.Flags().StringVarP(&backendClientID, "client-id", "", "", "OIDC Client ID for the backend")
backendAddCmd.Flags().StringVarP(&backendClientSecret, "client-secret", "", "", "OIDC Client secret for the backend") backendAddCmd.Flags().StringVarP(&backendClientSecret, "client-secret", "", "", "OIDC Client secret for the backend")
backendAddCmd.Flags().StringArrayVarP(&backendScopes, "scopes", "s", []string{"openid", "profile", "email"}, "OIDC Scopes asked to the backend")
} }

View file

@ -4,10 +4,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strings"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -21,10 +23,12 @@ Optional parameters:
- app-id: id of the backend to display. If empty, display the list of available backends instead`, - app-id: id of the backend to display. If empty, display the list of available backends instead`,
Args: cobra.MaximumNArgs(1), Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
s, err := db.New(*utils.InitConfig("")) conf := utils.InitConfig("")
s, err := db.New(*conf)
if err != nil { if err != nil {
utils.Failf("Failed to init storage: %s", err.Error()) utils.Failf("Failed to init storage: %s", err.Error())
} }
logger.Init(conf.LogLevel)
if len(args) > 0 { if len(args) > 0 {
showBackend(args[0], s.BackendStorage()) showBackend(args[0], s.BackendStorage())
@ -50,6 +54,7 @@ func showBackend(backendName string, backendService backend.BackendDB) {
printProperty("Client ID", backendConfig.Config.ClientID) printProperty("Client ID", backendConfig.Config.ClientID)
printProperty("Client secret", backendConfig.Config.ClientSecret) printProperty("Client secret", backendConfig.Config.ClientSecret)
printProperty("Redirect URI", backendConfig.Config.RedirectURI) printProperty("Redirect URI", backendConfig.Config.RedirectURI)
printProperty("Scopes", strings.Join(backendConfig.Config.Scopes, ", "))
} }
func listBackends(backendStorage backend.BackendDB) { func listBackends(backendStorage backend.BackendDB) {

View file

@ -19,7 +19,6 @@ import (
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/server" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/server"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -49,9 +48,6 @@ func serve() {
logger.Init(conf.LogLevel) logger.Init(conf.LogLevel)
logger.L.Infof("Initialized logger with level %v", conf.LogLevel) logger.L.Infof("Initialized logger with level %v", conf.LogLevel)
storageType := utils.InitStorage(conf)
logger.L.Infof("Initialized storage backend %q", conf.StorageType)
userDB, err := db.New(*conf) userDB, err := db.New(*conf)
if err != nil { if err != nil {
utils.Failf("failed to init user DB: %s", err.Error()) utils.Failf("failed to init user DB: %s", err.Error())
@ -83,38 +79,11 @@ func serve() {
op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware), op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware),
} }
// logger.L.Info("Initializing authentication backends")
// backendConfs, err := userDB.BackendStorage().GetAllBackends(context.Background())
// if err != nil {
// utils.Failf("failed to get backend configs from the DB: %s", err.Error())
// }
// TODO: check if we need to do it this way or
// - do a try-loop?
// - only init when using them in a request?
// for _, c := range backendConfs {
// logger.L.Debugf("Initializing backend %s", c.Name)
// b, err := client.New(context.Background(), c, logger.L)
// if err != nil {
// utils.Failf("failed to init backend client: %s", err.Error())
// }
// backends[c.ID] = b
// }
// if len(backends) == 0 {
// logger.L.Warn("No auth backend loaded")
// } else {
// logger.L.Infof("Initialized %d auth backends", len(backends))
// }
provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...) provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...)
if err != nil { if err != nil {
utils.Failf("failed to init OIDC provider: %s", err.Error()) utils.Failf("failed to init OIDC provider: %s", err.Error())
} }
if err := services.AddDefaultBackend(storageType); err != nil {
logger.L.Errorf("Failed to add connector for backend RefuseAll to stage: %s", err.Error())
}
logger.L.Info("Initializing server") logger.L.Info("Initializing server")
s, err := server.New(conf, provider, &st, logger.L) s, err := server.New(conf, provider, &st, logger.L)
if err != nil { if err != nil {

View file

@ -36,17 +36,27 @@ func (c *AuthCallbackController) HandleUserInfoCallback(w http.ResponseWriter, r
c.l.Infof("Successful login from %s", info.Email) c.l.Infof("Successful login from %s", info.Email)
user := model.User{ user := model.User{
ID: uuid.New(), Subject: info.Subject,
Email: info.Email, Name: info.Name,
Username: info.PreferredUsername, FamilyName: info.FamilyName,
GivenName: info.GivenName,
Picture: info.Picture,
UpdatedAt: info.UpdatedAt.AsTime(),
Email: info.Email,
EmailVerified: bool(info.EmailVerified),
} }
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, &user) err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, user.Subject)
if err != nil { if err != nil {
c.l.Errorf("Failed to validate auth request from storage: %s", err) c.l.Errorf("Failed to validate auth request from storage: %s", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l) helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
return return
} }
if err := c.st.LocalStorage.UserStorage().AddUser(r.Context(), &user); err != nil {
c.l.Errorf("Failed to add related user to storageL %w", err)
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
return
}
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound) http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
} }

View file

@ -46,20 +46,5 @@ func (c *AuthRedirectController) ServeHTTP(w http.ResponseWriter, r *http.Reques
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l) helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
return return
} }
// backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID)
// if err != nil {
// c.l.Errorf("Failed to get backend from DB: %s", err)
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
// return
// }
// provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes)
// if err != nil {
// c.l.Errorf("Failed to init relying party: %s", err)
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
// return
// }
rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r) rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r)
} }

View file

@ -29,7 +29,7 @@ func (db *sqlAuthCodeDB) CreateAuthCode(ctx context.Context, code model.AuthCode
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
query := `INSERT INTO "auth_code_2" ("id", "auth_request_id", "code") VALUES ($1, $2, $3)` query := `INSERT INTO "auth_code" ("id", "auth_request_id", "code") VALUES ($1, $2, $3)`
_, err = tx.ExecContext(ctx, query, code.CodeID, code.RequestID, code.Code) _, err = tx.ExecContext(ctx, query, code.CodeID, code.RequestID, code.Code)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert in DB: %w", err) return fmt.Errorf("failed to insert in DB: %w", err)
@ -43,7 +43,7 @@ func (db *sqlAuthCodeDB) CreateAuthCode(ctx context.Context, code model.AuthCode
func (db *sqlAuthCodeDB) GetAuthCodeByCode(ctx context.Context, code string) (*model.AuthCode, error) { func (db *sqlAuthCodeDB) GetAuthCodeByCode(ctx context.Context, code string) (*model.AuthCode, error) {
logger.L.Debugf("Getting auth code %s from DB", code) logger.L.Debugf("Getting auth code %s from DB", code)
query := `SELECT "id", "auth_request_id", "code" FROM "auth_code_2" WHERE "code" = ?` query := `SELECT "id", "auth_request_id", "code" FROM "auth_code" WHERE "code" = ?`
row := db.db.QueryRowContext(ctx, query, code) row := db.db.QueryRowContext(ctx, query, code)
var res model.AuthCode var res model.AuthCode

View file

@ -15,13 +15,12 @@ import (
var ErrNotFound = errors.New("backend not found") var ErrNotFound = errors.New("backend not found")
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "claim_user_id", "claim_username", "claim_email"` const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id"`
type AuthRequestDB interface { type AuthRequestDB interface {
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
} }
@ -29,75 +28,22 @@ type sqlAuthRequestDB struct {
db *sql.DB db *sql.DB
} }
type dbUser struct {
id string
username string
email string
}
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) { func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
logger.L.Debugf("Getting auth request with id %s", id) logger.L.Debugf("Getting auth request with id %s", id)
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows) query := fmt.Sprintf(`SELECT %s FROM "auth_request" WHERE "id" = ?`, authRequestRows)
row := db.db.QueryRowContext(ctx, query, id) row := db.db.QueryRowContext(ctx, query, id)
var res model.AuthRequest var res model.AuthRequest
var user dbUser
var scopesStr []byte var scopesStr []byte
var timestamp *time.Time var timestamp *time.Time
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, &timestamp, &user.id, &user.username, &user.email); err != nil { if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, &timestamp, &res.UserID); err != nil {
return nil, fmt.Errorf("failed to get auth request from DB: %w", err) return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
} }
if timestamp != nil { if timestamp != nil {
res.AuthTime = *timestamp res.AuthTime = *timestamp
} }
if user.id != "" {
userID, err := uuid.Parse(user.id)
if err != nil {
return nil, fmt.Errorf("invalid format for user id: %w", err)
}
res.User = &model.User{
ID: userID,
Username: user.username,
Email: user.email,
}
}
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
return nil, fmt.Errorf("invalid format for scopes: %w", err)
}
return &res, nil
}
func (db *sqlAuthRequestDB) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
logger.L.Debugf("Getting auth request with user id %s", id)
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "claim_user_id" = ?`, authRequestRows)
row := db.db.QueryRowContext(ctx, query, id)
var res model.AuthRequest
var user dbUser
var scopesStr []byte
var timestamp *time.Time
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, &timestamp, &user.id, &user.username, &user.email); err != nil {
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
}
if timestamp != nil {
res.AuthTime = *timestamp
}
if user.id != "" {
userID, err := uuid.Parse(user.id)
if err != nil {
return nil, fmt.Errorf("invalid format for user id: %w", err)
}
res.User = &model.User{
ID: userID,
Username: user.username,
Email: user.email,
}
}
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil { if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
return nil, fmt.Errorf("invalid format for scopes: %w", err) return nil, fmt.Errorf("invalid format for scopes: %w", err)
} }
@ -118,8 +64,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
return fmt.Errorf("failed to serialize scopes: %w", err) return fmt.Errorf("failed to serialize scopes: %w", err)
} }
// TODO: when the old table is done, rename into auth_request query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '')`, authRequestRows)
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', '', '')`, authRequestRows)
_, err = tx.ExecContext(ctx, query, _, err = tx.ExecContext(ctx, query,
req.ID, req.ClientID, req.BackendID, req.ID, req.ClientID, req.BackendID,
scopesStr, req.RedirectURI, req.State, scopesStr, req.RedirectURI, req.State,
@ -137,7 +82,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
return nil return nil
} }
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error { func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error {
logger.L.Debugf("Validating auth request %s", reqID) logger.L.Debugf("Validating auth request %s", reqID)
tx, err := db.db.BeginTx(ctx, nil) tx, err := db.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
@ -145,7 +90,7 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1, claim_user_id = $2, claim_username = $3, claim_email = $4 WHERE id = $5`, time.Now().UTC(), user.ID, user.Username, user.Email, reqID.String()) res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET done = true, auth_time = $1, user_id = $2 WHERE id = $3`, time.Now().UTC(), userID, reqID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update in DB: %w", err) return fmt.Errorf("failed to update in DB: %w", err)
} }
@ -172,7 +117,7 @@ func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UU
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
_, err = tx.ExecContext(ctx, `DELETE FROM "auth_request_2" WHERE id = $1`, reqID.String()) _, err = tx.ExecContext(ctx, `DELETE FROM "auth_request" WHERE id = $1`, reqID.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to delete auth request: %w", err) return fmt.Errorf("failed to delete auth request: %w", err)
} }

View file

@ -3,6 +3,7 @@ package backend
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -13,7 +14,7 @@ import (
var ErrNotFound = errors.New("backend not found") var ErrNotFound = errors.New("backend not found")
const backendRows = `"id", "name", "oidc_issuer", "oidc_client_id", "oidc_client_secret", "oidc_redirect_uri"` const backendRows = `"id", "name", "oidc_issuer", "oidc_client_id", "oidc_client_secret", "oidc_redirect_uri", "oidc_scopes"`
type scannable interface { type scannable interface {
Scan(dest ...any) error Scan(dest ...any) error
@ -36,13 +37,19 @@ type sqlBackendDB struct {
func backendFromRow(row scannable) (*model.Backend, error) { func backendFromRow(row scannable) (*model.Backend, error) {
var res model.Backend var res model.Backend
var scopesStr []byte
if err := row.Scan(&res.ID, &res.Name, &res.Config.Issuer, &res.Config.ClientID, &res.Config.ClientSecret, &res.Config.RedirectURI); err != nil { fmt.Println(string(scopesStr))
if err := row.Scan(&res.ID, &res.Name, &res.Config.Issuer, &res.Config.ClientID, &res.Config.ClientSecret, &res.Config.RedirectURI, &scopesStr); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} }
return nil, fmt.Errorf("invalid format for backend: %w", err) return nil, fmt.Errorf("invalid format for backend: %w", err)
} }
if err := json.Unmarshal(scopesStr, &res.Config.Scopes); err != nil {
return nil, fmt.Errorf("invalid value for scopes: %w", err)
}
return &res, nil return &res, nil
} }
@ -85,12 +92,18 @@ func (db *sqlBackendDB) AddBackend(ctx context.Context, newBackend *model.Backen
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
query := fmt.Sprintf(`INSERT INTO "backend" (%s) VALUES ($1, $2, $3, $4, $5, $6)`, backendRows) scopesStr, err := json.Marshal(newBackend.Config.Scopes)
if err != nil {
return fmt.Errorf("failed to serialize scopes: %w", err)
}
query := fmt.Sprintf(`INSERT INTO "backend" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7)`, backendRows)
_, err = tx.ExecContext( _, err = tx.ExecContext(
ctx, query, ctx, query,
newBackend.ID, newBackend.Name, newBackend.ID, newBackend.Name,
newBackend.Config.Issuer, newBackend.Config.ClientID, newBackend.Config.Issuer, newBackend.Config.ClientID,
newBackend.Config.ClientSecret, newBackend.Config.RedirectURI, newBackend.Config.ClientSecret, newBackend.Config.RedirectURI,
scopesStr,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to insert in DB: %w", err) return fmt.Errorf("failed to insert in DB: %w", err)

View file

@ -9,6 +9,7 @@ import (
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
) )
type Storage interface { type Storage interface {
@ -17,6 +18,7 @@ type Storage interface {
BackendStorage() backend.BackendDB BackendStorage() backend.BackendDB
AuthRequestStorage() authrequest.AuthRequestDB AuthRequestStorage() authrequest.AuthRequestDB
AuthCodeStorage() authcode.AuthCodeDB AuthCodeStorage() authcode.AuthCodeDB
UserStorage() user.UserDB
} }
type sqlStorage struct { type sqlStorage struct {
@ -43,6 +45,10 @@ func (s *sqlStorage) AuthCodeStorage() authcode.AuthCodeDB {
return authcode.New(s.db) return authcode.New(s.db)
} }
func (s *sqlStorage) UserStorage() user.UserDB {
return user.New(s.db)
}
func New(conf config.AppConfig) (Storage, error) { func New(conf config.AppConfig) (Storage, error) {
db, err := sql.Open("sqlite3", conf.StorageConfig.File) db, err := sql.Open("sqlite3", conf.StorageConfig.File)
if err != nil { if err != nil {

View file

@ -18,6 +18,7 @@ const clientRows = `"client"."id", "client"."secret", "client"."redirect_uris",
type ClientDB interface { type ClientDB interface {
GetClientByID(ctx context.Context, id string) (*model.Client, error) GetClientByID(ctx context.Context, id string) (*model.Client, error)
AddClient(ctx context.Context, client *model.Client) error
} }
type sqlClientDB struct { type sqlClientDB struct {
@ -32,6 +33,14 @@ func strArrayToSlice(rawVal string) []string {
return res return res
} }
func sliceToStrArray(rawVal []string) string {
res, err := json.Marshal(rawVal)
if err != nil {
return "[]"
}
return string(res)
}
func clientFromRow(row *sql.Row) (*model.Client, error) { func clientFromRow(row *sql.Row) (*model.Client, error) {
var res model.Client var res model.Client
redirectURIsStr := "" redirectURIsStr := ""
@ -57,6 +66,28 @@ func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Cli
return clientFromRow(row) return clientFromRow(row)
} }
func (db *sqlClientDB) AddClient(ctx context.Context, client *model.Client) error {
logger.L.Debugf("Creating client %s", client.Name)
query := `INSERT INTO "client" ("id", "secret", "redirect_uris", "trusted_peers", "name") VALUES ($1, $2, $3, $4, $5)`
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if affectedRows, err := tx.ExecContext(ctx, query, client.ID, client.Secret, sliceToStrArray(client.RedirectURIs()), sliceToStrArray(client.TrustedPeers), client.Name); err != nil {
return fmt.Errorf("failed to insert in DB: %w", err)
} else if nbAffected, err := affectedRows.RowsAffected(); err != nil {
return fmt.Errorf("failed to check number of affected rows: %w", err)
} else if nbAffected != 1 {
return fmt.Errorf("unexpected number of affected rows: %d", nbAffected)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func New(db *sql.DB) *sqlClientDB { func New(db *sql.DB) *sqlClientDB {
return &sqlClientDB{db: db} return &sqlClientDB{db: db}
} }

View file

@ -0,0 +1,63 @@
package user
import (
"context"
"database/sql"
"errors"
"fmt"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
)
type UserDB interface {
AddUser(ctx context.Context, user *model.User) error
GetUserBySubject(ctx context.Context, subject string) (*model.User, error)
}
var ErrNotFound = errors.New("not found")
const getUserQuery = `
SELECT id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified
FROM user
WHERE id = ?
`
const insertUserQuery = `
INSERT INTO user (id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
type sqlUserDB struct {
db *sql.DB
}
func (db *sqlUserDB) GetUserBySubject(ctx context.Context, subject string) (*model.User, error) {
row := db.db.QueryRowContext(ctx, getUserQuery, subject)
var res model.User
if err := row.Scan(&res.Subject, &res.Name, &res.FamilyName, &res.GivenName, &res.Nickname, &res.Picture, &res.UpdatedAt, &res.Email, &res.EmailVerified); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("failed to read result from DB: %w", err)
}
return &res, nil
}
func (db *sqlUserDB) AddUser(ctx context.Context, user *model.User) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if _, err := tx.ExecContext(ctx, insertUserQuery, user.Subject, user.Name, user.FamilyName, user.GivenName, user.Nickname, user.Picture, user.UpdatedAt, user.Email, user.EmailVerified); err != nil {
return fmt.Errorf("failed to insert in DB: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func New(db *sql.DB) *sqlUserDB {
return &sqlUserDB{db: db}
}

View file

@ -30,7 +30,8 @@ type AuthRequest struct {
BackendID uuid.UUID BackendID uuid.UUID
Backend *Backend Backend *Backend
User *User UserID string
User *User
DoneVal bool DoneVal bool
} }
@ -94,7 +95,7 @@ func (a AuthRequest) GetSubject() string {
if a.User == nil { if a.User == nil {
return "" return ""
} }
return a.User.ID.String() return a.User.Subject
} }
func (a AuthRequest) Done() bool { func (a AuthRequest) Done() bool {

View file

@ -7,6 +7,7 @@ type BackendOIDCConfig struct {
ClientID string ClientID string
ClientSecret string ClientSecret string
RedirectURI string RedirectURI string
Scopes []string
} }
type Backend struct { type Backend struct {

View file

@ -1,9 +1,22 @@
package model package model
import "github.com/google/uuid" import (
"time"
)
type User struct { type User struct {
ID uuid.UUID // Part of openid scope
Email string Subject string
Username string
// Part of profile scope
Name string
FamilyName string
GivenName string
Nickname string
Picture string
UpdatedAt time.Time
// part of email scope
Email string
EmailVerified bool
} }

View file

@ -66,7 +66,20 @@ func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.Aut
return nil, fmt.Errorf("invalid format for uuid: %w", err) return nil, fmt.Errorf("invalid format for uuid: %w", err)
} }
return s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id) req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
}
if req.UserID == "" {
return req, nil
}
user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get user information from DB: %w", err)
}
req.User = user
return req, nil
} }
func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) { func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) {
@ -77,7 +90,20 @@ func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op
return nil, fmt.Errorf("failed to get auth code from DB: %w", err) return nil, fmt.Errorf("failed to get auth code from DB: %w", err)
} }
return s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authCode.RequestID) req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authCode.RequestID)
if err != nil {
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
}
if req.UserID == "" {
return req, nil
}
user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get user information from DB: %w", err)
}
req.User = user
return req, nil
} }
func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error { func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
@ -253,20 +279,28 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
logger.L.Debugf("Setting user info for user %s", userID) logger.L.Debugf("Setting user info for user %s", userID)
parsedID, err := uuid.Parse(userID) user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, userID)
if err != nil { if err != nil {
return fmt.Errorf("invalid userID: %w", err) return fmt.Errorf("failed to get user from DB: %w", err)
} }
req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByUserID(ctx, parsedID)
if err != nil { for _, s := range scopes {
return fmt.Errorf("failed to get auth request from DB: %w", err) switch s {
} case "openid":
if req.User == nil { userinfo.Subject = user.Subject
return errors.New("no user associated to that ID") case "profile":
userinfo.Name = user.Name
userinfo.FamilyName = user.FamilyName
userinfo.GivenName = user.GivenName
userinfo.Nickname = user.Nickname
userinfo.Picture = user.Picture
userinfo.UpdatedAt = oidc.FromTime(user.UpdatedAt)
case "email":
userinfo.Email = user.Email
userinfo.EmailVerified = oidc.Bool(user.EmailVerified)
}
} }
userinfo.PreferredUsername = req.User.Username
userinfo.Email = req.User.Email
return nil return nil
} }

View file

@ -1 +0,0 @@
DROP TABLE "backend";

View file

@ -1,8 +0,0 @@
CREATE TABLE "backend" (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
oidc_issuer TEXT NOT NULL,
oidc_client_id TEXT NOT NULL,
oidc_client_secret TEXT NOT NULL,
oidc_redirect_uri TEXT NOT NULL
);

View file

@ -0,0 +1,5 @@
DROP TABLE "auth_code";
DROP TABLE "auth_request";
DROP TABLE "user";
DROP TABLE "backend";
DROP TABLE "client";

View file

@ -0,0 +1,58 @@
CREATE TABLE "backend" (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
oidc_issuer TEXT NOT NULL,
oidc_client_id TEXT NOT NULL,
oidc_client_secret TEXT NOT NULL,
oidc_redirect_uri TEXT NOT NULL,
oidc_scopes blob NOT NULL DEFAULT '[]' -- list of strings, json-encoded,
);
CREATE TABLE "client" (
id TEXT NOT NULL PRIMARY KEY,
secret TEXT NOT NULL,
redirect_uris blob NOT NULL,
trusted_peers blob NOT NULL,
public integer NOT NULL DEFAULT 0,
name TEXT NOT NULL
);
CREATE TABLE "user" (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL DEFAULT '',
family_name TEXT NOT NULL DEFAULT '',
given_name TEXT NOT NULL DEFAULT '',
nickname TEXT NOT NULL DEFAULT '',
picture TEXT NOT NULL DEFAULT '',
updated_at timestamp,
email TEXT NOT NULL DEFAULT '',
email_verified INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE "auth_request" (
id TEXT NOT NULL PRIMARY KEY,
client_id TEXT NOT NULL,
backend_id TEXT NOT NULL,
scopes blob NOT NULL, -- list of strings, json-encoded
redirect_uri TEXT NOT NULL,
state TEXT NOT NULL,
nonce TEXT NOT NULL,
response_type TEXT NOT NULL,
creation_time timestamp NOT NULL,
done INTEGER NOT NULL DEFAULT 0,
code_challenge STRING NOT NULL DEFAULT '',
code_challenge_method STRING NOT NULL DEFAULT '',
auth_time timestamp,
user_id TEXT NOT NULL DEFAULT '',
FOREIGN KEY(backend_id) REFERENCES backend(id),
FOREIGN KEY(client_id) REFERENCES client(id),
FOREIGN KEY(user_id) REFERENCES user(id)
);
CREATE TABLE "auth_code" (
id TEXT NOT NULL PRIMARY KEY,
code TEXT NOT NULL,
auth_request_id TEXT NOT NULL,
FOREIGN KEY(auth_request_id) REFERENCES auth_request(id)
);

View file

@ -1 +0,0 @@
DROP TABLE "auth_request_2";

View file

@ -1,11 +0,0 @@
CREATE TABLE "auth_request_2" (
id TEXT NOT NULL PRIMARY KEY,
client_id TEXT NOT NULL,
backend_id TEXT NOT NULL,
scopes blob NOT NULL, -- list of strings, json-encoded
redirect_uri TEXT NOT NULL,
state TEXT NOT NULL,
nonce TEXT NOT NULL,
response_type TEXT NOT NULL,
creation_time timestamp NOT NULL
);

View file

@ -1 +0,0 @@
ALTER TABLE "auth_request_2" DROP COLUMN done;

View file

@ -1 +0,0 @@
ALTER TABLE "auth_request_2" ADD COLUMN done INTEGER NOT NULL DEFAULT 0;

View file

@ -1 +0,0 @@
DROP TABLE "auth_code_2";

View file

@ -1,5 +0,0 @@
CREATE TABLE "auth_code_2" (
id TEXT NOT NULL PRIMARY KEY,
code TEXT NOT NULL,
auth_request_id TEXT NOT NULL
);

View file

@ -1,2 +0,0 @@
ALTER TABLE "auth_request_2" DROP COLUMN code_challenge;
ALTER TABLE "auth_request_2" DROP COLUMN code_challenge_method;

View file

@ -1,2 +0,0 @@
ALTER TABLE "auth_request_2" ADD COLUMN code_challenge STRING NOT NULL DEFAULT '';
ALTER TABLE "auth_request_2" ADD COLUMN code_challenge_method STRING NOT NULL DEFAULT '';

View file

@ -1 +0,0 @@
ALTER TABLE "auth_request_2" DROP COLUMN auth_time;

View file

@ -1 +0,0 @@
ALTER TABLE "auth_request_2" ADD COLUMN auth_time timestamp;

View file

@ -1,3 +0,0 @@
ALTER TABLE "auth_request_2" DROP COLUMN claim_user_id;
ALTER TABLE "auth_request_2" DROP COLUMN claim_username;
ALTER TABLE "auth_request_2" DROP COLUMN claim_email;

View file

@ -1,3 +0,0 @@
ALTER TABLE "auth_request_2" ADD COLUMN claim_user_id string;
ALTER TABLE "auth_request_2" ADD COLUMN claim_username string;
ALTER TABLE "auth_request_2" ADD COLUMN claim_email string;

Binary file not shown.

View file

@ -81,7 +81,7 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
return nil, fmt.Errorf("failed to get list of backends from storage: %w", err) return nil, fmt.Errorf("failed to get list of backends from storage: %w", err)
} }
for _, b := range backends { for _, b := range backends {
provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, []string{"openid", "email"}) provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, b.Config.Scopes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err) return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err)
} }