Cleanup DB a bit, and start correctly handling users (#42)
This commit is contained in:
parent
93d7b13928
commit
8d805cefe6
34 changed files with 312 additions and 186 deletions
|
@ -1,12 +1,14 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services/app"
|
|
||||||
"github.com/dexidp/dex/storage"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,7 +62,12 @@ func generateSecret(interactive bool, currentValue, valueName string) (string, e
|
||||||
|
|
||||||
func addNewApp() {
|
func addNewApp() {
|
||||||
c := utils.InitConfig("")
|
c := utils.InitConfig("")
|
||||||
s := utils.InitStorage(c)
|
logger.Init(c.LogLevel)
|
||||||
|
|
||||||
|
s, err := db.New(*c)
|
||||||
|
if err != nil {
|
||||||
|
utils.Failf("failed to init storage: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
clientID, err := generateSecret(appInteractive, appClientID, "client ID")
|
clientID, err := generateSecret(appInteractive, appClientID, "client ID")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -71,14 +78,18 @@ func addNewApp() {
|
||||||
utils.Fail(err.Error())
|
utils.Fail(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
appConf := storage.Client{
|
appConf := model.ClientConfig{
|
||||||
ID: clientID,
|
ID: clientID,
|
||||||
Secret: clientSecret,
|
Secret: clientSecret,
|
||||||
Name: appName,
|
Name: appName,
|
||||||
RedirectURIs: appRedirectURIs,
|
RedirectURIs: appRedirectURIs,
|
||||||
}
|
}
|
||||||
if err := app.New(s).AddApp(appConf); err != nil {
|
clt := model.Client{
|
||||||
utils.Failf("Failed to add new app to storage: %s", err.Error())
|
ClientConfig: appConf,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.ClientStorage().AddClient(context.Background(), &clt); err != nil {
|
||||||
|
utils.Failf("failed to create app: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("New app %s added.\n", appName)
|
fmt.Printf("New app %s added.\n", appName)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||||
|
@ -18,6 +19,7 @@ var (
|
||||||
backendIssuer string
|
backendIssuer string
|
||||||
backendClientID string
|
backendClientID string
|
||||||
backendClientSecret string
|
backendClientSecret string
|
||||||
|
backendScopes []string
|
||||||
)
|
)
|
||||||
|
|
||||||
var backendAddCmd = &cobra.Command{
|
var backendAddCmd = &cobra.Command{
|
||||||
|
@ -38,6 +40,15 @@ Parameters to provide:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func scopesValid(scopes []string) bool {
|
||||||
|
for _, s := range scopes {
|
||||||
|
if s == "openid" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func addNewBackend() {
|
func addNewBackend() {
|
||||||
c := utils.InitConfig("")
|
c := utils.InitConfig("")
|
||||||
logger.Init(c.LogLevel)
|
logger.Init(c.LogLevel)
|
||||||
|
@ -54,6 +65,10 @@ func addNewBackend() {
|
||||||
utils.Fail("Empty client secret")
|
utils.Fail("Empty client secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !scopesValid(backendScopes) {
|
||||||
|
utils.Failf("Invalid list of scopes %s", strings.Join(backendScopes, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
backendIDUUID := uuid.New()
|
backendIDUUID := uuid.New()
|
||||||
|
|
||||||
backendConf := model.Backend{
|
backendConf := model.Backend{
|
||||||
|
@ -64,6 +79,7 @@ func addNewBackend() {
|
||||||
ClientSecret: backendClientSecret,
|
ClientSecret: backendClientSecret,
|
||||||
Issuer: backendIssuer,
|
Issuer: backendIssuer,
|
||||||
RedirectURI: c.RedirectURI(),
|
RedirectURI: c.RedirectURI(),
|
||||||
|
Scopes: backendScopes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err := s.BackendStorage().AddBackend(context.Background(), &backendConf); err != nil {
|
if err := s.BackendStorage().AddBackend(context.Background(), &backendConf); err != nil {
|
||||||
|
@ -81,4 +97,5 @@ func init() {
|
||||||
backendAddCmd.Flags().StringVarP(&backendIssuer, "issuer", "d", "", "Full hostname of the backend")
|
backendAddCmd.Flags().StringVarP(&backendIssuer, "issuer", "d", "", "Full hostname of the backend")
|
||||||
backendAddCmd.Flags().StringVarP(&backendClientID, "client-id", "", "", "OIDC Client ID for the backend")
|
backendAddCmd.Flags().StringVarP(&backendClientID, "client-id", "", "", "OIDC Client ID for the backend")
|
||||||
backendAddCmd.Flags().StringVarP(&backendClientSecret, "client-secret", "", "", "OIDC Client secret for the backend")
|
backendAddCmd.Flags().StringVarP(&backendClientSecret, "client-secret", "", "", "OIDC Client secret for the backend")
|
||||||
|
backendAddCmd.Flags().StringArrayVarP(&backendScopes, "scopes", "s", []string{"openid", "profile", "email"}, "OIDC Scopes asked to the backend")
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/cmd/utils"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
@ -21,10 +23,12 @@ Optional parameters:
|
||||||
- app-id: id of the backend to display. If empty, display the list of available backends instead`,
|
- app-id: id of the backend to display. If empty, display the list of available backends instead`,
|
||||||
Args: cobra.MaximumNArgs(1),
|
Args: cobra.MaximumNArgs(1),
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
s, err := db.New(*utils.InitConfig(""))
|
conf := utils.InitConfig("")
|
||||||
|
s, err := db.New(*conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Failf("Failed to init storage: %s", err.Error())
|
utils.Failf("Failed to init storage: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
logger.Init(conf.LogLevel)
|
||||||
|
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
showBackend(args[0], s.BackendStorage())
|
showBackend(args[0], s.BackendStorage())
|
||||||
|
@ -50,6 +54,7 @@ func showBackend(backendName string, backendService backend.BackendDB) {
|
||||||
printProperty("Client ID", backendConfig.Config.ClientID)
|
printProperty("Client ID", backendConfig.Config.ClientID)
|
||||||
printProperty("Client secret", backendConfig.Config.ClientSecret)
|
printProperty("Client secret", backendConfig.Config.ClientSecret)
|
||||||
printProperty("Redirect URI", backendConfig.Config.RedirectURI)
|
printProperty("Redirect URI", backendConfig.Config.RedirectURI)
|
||||||
|
printProperty("Scopes", strings.Join(backendConfig.Config.Scopes, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
func listBackends(backendStorage backend.BackendDB) {
|
func listBackends(backendStorage backend.BackendDB) {
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/server"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/server"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
|
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -49,9 +48,6 @@ func serve() {
|
||||||
logger.Init(conf.LogLevel)
|
logger.Init(conf.LogLevel)
|
||||||
logger.L.Infof("Initialized logger with level %v", conf.LogLevel)
|
logger.L.Infof("Initialized logger with level %v", conf.LogLevel)
|
||||||
|
|
||||||
storageType := utils.InitStorage(conf)
|
|
||||||
logger.L.Infof("Initialized storage backend %q", conf.StorageType)
|
|
||||||
|
|
||||||
userDB, err := db.New(*conf)
|
userDB, err := db.New(*conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Failf("failed to init user DB: %s", err.Error())
|
utils.Failf("failed to init user DB: %s", err.Error())
|
||||||
|
@ -83,38 +79,11 @@ func serve() {
|
||||||
op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware),
|
op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware),
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.L.Info("Initializing authentication backends")
|
|
||||||
// backendConfs, err := userDB.BackendStorage().GetAllBackends(context.Background())
|
|
||||||
// if err != nil {
|
|
||||||
// utils.Failf("failed to get backend configs from the DB: %s", err.Error())
|
|
||||||
// }
|
|
||||||
|
|
||||||
// TODO: check if we need to do it this way or
|
|
||||||
// - do a try-loop?
|
|
||||||
// - only init when using them in a request?
|
|
||||||
// for _, c := range backendConfs {
|
|
||||||
// logger.L.Debugf("Initializing backend %s", c.Name)
|
|
||||||
// b, err := client.New(context.Background(), c, logger.L)
|
|
||||||
// if err != nil {
|
|
||||||
// utils.Failf("failed to init backend client: %s", err.Error())
|
|
||||||
// }
|
|
||||||
// backends[c.ID] = b
|
|
||||||
// }
|
|
||||||
// if len(backends) == 0 {
|
|
||||||
// logger.L.Warn("No auth backend loaded")
|
|
||||||
// } else {
|
|
||||||
// logger.L.Infof("Initialized %d auth backends", len(backends))
|
|
||||||
// }
|
|
||||||
|
|
||||||
provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...)
|
provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Failf("failed to init OIDC provider: %s", err.Error())
|
utils.Failf("failed to init OIDC provider: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := services.AddDefaultBackend(storageType); err != nil {
|
|
||||||
logger.L.Errorf("Failed to add connector for backend RefuseAll to stage: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.L.Info("Initializing server")
|
logger.L.Info("Initializing server")
|
||||||
s, err := server.New(conf, provider, &st, logger.L)
|
s, err := server.New(conf, provider, &st, logger.L)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -36,17 +36,27 @@ func (c *AuthCallbackController) HandleUserInfoCallback(w http.ResponseWriter, r
|
||||||
|
|
||||||
c.l.Infof("Successful login from %s", info.Email)
|
c.l.Infof("Successful login from %s", info.Email)
|
||||||
user := model.User{
|
user := model.User{
|
||||||
ID: uuid.New(),
|
Subject: info.Subject,
|
||||||
|
Name: info.Name,
|
||||||
|
FamilyName: info.FamilyName,
|
||||||
|
GivenName: info.GivenName,
|
||||||
|
Picture: info.Picture,
|
||||||
|
UpdatedAt: info.UpdatedAt.AsTime(),
|
||||||
Email: info.Email,
|
Email: info.Email,
|
||||||
Username: info.PreferredUsername,
|
EmailVerified: bool(info.EmailVerified),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, &user)
|
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID, user.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.Errorf("Failed to validate auth request from storage: %s", err)
|
c.l.Errorf("Failed to validate auth request from storage: %s", err)
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := c.st.LocalStorage.UserStorage().AddUser(r.Context(), &user); err != nil {
|
||||||
|
c.l.Errorf("Failed to add related user to storageL %w", err)
|
||||||
|
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
|
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,20 +46,5 @@ func (c *AuthRedirectController) ServeHTTP(w http.ResponseWriter, r *http.Reques
|
||||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID)
|
|
||||||
// if err != nil {
|
|
||||||
// c.l.Errorf("Failed to get backend from DB: %s", err)
|
|
||||||
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes)
|
|
||||||
// if err != nil {
|
|
||||||
// c.l.Errorf("Failed to init relying party: %s", err)
|
|
||||||
// helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r)
|
rp.AuthURLHandler(func() string { return requestIDStr }, c.provider).ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ func (db *sqlAuthCodeDB) CreateAuthCode(ctx context.Context, code model.AuthCode
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
query := `INSERT INTO "auth_code_2" ("id", "auth_request_id", "code") VALUES ($1, $2, $3)`
|
query := `INSERT INTO "auth_code" ("id", "auth_request_id", "code") VALUES ($1, $2, $3)`
|
||||||
_, err = tx.ExecContext(ctx, query, code.CodeID, code.RequestID, code.Code)
|
_, err = tx.ExecContext(ctx, query, code.CodeID, code.RequestID, code.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to insert in DB: %w", err)
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
@ -43,7 +43,7 @@ func (db *sqlAuthCodeDB) CreateAuthCode(ctx context.Context, code model.AuthCode
|
||||||
|
|
||||||
func (db *sqlAuthCodeDB) GetAuthCodeByCode(ctx context.Context, code string) (*model.AuthCode, error) {
|
func (db *sqlAuthCodeDB) GetAuthCodeByCode(ctx context.Context, code string) (*model.AuthCode, error) {
|
||||||
logger.L.Debugf("Getting auth code %s from DB", code)
|
logger.L.Debugf("Getting auth code %s from DB", code)
|
||||||
query := `SELECT "id", "auth_request_id", "code" FROM "auth_code_2" WHERE "code" = ?`
|
query := `SELECT "id", "auth_request_id", "code" FROM "auth_code" WHERE "code" = ?`
|
||||||
row := db.db.QueryRowContext(ctx, query, code)
|
row := db.db.QueryRowContext(ctx, query, code)
|
||||||
|
|
||||||
var res model.AuthCode
|
var res model.AuthCode
|
||||||
|
|
|
@ -15,13 +15,12 @@ import (
|
||||||
|
|
||||||
var ErrNotFound = errors.New("backend not found")
|
var ErrNotFound = errors.New("backend not found")
|
||||||
|
|
||||||
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "claim_user_id", "claim_username", "claim_email"`
|
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done", "code_challenge", "code_challenge_method", "auth_time", "user_id"`
|
||||||
|
|
||||||
type AuthRequestDB interface {
|
type AuthRequestDB interface {
|
||||||
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
||||||
GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
|
||||||
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
||||||
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error
|
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error
|
||||||
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
DeleteAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,75 +28,22 @@ type sqlAuthRequestDB struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
type dbUser struct {
|
|
||||||
id string
|
|
||||||
username string
|
|
||||||
email string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
||||||
logger.L.Debugf("Getting auth request with id %s", id)
|
logger.L.Debugf("Getting auth request with id %s", id)
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows)
|
query := fmt.Sprintf(`SELECT %s FROM "auth_request" WHERE "id" = ?`, authRequestRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
|
|
||||||
var res model.AuthRequest
|
var res model.AuthRequest
|
||||||
var user dbUser
|
|
||||||
var scopesStr []byte
|
var scopesStr []byte
|
||||||
|
|
||||||
var timestamp *time.Time
|
var timestamp *time.Time
|
||||||
|
|
||||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil {
|
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &res.UserID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
}
|
}
|
||||||
if timestamp != nil {
|
if timestamp != nil {
|
||||||
res.AuthTime = *timestamp
|
res.AuthTime = *timestamp
|
||||||
}
|
}
|
||||||
if user.id != "" {
|
|
||||||
userID, err := uuid.Parse(user.id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid format for user id: %w", err)
|
|
||||||
}
|
|
||||||
res.User = &model.User{
|
|
||||||
ID: userID,
|
|
||||||
Username: user.username,
|
|
||||||
Email: user.email,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) GetAuthRequestByUserID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
|
||||||
logger.L.Debugf("Getting auth request with user id %s", id)
|
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "claim_user_id" = ?`, authRequestRows)
|
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
|
||||||
|
|
||||||
var res model.AuthRequest
|
|
||||||
var user dbUser
|
|
||||||
var scopesStr []byte
|
|
||||||
|
|
||||||
var timestamp *time.Time
|
|
||||||
|
|
||||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal, &res.CodeChallenge, &res.CodeChallengeMethod, ×tamp, &user.id, &user.username, &user.email); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
|
||||||
}
|
|
||||||
if timestamp != nil {
|
|
||||||
res.AuthTime = *timestamp
|
|
||||||
}
|
|
||||||
if user.id != "" {
|
|
||||||
userID, err := uuid.Parse(user.id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid format for user id: %w", err)
|
|
||||||
}
|
|
||||||
res.User = &model.User{
|
|
||||||
ID: userID,
|
|
||||||
Username: user.username,
|
|
||||||
Email: user.email,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
||||||
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -118,8 +64,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
return fmt.Errorf("failed to serialize scopes: %w", err)
|
return fmt.Errorf("failed to serialize scopes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when the old table is done, rename into auth_request
|
query := fmt.Sprintf(`INSERT INTO "auth_request" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '')`, authRequestRows)
|
||||||
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL, '', '', '')`, authRequestRows)
|
|
||||||
_, err = tx.ExecContext(ctx, query,
|
_, err = tx.ExecContext(ctx, query,
|
||||||
req.ID, req.ClientID, req.BackendID,
|
req.ID, req.ClientID, req.BackendID,
|
||||||
scopesStr, req.RedirectURI, req.State,
|
scopesStr, req.RedirectURI, req.State,
|
||||||
|
@ -137,7 +82,7 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, user *model.User) error {
|
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID, userID string) error {
|
||||||
logger.L.Debugf("Validating auth request %s", reqID)
|
logger.L.Debugf("Validating auth request %s", reqID)
|
||||||
tx, err := db.db.BeginTx(ctx, nil)
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -145,7 +90,7 @@ func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true, auth_time = $1, claim_user_id = $2, claim_username = $3, claim_email = $4 WHERE id = $5`, time.Now().UTC(), user.ID, user.Username, user.Email, reqID.String())
|
res, err := tx.ExecContext(ctx, `UPDATE "auth_request" SET done = true, auth_time = $1, user_id = $2 WHERE id = $3`, time.Now().UTC(), userID, reqID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update in DB: %w", err)
|
return fmt.Errorf("failed to update in DB: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -172,7 +117,7 @@ func (db *sqlAuthRequestDB) DeleteAuthRequest(ctx context.Context, reqID uuid.UU
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, `DELETE FROM "auth_request_2" WHERE id = $1`, reqID.String())
|
_, err = tx.ExecContext(ctx, `DELETE FROM "auth_request" WHERE id = $1`, reqID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete auth request: %w", err)
|
return fmt.Errorf("failed to delete auth request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package backend
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -13,7 +14,7 @@ import (
|
||||||
|
|
||||||
var ErrNotFound = errors.New("backend not found")
|
var ErrNotFound = errors.New("backend not found")
|
||||||
|
|
||||||
const backendRows = `"id", "name", "oidc_issuer", "oidc_client_id", "oidc_client_secret", "oidc_redirect_uri"`
|
const backendRows = `"id", "name", "oidc_issuer", "oidc_client_id", "oidc_client_secret", "oidc_redirect_uri", "oidc_scopes"`
|
||||||
|
|
||||||
type scannable interface {
|
type scannable interface {
|
||||||
Scan(dest ...any) error
|
Scan(dest ...any) error
|
||||||
|
@ -36,13 +37,19 @@ type sqlBackendDB struct {
|
||||||
|
|
||||||
func backendFromRow(row scannable) (*model.Backend, error) {
|
func backendFromRow(row scannable) (*model.Backend, error) {
|
||||||
var res model.Backend
|
var res model.Backend
|
||||||
|
var scopesStr []byte
|
||||||
|
|
||||||
if err := row.Scan(&res.ID, &res.Name, &res.Config.Issuer, &res.Config.ClientID, &res.Config.ClientSecret, &res.Config.RedirectURI); err != nil {
|
fmt.Println(string(scopesStr))
|
||||||
|
|
||||||
|
if err := row.Scan(&res.ID, &res.Name, &res.Config.Issuer, &res.Config.ClientID, &res.Config.ClientSecret, &res.Config.RedirectURI, &scopesStr); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("invalid format for backend: %w", err)
|
return nil, fmt.Errorf("invalid format for backend: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := json.Unmarshal(scopesStr, &res.Config.Scopes); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid value for scopes: %w", err)
|
||||||
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,12 +92,18 @@ func (db *sqlBackendDB) AddBackend(ctx context.Context, newBackend *model.Backen
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
query := fmt.Sprintf(`INSERT INTO "backend" (%s) VALUES ($1, $2, $3, $4, $5, $6)`, backendRows)
|
scopesStr, err := json.Marshal(newBackend.Config.Scopes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to serialize scopes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`INSERT INTO "backend" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7)`, backendRows)
|
||||||
_, err = tx.ExecContext(
|
_, err = tx.ExecContext(
|
||||||
ctx, query,
|
ctx, query,
|
||||||
newBackend.ID, newBackend.Name,
|
newBackend.ID, newBackend.Name,
|
||||||
newBackend.Config.Issuer, newBackend.Config.ClientID,
|
newBackend.Config.Issuer, newBackend.Config.ClientID,
|
||||||
newBackend.Config.ClientSecret, newBackend.Config.RedirectURI,
|
newBackend.Config.ClientSecret, newBackend.Config.RedirectURI,
|
||||||
|
scopesStr,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to insert in DB: %w", err)
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
|
@ -17,6 +18,7 @@ type Storage interface {
|
||||||
BackendStorage() backend.BackendDB
|
BackendStorage() backend.BackendDB
|
||||||
AuthRequestStorage() authrequest.AuthRequestDB
|
AuthRequestStorage() authrequest.AuthRequestDB
|
||||||
AuthCodeStorage() authcode.AuthCodeDB
|
AuthCodeStorage() authcode.AuthCodeDB
|
||||||
|
UserStorage() user.UserDB
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlStorage struct {
|
type sqlStorage struct {
|
||||||
|
@ -43,6 +45,10 @@ func (s *sqlStorage) AuthCodeStorage() authcode.AuthCodeDB {
|
||||||
return authcode.New(s.db)
|
return authcode.New(s.db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *sqlStorage) UserStorage() user.UserDB {
|
||||||
|
return user.New(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
func New(conf config.AppConfig) (Storage, error) {
|
func New(conf config.AppConfig) (Storage, error) {
|
||||||
db, err := sql.Open("sqlite3", conf.StorageConfig.File)
|
db, err := sql.Open("sqlite3", conf.StorageConfig.File)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -18,6 +18,7 @@ const clientRows = `"client"."id", "client"."secret", "client"."redirect_uris",
|
||||||
|
|
||||||
type ClientDB interface {
|
type ClientDB interface {
|
||||||
GetClientByID(ctx context.Context, id string) (*model.Client, error)
|
GetClientByID(ctx context.Context, id string) (*model.Client, error)
|
||||||
|
AddClient(ctx context.Context, client *model.Client) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlClientDB struct {
|
type sqlClientDB struct {
|
||||||
|
@ -32,6 +33,14 @@ func strArrayToSlice(rawVal string) []string {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sliceToStrArray(rawVal []string) string {
|
||||||
|
res, err := json.Marshal(rawVal)
|
||||||
|
if err != nil {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
return string(res)
|
||||||
|
}
|
||||||
|
|
||||||
func clientFromRow(row *sql.Row) (*model.Client, error) {
|
func clientFromRow(row *sql.Row) (*model.Client, error) {
|
||||||
var res model.Client
|
var res model.Client
|
||||||
redirectURIsStr := ""
|
redirectURIsStr := ""
|
||||||
|
@ -57,6 +66,28 @@ func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Cli
|
||||||
return clientFromRow(row)
|
return clientFromRow(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *sqlClientDB) AddClient(ctx context.Context, client *model.Client) error {
|
||||||
|
logger.L.Debugf("Creating client %s", client.Name)
|
||||||
|
query := `INSERT INTO "client" ("id", "secret", "redirect_uris", "trusted_peers", "name") VALUES ($1, $2, $3, $4, $5)`
|
||||||
|
|
||||||
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
if affectedRows, err := tx.ExecContext(ctx, query, client.ID, client.Secret, sliceToStrArray(client.RedirectURIs()), sliceToStrArray(client.TrustedPeers), client.Name); err != nil {
|
||||||
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
} else if nbAffected, err := affectedRows.RowsAffected(); err != nil {
|
||||||
|
return fmt.Errorf("failed to check number of affected rows: %w", err)
|
||||||
|
} else if nbAffected != 1 {
|
||||||
|
return fmt.Errorf("unexpected number of affected rows: %d", nbAffected)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(db *sql.DB) *sqlClientDB {
|
func New(db *sql.DB) *sqlClientDB {
|
||||||
return &sqlClientDB{db: db}
|
return &sqlClientDB{db: db}
|
||||||
}
|
}
|
||||||
|
|
63
polyculeconnect/internal/db/user/user.go
Normal file
63
polyculeconnect/internal/db/user/user.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserDB interface {
|
||||||
|
AddUser(ctx context.Context, user *model.User) error
|
||||||
|
GetUserBySubject(ctx context.Context, subject string) (*model.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
const getUserQuery = `
|
||||||
|
SELECT id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified
|
||||||
|
FROM user
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
const insertUserQuery = `
|
||||||
|
INSERT INTO user (id, name, family_name, given_name, nickname, picture, updated_at, email, email_verified)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
`
|
||||||
|
|
||||||
|
type sqlUserDB struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *sqlUserDB) GetUserBySubject(ctx context.Context, subject string) (*model.User, error) {
|
||||||
|
row := db.db.QueryRowContext(ctx, getUserQuery, subject)
|
||||||
|
var res model.User
|
||||||
|
if err := row.Scan(&res.Subject, &res.Name, &res.FamilyName, &res.GivenName, &res.Nickname, &res.Picture, &res.UpdatedAt, &res.Email, &res.EmailVerified); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to read result from DB: %w", err)
|
||||||
|
}
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *sqlUserDB) AddUser(ctx context.Context, user *model.User) error {
|
||||||
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
if _, err := tx.ExecContext(ctx, insertUserQuery, user.Subject, user.Name, user.FamilyName, user.GivenName, user.Nickname, user.Picture, user.UpdatedAt, user.Email, user.EmailVerified); err != nil {
|
||||||
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(db *sql.DB) *sqlUserDB {
|
||||||
|
return &sqlUserDB{db: db}
|
||||||
|
}
|
|
@ -30,6 +30,7 @@ type AuthRequest struct {
|
||||||
BackendID uuid.UUID
|
BackendID uuid.UUID
|
||||||
Backend *Backend
|
Backend *Backend
|
||||||
|
|
||||||
|
UserID string
|
||||||
User *User
|
User *User
|
||||||
|
|
||||||
DoneVal bool
|
DoneVal bool
|
||||||
|
@ -94,7 +95,7 @@ func (a AuthRequest) GetSubject() string {
|
||||||
if a.User == nil {
|
if a.User == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return a.User.ID.String()
|
return a.User.Subject
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AuthRequest) Done() bool {
|
func (a AuthRequest) Done() bool {
|
||||||
|
|
|
@ -7,6 +7,7 @@ type BackendOIDCConfig struct {
|
||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
|
Scopes []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
|
|
|
@ -1,9 +1,22 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID uuid.UUID
|
// Part of openid scope
|
||||||
|
Subject string
|
||||||
|
|
||||||
|
// Part of profile scope
|
||||||
|
Name string
|
||||||
|
FamilyName string
|
||||||
|
GivenName string
|
||||||
|
Nickname string
|
||||||
|
Picture string
|
||||||
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
// part of email scope
|
||||||
Email string
|
Email string
|
||||||
Username string
|
EmailVerified bool
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,20 @@ func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.Aut
|
||||||
return nil, fmt.Errorf("invalid format for uuid: %w", err)
|
return nil, fmt.Errorf("invalid format for uuid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id)
|
req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
|
}
|
||||||
|
if req.UserID == "" {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user information from DB: %w", err)
|
||||||
|
}
|
||||||
|
req.User = user
|
||||||
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) {
|
func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) {
|
||||||
|
@ -77,7 +90,20 @@ func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op
|
||||||
return nil, fmt.Errorf("failed to get auth code from DB: %w", err)
|
return nil, fmt.Errorf("failed to get auth code from DB: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authCode.RequestID)
|
req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authCode.RequestID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
|
}
|
||||||
|
if req.UserID == "" {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user information from DB: %w", err)
|
||||||
|
}
|
||||||
|
req.User = user
|
||||||
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
|
func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
|
||||||
|
@ -253,20 +279,28 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
|
||||||
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
||||||
logger.L.Debugf("Setting user info for user %s", userID)
|
logger.L.Debugf("Setting user info for user %s", userID)
|
||||||
|
|
||||||
parsedID, err := uuid.Parse(userID)
|
user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid userID: %w", err)
|
return fmt.Errorf("failed to get user from DB: %w", err)
|
||||||
}
|
}
|
||||||
req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByUserID(ctx, parsedID)
|
|
||||||
if err != nil {
|
for _, s := range scopes {
|
||||||
return fmt.Errorf("failed to get auth request from DB: %w", err)
|
switch s {
|
||||||
}
|
case "openid":
|
||||||
if req.User == nil {
|
userinfo.Subject = user.Subject
|
||||||
return errors.New("no user associated to that ID")
|
case "profile":
|
||||||
|
userinfo.Name = user.Name
|
||||||
|
userinfo.FamilyName = user.FamilyName
|
||||||
|
userinfo.GivenName = user.GivenName
|
||||||
|
userinfo.Nickname = user.Nickname
|
||||||
|
userinfo.Picture = user.Picture
|
||||||
|
userinfo.UpdatedAt = oidc.FromTime(user.UpdatedAt)
|
||||||
|
case "email":
|
||||||
|
userinfo.Email = user.Email
|
||||||
|
userinfo.EmailVerified = oidc.Bool(user.EmailVerified)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userinfo.PreferredUsername = req.User.Username
|
|
||||||
userinfo.Email = req.User.Email
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE "backend";
|
|
|
@ -1,8 +0,0 @@
|
||||||
CREATE TABLE "backend" (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL UNIQUE,
|
|
||||||
oidc_issuer TEXT NOT NULL,
|
|
||||||
oidc_client_id TEXT NOT NULL,
|
|
||||||
oidc_client_secret TEXT NOT NULL,
|
|
||||||
oidc_redirect_uri TEXT NOT NULL
|
|
||||||
);
|
|
5
polyculeconnect/migrations/0_initial_schema.down.sql
Normal file
5
polyculeconnect/migrations/0_initial_schema.down.sql
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
DROP TABLE "auth_code";
|
||||||
|
DROP TABLE "auth_request";
|
||||||
|
DROP TABLE "user";
|
||||||
|
DROP TABLE "backend";
|
||||||
|
DROP TABLE "client";
|
58
polyculeconnect/migrations/0_initial_schema.up.sql
Normal file
58
polyculeconnect/migrations/0_initial_schema.up.sql
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
CREATE TABLE "backend" (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
oidc_issuer TEXT NOT NULL,
|
||||||
|
oidc_client_id TEXT NOT NULL,
|
||||||
|
oidc_client_secret TEXT NOT NULL,
|
||||||
|
oidc_redirect_uri TEXT NOT NULL,
|
||||||
|
oidc_scopes blob NOT NULL DEFAULT '[]' -- list of strings, json-encoded,
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "client" (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
secret TEXT NOT NULL,
|
||||||
|
redirect_uris blob NOT NULL,
|
||||||
|
trusted_peers blob NOT NULL,
|
||||||
|
public integer NOT NULL DEFAULT 0,
|
||||||
|
name TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "user" (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL DEFAULT '',
|
||||||
|
family_name TEXT NOT NULL DEFAULT '',
|
||||||
|
given_name TEXT NOT NULL DEFAULT '',
|
||||||
|
nickname TEXT NOT NULL DEFAULT '',
|
||||||
|
picture TEXT NOT NULL DEFAULT '',
|
||||||
|
updated_at timestamp,
|
||||||
|
email TEXT NOT NULL DEFAULT '',
|
||||||
|
email_verified INTEGER NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "auth_request" (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
client_id TEXT NOT NULL,
|
||||||
|
backend_id TEXT NOT NULL,
|
||||||
|
scopes blob NOT NULL, -- list of strings, json-encoded
|
||||||
|
redirect_uri TEXT NOT NULL,
|
||||||
|
state TEXT NOT NULL,
|
||||||
|
nonce TEXT NOT NULL,
|
||||||
|
response_type TEXT NOT NULL,
|
||||||
|
creation_time timestamp NOT NULL,
|
||||||
|
done INTEGER NOT NULL DEFAULT 0,
|
||||||
|
code_challenge STRING NOT NULL DEFAULT '',
|
||||||
|
code_challenge_method STRING NOT NULL DEFAULT '',
|
||||||
|
auth_time timestamp,
|
||||||
|
user_id TEXT NOT NULL DEFAULT '',
|
||||||
|
FOREIGN KEY(backend_id) REFERENCES backend(id),
|
||||||
|
FOREIGN KEY(client_id) REFERENCES client(id),
|
||||||
|
FOREIGN KEY(user_id) REFERENCES user(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "auth_code" (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
auth_request_id TEXT NOT NULL,
|
||||||
|
FOREIGN KEY(auth_request_id) REFERENCES auth_request(id)
|
||||||
|
);
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE "auth_request_2";
|
|
|
@ -1,11 +0,0 @@
|
||||||
CREATE TABLE "auth_request_2" (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
client_id TEXT NOT NULL,
|
|
||||||
backend_id TEXT NOT NULL,
|
|
||||||
scopes blob NOT NULL, -- list of strings, json-encoded
|
|
||||||
redirect_uri TEXT NOT NULL,
|
|
||||||
state TEXT NOT NULL,
|
|
||||||
nonce TEXT NOT NULL,
|
|
||||||
response_type TEXT NOT NULL,
|
|
||||||
creation_time timestamp NOT NULL
|
|
||||||
);
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN done;
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN done INTEGER NOT NULL DEFAULT 0;
|
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE "auth_code_2";
|
|
|
@ -1,5 +0,0 @@
|
||||||
CREATE TABLE "auth_code_2" (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
code TEXT NOT NULL,
|
|
||||||
auth_request_id TEXT NOT NULL
|
|
||||||
);
|
|
|
@ -1,2 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN code_challenge;
|
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN code_challenge_method;
|
|
|
@ -1,2 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN code_challenge STRING NOT NULL DEFAULT '';
|
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN code_challenge_method STRING NOT NULL DEFAULT '';
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN auth_time;
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN auth_time timestamp;
|
|
|
@ -1,3 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN claim_user_id;
|
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN claim_username;
|
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN claim_email;
|
|
|
@ -1,3 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN claim_user_id string;
|
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN claim_username string;
|
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN claim_email string;
|
|
Binary file not shown.
|
@ -81,7 +81,7 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
|
||||||
return nil, fmt.Errorf("failed to get list of backends from storage: %w", err)
|
return nil, fmt.Errorf("failed to get list of backends from storage: %w", err)
|
||||||
}
|
}
|
||||||
for _, b := range backends {
|
for _, b := range backends {
|
||||||
provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, []string{"openid", "email"})
|
provider, err := rp.NewRelyingPartyOIDC(context.Background(), b.Config.Issuer, b.Config.ClientID, b.Config.ClientSecret, b.Config.RedirectURI, b.Config.Scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err)
|
return nil, fmt.Errorf("failed to create connector for backend %s: %w", b.Name, err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue