Compare commits
No commits in common. "a0849388a7461d41194c9fae93489de78a02f74c" and "9206c8e41e6bb732a6fe4b84dbbb8366f1574972" have entirely different histories.
a0849388a7
...
9206c8e41e
21 changed files with 59 additions and 274 deletions
|
@ -59,7 +59,7 @@ func addNewBackend() {
|
||||||
backendConf := model.Backend{
|
backendConf := model.Backend{
|
||||||
ID: backendIDUUID,
|
ID: backendIDUUID,
|
||||||
Name: backendName,
|
Name: backendName,
|
||||||
Config: model.BackendOIDCConfig{
|
OIDCConfig: model.BackendOIDCConfig{
|
||||||
ClientID: backendClientID,
|
ClientID: backendClientID,
|
||||||
ClientSecret: backendClientSecret,
|
ClientSecret: backendClientSecret,
|
||||||
Issuer: backendIssuer,
|
Issuer: backendIssuer,
|
||||||
|
|
|
@ -46,10 +46,10 @@ func showBackend(backendName string, backendService backend.BackendDB) {
|
||||||
fmt.Println("Backend config:")
|
fmt.Println("Backend config:")
|
||||||
printProperty("ID", backendConfig.ID.String())
|
printProperty("ID", backendConfig.ID.String())
|
||||||
printProperty("Name", backendConfig.Name)
|
printProperty("Name", backendConfig.Name)
|
||||||
printProperty("Issuer", backendConfig.Config.Issuer)
|
printProperty("Issuer", backendConfig.OIDCConfig.Issuer)
|
||||||
printProperty("Client ID", backendConfig.Config.ClientID)
|
printProperty("Client ID", backendConfig.OIDCConfig.ClientID)
|
||||||
printProperty("Client secret", backendConfig.Config.ClientSecret)
|
printProperty("Client secret", backendConfig.OIDCConfig.ClientSecret)
|
||||||
printProperty("Redirect URI", backendConfig.Config.RedirectURI)
|
printProperty("Redirect URI", backendConfig.OIDCConfig.RedirectURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listBackends(backendStorage backend.BackendDB) {
|
func listBackends(backendStorage backend.BackendDB) {
|
||||||
|
@ -63,7 +63,7 @@ func listBackends(backendStorage backend.BackendDB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, b := range backends {
|
for _, b := range backends {
|
||||||
fmt.Printf("\t - %s: (%s) - %s\n", b.ID, b.Name, b.Config.Issuer)
|
fmt.Printf("\t - %s: (%s) - %s\n", b.ID, b.Name, b.OIDCConfig.Issuer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/server"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/server"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/services"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
"go.uber.org/zap/exp/zapslog"
|
"go.uber.org/zap/exp/zapslog"
|
||||||
|
@ -52,9 +51,7 @@ func serve() {
|
||||||
utils.Failf("failed to init user DB: %s", err.Error())
|
utils.Failf("failed to init user DB: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
backends := map[uuid.UUID]*client.OIDCClient{}
|
st := storage.Storage{LocalStorage: userDB}
|
||||||
|
|
||||||
st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends}
|
|
||||||
opConf := op.Config{}
|
opConf := op.Config{}
|
||||||
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
|
slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))
|
||||||
// slogger :=
|
// slogger :=
|
||||||
|
@ -64,28 +61,29 @@ func serve() {
|
||||||
op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware),
|
op.WithHttpInterceptors(middlewares.WithBackendFromRequestMiddleware),
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.L.Info("Initializing authentication backends")
|
logger.L.Info("Initializing authentication backends")
|
||||||
// backendConfs, err := userDB.BackendStorage().GetAllBackends(context.Background())
|
backends := []*client.OIDCClient{}
|
||||||
// if err != nil {
|
backendConfs, err := userDB.BackendStorage().GetAllBackends(context.Background())
|
||||||
// utils.Failf("failed to get backend configs from the DB: %s", err.Error())
|
if err != nil {
|
||||||
// }
|
utils.Failf("failed to get backend configs from the DB: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: check if we need to do it this way or
|
// TODO: check if we need to do it this way or
|
||||||
// - do a try-loop?
|
// - do a try-loop?
|
||||||
// - only init when using them in a request?
|
// - only init when using them in a request?
|
||||||
// for _, c := range backendConfs {
|
for _, c := range backendConfs {
|
||||||
// logger.L.Debugf("Initializing backend %s", c.Name)
|
logger.L.Debugf("Initializing backend %s", c.Name)
|
||||||
// b, err := client.New(context.Background(), c, logger.L)
|
b, err := client.New(context.Background(), c)
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// utils.Failf("failed to init backend client: %s", err.Error())
|
utils.Failf("failed to init backend client: %s", err.Error())
|
||||||
// }
|
}
|
||||||
// backends[c.ID] = b
|
backends = append(backends, b)
|
||||||
// }
|
}
|
||||||
// if len(backends) == 0 {
|
if len(backends) == 0 {
|
||||||
// logger.L.Warn("No auth backend loaded")
|
logger.L.Warn("No auth backend loaded")
|
||||||
// } else {
|
} else {
|
||||||
// logger.L.Infof("Initialized %d auth backends", len(backends))
|
logger.L.Infof("Initialized %d auth backends", len(backends))
|
||||||
// }
|
}
|
||||||
|
|
||||||
provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...)
|
provider, err := op.NewProvider(&opConf, &st, op.StaticIssuer(conf.Issuer), options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -97,7 +95,7 @@ func serve() {
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.L.Info("Initializing server")
|
logger.L.Info("Initializing server")
|
||||||
s, err := server.New(conf, provider, &st, logger.L)
|
s, err := server.New(conf, provider, logger.L)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Fatalf("Failed to initialize server: %s", err.Error())
|
logger.L.Fatalf("Failed to initialize server: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
const AuthCallbackRoute = "/callback"
|
|
||||||
|
|
||||||
type AuthCallbackController struct {
|
|
||||||
l *zap.SugaredLogger
|
|
||||||
st *storage.Storage
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthCallbackController(l *zap.SugaredLogger, st *storage.Storage) *AuthCallbackController {
|
|
||||||
return &AuthCallbackController{
|
|
||||||
l: l,
|
|
||||||
st: st,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AuthCallbackController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
errMsg := r.URL.Query().Get("error")
|
|
||||||
if errMsg != "" {
|
|
||||||
errorDesc := r.URL.Query().Get("error_description")
|
|
||||||
c.l.Errorf("Failed to perform authentication: %s (%s)", errMsg, errorDesc)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
|
||||||
state := r.URL.Query().Get("state")
|
|
||||||
if code == "" || state == "" {
|
|
||||||
c.l.Error("Missing code or state in response")
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID, err := uuid.Parse(state)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Invalid state, should be a request UUID, but got %s: %s", state, err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.st.LocalStorage.AuthRequestStorage().ValidateAuthRequest(r.Context(), requestID)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Failed to validate auth request from storage: %s", err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform authentication"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(w, r, "/authorize/callback?id="+state, http.StatusFound)
|
|
||||||
}
|
|
|
@ -1,63 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/helpers"
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
const AuthRedirectRoute = "/perform_auth"
|
|
||||||
|
|
||||||
type AuthRedirectController struct {
|
|
||||||
l *zap.SugaredLogger
|
|
||||||
st *storage.Storage
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthRedirectController(l *zap.SugaredLogger, storage *storage.Storage) *AuthRedirectController {
|
|
||||||
return &AuthRedirectController{
|
|
||||||
l: l,
|
|
||||||
st: storage,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AuthRedirectController) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
requestIDStr := r.URL.Query().Get("request_id")
|
|
||||||
if requestIDStr == "" {
|
|
||||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("no request ID in request"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID, err := uuid.Parse(requestIDStr)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Invalid UUID format for request ID: %s", err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("invalid request id"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := c.st.LocalStorage.AuthRequestStorage().GetAuthRequestByID(r.Context(), requestID)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Failed to get auth request from DB: %s", err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusBadRequest, []byte("unknown request id"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
backend, err := c.st.LocalStorage.BackendStorage().GetBackendByID(r.Context(), req.BackendID)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Failed to get backend from DB: %s", err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
provider, err := rp.NewRelyingPartyOIDC(r.Context(), backend.Config.Issuer, backend.Config.ClientID, backend.Config.ClientSecret, backend.Config.RedirectURI, req.Scopes)
|
|
||||||
if err != nil {
|
|
||||||
c.l.Errorf("Failed to init relying party: %s", err)
|
|
||||||
helpers.HandleResponse(w, r, http.StatusInternalServerError, []byte("failed to perform auth"), c.l)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rp.AuthURLHandler(func() string { return requestIDStr }, provider).ServeHTTP(w, r)
|
|
||||||
}
|
|
|
@ -5,48 +5,35 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||||
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/exp/zapslog"
|
"go.uber.org/zap/exp/zapslog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BackendOIDCConfig struct {
|
|
||||||
Issuer string
|
|
||||||
ClientID string
|
|
||||||
ClientSecret string
|
|
||||||
RedirectURI string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Backend struct {
|
|
||||||
ID uuid.UUID
|
|
||||||
Name string
|
|
||||||
Config BackendOIDCConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// OIDCClient is an OIDC client which is the client used to access a registered backend
|
|
||||||
type OIDCClient struct {
|
type OIDCClient struct {
|
||||||
backend *Backend
|
conf *model.Backend
|
||||||
|
|
||||||
provider rp.RelyingParty
|
provider rp.RelyingParty
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
st db.Storage
|
||||||
l *zap.SugaredLogger
|
l *zap.SugaredLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, conf *Backend, l *zap.SugaredLogger) (*OIDCClient, error) {
|
func New(ctx context.Context, conf *model.Backend, l *zap.SugaredLogger) (*OIDCClient, error) {
|
||||||
options := []rp.Option{
|
options := []rp.Option{
|
||||||
rp.WithLogger(slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))),
|
rp.WithLogger(slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil))),
|
||||||
}
|
}
|
||||||
pr, err := rp.NewRelyingPartyOIDC(ctx, conf.Config.Issuer, conf.Config.ClientID, conf.Config.ClientSecret, conf.Config.RedirectURI, []string{}, options...)
|
pr, err := rp.NewRelyingPartyOIDC(ctx, conf.OIDCConfig.Issuer, conf.OIDCConfig.ClientID, conf.OIDCConfig.ClientSecret, conf.OIDCConfig.RedirectURI, []string{}, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to init relying party provider: %w", err)
|
return nil, fmt.Errorf("failed to init relying party provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &OIDCClient{ctx: ctx, backend: conf, provider: pr, l: l}, nil
|
return &OIDCClient{ctx: ctx, conf: conf, provider: pr, l: l}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OIDCClient) AuthorizationEndpoint() string {
|
func (c *OIDCClient) toto() {
|
||||||
url := rp.AuthURL(uuid.NewString(), c.provider)
|
c.provider.GetDeviceAuthorizationEndpoint()
|
||||||
return url
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,18 +8,16 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNotFound = errors.New("backend not found")
|
var ErrNotFound = errors.New("backend not found")
|
||||||
|
|
||||||
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time", "done"`
|
const authRequestRows = `"id", "client_id", "backend_id", "scopes", "redirect_uri", "state", "nonce", "response_type", "creation_time"`
|
||||||
|
|
||||||
type AuthRequestDB interface {
|
type AuthRequestDB interface {
|
||||||
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error)
|
||||||
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
CreateAuthRequest(ctx context.Context, req model.AuthRequest) error
|
||||||
ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlAuthRequestDB struct {
|
type sqlAuthRequestDB struct {
|
||||||
|
@ -27,26 +25,25 @@ type sqlAuthRequestDB struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
func (db *sqlAuthRequestDB) GetAuthRequestByID(ctx context.Context, id uuid.UUID) (*model.AuthRequest, error) {
|
||||||
logger.L.Debugf("Getting auth request with id %s", id)
|
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows)
|
query := fmt.Sprintf(`SELECT %s FROM "auth_request_2" WHERE "id" = ?`, authRequestRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
|
|
||||||
var res model.AuthRequest
|
var res model.AuthRequest
|
||||||
var scopesStr []byte
|
var scopesStr []byte
|
||||||
|
|
||||||
fmt.Println(query)
|
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate); err != nil {
|
||||||
if err := row.Scan(&res.ID, &res.ClientID, &res.BackendID, &scopesStr, &res.RedirectURI, &res.State, &res.Nonce, &res.ResponseType, &res.CreationDate, &res.DoneVal); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
return nil, fmt.Errorf("failed to get auth request from DB: %w", err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
if err := json.Unmarshal(scopesStr, &res.Scopes); err != nil {
|
||||||
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
return nil, fmt.Errorf("invalid format for scopes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println(res)
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error {
|
func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.AuthRequest) error {
|
||||||
logger.L.Debugf("Creating a new auth request between client app %s and backend %s", req.ClientID, req.BackendID)
|
|
||||||
tx, err := db.db.BeginTx(ctx, nil)
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start transaction: %w", err)
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||||||
|
@ -58,12 +55,11 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
return fmt.Errorf("failed to serialize scopes: %w", err)
|
return fmt.Errorf("failed to serialize scopes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when the old table is done, rename into auth_request
|
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, authRequestRows)
|
||||||
query := fmt.Sprintf(`INSERT INTO "auth_request_2" (%s) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`, authRequestRows)
|
|
||||||
_, err = tx.ExecContext(ctx, query,
|
_, err = tx.ExecContext(ctx, query,
|
||||||
req.ID, req.ClientID, req.BackendID,
|
req.ID, req.ClientID, req.BackendID,
|
||||||
scopesStr, req.RedirectURI, req.State,
|
scopesStr, req.RedirectURI, req.State,
|
||||||
req.Nonce, req.ResponseType, req.CreationDate, false,
|
req.Nonce, req.ResponseType, req.CreationDate, req.AuthTime,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to insert in DB: %w", err)
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
@ -76,33 +72,6 @@ func (db *sqlAuthRequestDB) CreateAuthRequest(ctx context.Context, req model.Aut
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlAuthRequestDB) ValidateAuthRequest(ctx context.Context, reqID uuid.UUID) error {
|
|
||||||
logger.L.Debugf("Validating auth request %s", reqID)
|
|
||||||
tx, err := db.db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to start transaction: %w", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
res, err := tx.ExecContext(ctx, `UPDATE "auth_request_2" SET done = true WHERE id = $1`, reqID.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to update in DB: %w", err)
|
|
||||||
}
|
|
||||||
affectedRows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check number of affected rows: %w", err)
|
|
||||||
}
|
|
||||||
if affectedRows != 1 {
|
|
||||||
return ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(db *sql.DB) *sqlAuthRequestDB {
|
func New(db *sql.DB) *sqlAuthRequestDB {
|
||||||
return &sqlAuthRequestDB{db: db}
|
return &sqlAuthRequestDB{db: db}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +36,7 @@ type sqlBackendDB struct {
|
||||||
func backendFromRow(row scannable) (*model.Backend, error) {
|
func backendFromRow(row scannable) (*model.Backend, error) {
|
||||||
var res model.Backend
|
var res model.Backend
|
||||||
|
|
||||||
if err := row.Scan(&res.ID, &res.Name, &res.Config.Issuer, &res.Config.ClientID, &res.Config.ClientSecret, &res.Config.RedirectURI); err != nil {
|
if err := row.Scan(&res.ID, &res.Name, &res.OIDCConfig.Issuer, &res.OIDCConfig.ClientID, &res.OIDCConfig.ClientSecret, &res.OIDCConfig.RedirectURI); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
@ -47,21 +46,18 @@ func backendFromRow(row scannable) (*model.Backend, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlBackendDB) GetBackendByName(ctx context.Context, name string) (*model.Backend, error) {
|
func (db *sqlBackendDB) GetBackendByName(ctx context.Context, name string) (*model.Backend, error) {
|
||||||
logger.L.Debugf("Getting backend with name %s from DB", name)
|
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "backend" WHERE "name" = ?`, backendRows)
|
query := fmt.Sprintf(`SELECT %s FROM "backend" WHERE "name" = ?`, backendRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, name)
|
row := db.db.QueryRowContext(ctx, query, name)
|
||||||
return backendFromRow(row)
|
return backendFromRow(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlBackendDB) GetBackendByID(ctx context.Context, id uuid.UUID) (*model.Backend, error) {
|
func (db *sqlBackendDB) GetBackendByID(ctx context.Context, id uuid.UUID) (*model.Backend, error) {
|
||||||
logger.L.Debugf("Getting backend with ID %s from DB", id)
|
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "backend" WHERE "id" = ?`, backendRows)
|
query := fmt.Sprintf(`SELECT %s FROM "backend" WHERE "id" = ?`, backendRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
return backendFromRow(row)
|
return backendFromRow(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlBackendDB) GetAllBackends(ctx context.Context) ([]*model.Backend, error) {
|
func (db *sqlBackendDB) GetAllBackends(ctx context.Context) ([]*model.Backend, error) {
|
||||||
logger.L.Debug("Getting all backends from DB")
|
|
||||||
rows, err := db.db.QueryContext(ctx, fmt.Sprintf(`SELECT %s FROM "backend"`, backendRows))
|
rows, err := db.db.QueryContext(ctx, fmt.Sprintf(`SELECT %s FROM "backend"`, backendRows))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -75,7 +71,7 @@ func (db *sqlBackendDB) GetAllBackends(ctx context.Context) ([]*model.Backend, e
|
||||||
}
|
}
|
||||||
res = append(res, b)
|
res = append(res, b)
|
||||||
}
|
}
|
||||||
return res, rows.Err()
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlBackendDB) AddBackend(ctx context.Context, newBackend *model.Backend) error {
|
func (db *sqlBackendDB) AddBackend(ctx context.Context, newBackend *model.Backend) error {
|
||||||
|
@ -89,8 +85,8 @@ func (db *sqlBackendDB) AddBackend(ctx context.Context, newBackend *model.Backen
|
||||||
_, err = tx.ExecContext(
|
_, err = tx.ExecContext(
|
||||||
ctx, query,
|
ctx, query,
|
||||||
newBackend.ID, newBackend.Name,
|
newBackend.ID, newBackend.Name,
|
||||||
newBackend.Config.Issuer, newBackend.Config.ClientID,
|
newBackend.OIDCConfig.Issuer, newBackend.OIDCConfig.ClientID,
|
||||||
newBackend.Config.ClientSecret, newBackend.Config.RedirectURI,
|
newBackend.OIDCConfig.ClientSecret, newBackend.OIDCConfig.RedirectURI,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to insert in DB: %w", err)
|
return fmt.Errorf("failed to insert in DB: %w", err)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,7 +50,6 @@ func clientFromRow(row *sql.Row) (*model.Client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Client, error) {
|
func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Client, error) {
|
||||||
logger.L.Debugf("Getting client app with ID %s from DB", id)
|
|
||||||
query := fmt.Sprintf(`SELECT %s FROM "client" WHERE "id" = ?`, clientRows)
|
query := fmt.Sprintf(`SELECT %s FROM "client" WHERE "id" = ?`, clientRows)
|
||||||
row := db.db.QueryRowContext(ctx, query, id)
|
row := db.db.QueryRowContext(ctx, query, id)
|
||||||
return clientFromRow(row)
|
return clientFromRow(row)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package middlewares
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -15,7 +16,7 @@ type BackendFromRequestMiddleware struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *BackendFromRequestMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (m *BackendFromRequestMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path != "/authorize" {
|
if !strings.HasPrefix(r.RequestURI, "/authorize") {
|
||||||
m.h.ServeHTTP(w, r)
|
m.h.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,8 +31,7 @@ type AuthRequest struct {
|
||||||
Backend *Backend
|
Backend *Backend
|
||||||
|
|
||||||
UserID uuid.UUID
|
UserID uuid.UUID
|
||||||
|
done bool
|
||||||
DoneVal bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AuthRequest) GetID() string {
|
func (a AuthRequest) GetID() string {
|
||||||
|
@ -95,7 +94,7 @@ func (a AuthRequest) GetSubject() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AuthRequest) Done() bool {
|
func (a AuthRequest) Done() bool {
|
||||||
return a.DoneVal
|
return a.done
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) FromOIDCAuthRequest(req *oidc.AuthRequest, backendID uuid.UUID) {
|
func (a *AuthRequest) FromOIDCAuthRequest(req *oidc.AuthRequest, backendID uuid.UUID) {
|
||||||
|
|
|
@ -10,7 +10,7 @@ type BackendOIDCConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
Name string
|
Name string
|
||||||
Config BackendOIDCConfig
|
OIDCConfig BackendOIDCConfig
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientConfig represents the configuration for a OIDC client app
|
|
||||||
type ClientConfig struct {
|
type ClientConfig struct {
|
||||||
ID string
|
ID string
|
||||||
Secret string
|
Secret string
|
||||||
|
@ -17,7 +16,6 @@ type ClientConfig struct {
|
||||||
AuthRequest *AuthRequest
|
AuthRequest *AuthRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client represents an OIDC client app
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
ClientConfig
|
ClientConfig
|
||||||
}
|
}
|
||||||
|
@ -50,14 +48,12 @@ func (c Client) GrantTypes() []oidc.GrantType {
|
||||||
return []oidc.GrantType{oidc.GrantTypeCode}
|
return []oidc.GrantType{oidc.GrantTypeCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginURL returns the login URL for a given client app and auth request.
|
|
||||||
// This login url should be the authorization URL for the selected OIDC backend
|
|
||||||
func (c Client) LoginURL(authRequestID string) string {
|
func (c Client) LoginURL(authRequestID string) string {
|
||||||
if c.AuthRequest == nil {
|
if c.AuthRequest == nil {
|
||||||
return "" // we don't have a request, let's return nothing
|
return "" // we don't have a request, let's return nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
return "/perform_auth?request_id=" + authRequestID
|
return c.AuthRequest.Backend.OIDCConfig.Issuer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Client) AccessTokenType() op.AccessTokenType {
|
func (c Client) AccessTokenType() op.AccessTokenType {
|
||||||
|
|
|
@ -6,10 +6,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/client"
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
|
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
@ -22,19 +20,13 @@ func ErrNotImplemented(name string) error {
|
||||||
|
|
||||||
// Storage implements the Storage interface from zitadel/oidc/op
|
// Storage implements the Storage interface from zitadel/oidc/op
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
LocalStorage db.Storage
|
LocalStorage db.Storage
|
||||||
InitializedBackends map[uuid.UUID]*client.OIDCClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Auth storage interface
|
Auth storage interface
|
||||||
*/
|
*/
|
||||||
func (s *Storage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
func (s *Storage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
||||||
|
|
||||||
// userID should normally be an empty string (to verify), we don't get it in our workflow from what I saw
|
|
||||||
// TODO: check this is indeed not needed / never present
|
|
||||||
logger.L.Debug("Creating a new auth request")
|
|
||||||
|
|
||||||
// validate that the connector is correct
|
// validate that the connector is correct
|
||||||
backendName, ok := stringFromCtx(ctx, "backendName")
|
backendName, ok := stringFromCtx(ctx, "backendName")
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -52,20 +44,11 @@ func (s *Storage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest,
|
||||||
return nil, fmt.Errorf("failed to save auth request: %w", err)
|
return nil, fmt.Errorf("failed to save auth request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.L.Debugf("Created a new auth request for backend %s", backendName)
|
|
||||||
|
|
||||||
return opReq, nil
|
return opReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) {
|
func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) {
|
||||||
logger.L.Debugf("Getting auth request with ID %s", requestID)
|
return nil, ErrNotImplemented("AuthRequestByID")
|
||||||
|
|
||||||
id, err := uuid.Parse(requestID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid format for uuid: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) {
|
func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) {
|
||||||
|
@ -73,7 +56,6 @@ func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
|
func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
|
||||||
logger.L.Debugf("Saving auth code %s for request %s", code, id)
|
|
||||||
return ErrNotImplemented("SaveAuthCode")
|
return ErrNotImplemented("SaveAuthCode")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,8 +104,6 @@ func (s *Storage) KeySet(ctx context.Context) ([]op.Key, error) {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func (s *Storage) getClientWithDetails(ctx context.Context, authRequestID uuid.UUID) (op.Client, error) {
|
func (s *Storage) getClientWithDetails(ctx context.Context, authRequestID uuid.UUID) (op.Client, error) {
|
||||||
logger.L.Debug("Trying to get client details from auth request")
|
|
||||||
|
|
||||||
authRequest, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authRequestID)
|
authRequest, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authRequestID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get authRequest from local storage: %w", err)
|
return nil, fmt.Errorf("failed to get authRequest from local storage: %w", err)
|
||||||
|
@ -137,11 +117,6 @@ func (s *Storage) getClientWithDetails(ctx context.Context, authRequestID uuid.U
|
||||||
return nil, fmt.Errorf("failed to get associated client from local storage: %w", err)
|
return nil, fmt.Errorf("failed to get associated client from local storage: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// oidcClient, ok := s.InitializedBackends[backend.ID]
|
|
||||||
// if !ok {
|
|
||||||
// return nil, fmt.Errorf("no initialized backend for ID %s", backend.ID)
|
|
||||||
// }
|
|
||||||
|
|
||||||
authRequest.Backend = backend
|
authRequest.Backend = backend
|
||||||
client.AuthRequest = authRequest
|
client.AuthRequest = authRequest
|
||||||
|
|
||||||
|
@ -152,7 +127,6 @@ func (s *Storage) getClientWithDetails(ctx context.Context, authRequestID uuid.U
|
||||||
// but a request is always associated to a backend, and we really need both, so we have no
|
// but a request is always associated to a backend, and we really need both, so we have no
|
||||||
// choice here. I'll maybe need to have a more elegant solution later, but not choice for now
|
// choice here. I'll maybe need to have a more elegant solution later, but not choice for now
|
||||||
func (s *Storage) GetClientByClientID(ctx context.Context, id string) (op.Client, error) {
|
func (s *Storage) GetClientByClientID(ctx context.Context, id string) (op.Client, error) {
|
||||||
logger.L.Debugf("Selecting client app with ID %s", id)
|
|
||||||
|
|
||||||
authRequestID, err := uuid.Parse(id)
|
authRequestID, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" DROP COLUMN done;
|
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE "auth_request_2" ADD COLUMN done INTEGER NOT NULL DEFAULT 0;
|
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE "auth_code_2";
|
|
|
@ -1,5 +0,0 @@
|
||||||
CREATE TABLE "auth_code_2" (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
code TEXT NOT NULL,
|
|
||||||
auth_request_id TEXT NOT NULL
|
|
||||||
);
|
|
Binary file not shown.
Binary file not shown.
|
@ -9,10 +9,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/config"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/config"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/auth"
|
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/ui"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/controller/ui"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/middlewares"
|
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/middlewares"
|
||||||
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/storage"
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
@ -44,7 +42,7 @@ func newUnixListener(sockPath string) (net.Listener, error) {
|
||||||
return sock, nil
|
return sock, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storage, logger *zap.SugaredLogger) (*Server, error) {
|
func New(appConf *config.AppConfig, oidcHandler *op.Provider, logger *zap.SugaredLogger) (*Server, error) {
|
||||||
var listener net.Listener
|
var listener net.Listener
|
||||||
var addr string
|
var addr string
|
||||||
var err error
|
var err error
|
||||||
|
@ -66,10 +64,8 @@ func New(appConf *config.AppConfig, oidcHandler *op.Provider, st *storage.Storag
|
||||||
}
|
}
|
||||||
|
|
||||||
controllers := map[string]http.Handler{
|
controllers := map[string]http.Handler{
|
||||||
auth.AuthCallbackRoute: middlewares.WithLogger(auth.NewAuthCallbackController(logger, st), logger),
|
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
||||||
auth.AuthRedirectRoute: middlewares.WithLogger(auth.NewAuthRedirectController(logger, st), logger),
|
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
||||||
ui.StaticRoute: middlewares.WithLogger(ui.NewStaticController(appConf.StaticDir), logger),
|
|
||||||
"/": middlewares.WithLogger(ui.NewIndexController(logger, oidcHandler, appConf.StaticDir), logger),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m := http.NewServeMux()
|
m := http.NewServeMux()
|
||||||
|
|
Loading…
Reference in a new issue