package storage import ( "context" "database/sql" "errors" "fmt" "time" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/client" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger" "github.com/go-jose/go-jose/v4" "github.com/google/uuid" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" ) func ErrNotImplemented(name string) error { return fmt.Errorf("%s is not implemented", name) } // Storage implements the Storage interface from zitadel/oidc/op type Storage struct { LocalStorage db.Storage InitializedBackends map[uuid.UUID]*client.OIDCClient Key *model.Key } /* Auth storage interface */ func (s *Storage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (op.AuthRequest, error) { // userID should normally be an empty string (to verify), we don't get it in our workflow from what I saw // TODO: check this is indeed not needed / never present logger.L.Debugf("Creating a new auth request") // validate that the connector is correct backendName, ok := stringFromCtx(ctx, "backendName") if !ok { return nil, errors.New("no backend name provided") } selectedBackend, err := s.LocalStorage.BackendStorage().GetBackendByName(ctx, backendName) if err != nil { return nil, fmt.Errorf("failed to get backend: %w", err) } var opReq model.AuthRequest opReq.FromOIDCAuthRequest(req, selectedBackend.ID) if err := s.LocalStorage.AuthRequestStorage().CreateAuthRequest(ctx, opReq); err != nil { return nil, fmt.Errorf("failed to save auth request: %w", err) } logger.L.Debugf("Created a new auth request for backend %s", backendName) return opReq, nil } func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) { logger.L.Debugf("Getting auth request with ID %s", requestID) id, err := uuid.Parse(requestID) if err != nil { return nil, fmt.Errorf("invalid format for uuid: %w", err) } req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if req.UserID == "" { return req, nil } user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID) if err != nil { return nil, fmt.Errorf("failed to get user information from DB: %w", err) } req.User = user return req, nil } func (s *Storage) AuthRequestByCode(ctx context.Context, requestCode string) (op.AuthRequest, error) { logger.L.Debugf("Getting auth request from code %s", requestCode) authCode, err := s.LocalStorage.AuthCodeStorage().GetAuthCodeByCode(ctx, requestCode) if err != nil { return nil, fmt.Errorf("failed to get auth code from DB: %w", err) } req, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authCode.RequestID) if err != nil { return nil, fmt.Errorf("failed to get auth request from DB: %w", err) } if req.UserID == "" { return req, nil } user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, req.UserID) if err != nil { return nil, fmt.Errorf("failed to get user information from DB: %w", err) } req.User = user return req, nil } func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error { logger.L.Debugf("Saving auth code %s for request %s", code, id) requestID, err := uuid.Parse(id) if err != nil { return fmt.Errorf("invalid requestID %s: %w", requestID, err) } codeID := uuid.New() savedCode := model.AuthCode{ CodeID: codeID, RequestID: requestID, Code: code, } return s.LocalStorage.AuthCodeStorage().CreateAuthCode(ctx, savedCode) } func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error { reqID, err := uuid.Parse(id) if err != nil { return fmt.Errorf("invalid id format: %w", err) } return s.LocalStorage.AuthRequestStorage().DeleteAuthRequest(ctx, reqID) } func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) { accessTokenUUID := uuid.New() var authTime time.Time switch typedReq := req.(type) { case *model.AuthRequest: logger.L.Debug("Creating access token for new authentication") authTime = typedReq.AuthTime case *model.RefreshTokenRequest: logger.L.Debug("Handling refresh token request") authTime = typedReq.GetAuthTime() default: logger.L.Errorf("Unexpected type for request %v", err) return "", time.Time{}, errors.New("failed to parse auth request") } expiration = authTime.Add(5 * time.Minute) // token := model.Token{ // ID: accessTokenUUID, // RefreshTokenID: refreshTokenUUID, // Expiration: authTime.Add(5 * time.Minute), // Subjet: request.GetSubject(), // Audiences: request.GetAudience(), // Scopes: request.GetScopes(), // } return accessTokenUUID.String(), expiration, nil } func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) { accessTokenUUID := uuid.New() refreshTokenUUID := uuid.New() var authTime time.Time var clientID string switch typedReq := request.(type) { case *model.AuthRequest: logger.L.Debug("Creating access token for new authentication") clientID = typedReq.ClientID authTime = typedReq.AuthTime case *model.RefreshTokenRequest: logger.L.Debug("Handling refresh token request") clientID = typedReq.GetClientID() authTime = typedReq.GetAuthTime() default: logger.L.Errorf("Unexpected type for request %v", err) return "", "", time.Time{}, errors.New("failed to parse auth request") } expiration = authTime.Add(5 * time.Minute) // token := model.Token{ // ID: accessTokenUUID, // RefreshTokenID: refreshTokenUUID, // Expiration: authTime.Add(5 * time.Minute), // Subjet: request.GetSubject(), // Audiences: request.GetAudience(), // Scopes: request.GetScopes(), // } refreshToken := model.RefreshToken{ ID: refreshTokenUUID, ClientID: clientID, UserID: request.GetSubject(), Scopes: request.GetScopes(), AuthTime: authTime, } if err := s.LocalStorage.TokenStorage().AddRefreshToken(ctx, &refreshToken); err != nil { return "", "", time.Time{}, fmt.Errorf("failed to insert token in DB: %w", err) } return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil } func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) { parsedID, err := uuid.Parse(refreshTokenID) if err != nil { return nil, fmt.Errorf("invalid format for refresh token id: %w", err) } refreshToken, err := s.LocalStorage.TokenStorage().GetRefreshTokenByID(ctx, parsedID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, op.ErrInvalidRefreshToken } return nil, fmt.Errorf("failed to get refresh token: %w", err) } return refreshToken.Request(), nil } func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error { return ErrNotImplemented("TerminateSession") } func (s *Storage) RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error { return nil } func (s *Storage) GetRefreshTokenInfo(ctx context.Context, clientID string, stoken string) (string, string, error) { return "", "", ErrNotImplemented("GetRefreshTokenInfo") } func (s *Storage) SigningKey(ctx context.Context) (op.SigningKey, error) { return s.Key.SigningKey(), nil } func (s *Storage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) { return nil, ErrNotImplemented("SignatureAlgorithms") } func (s *Storage) KeySet(ctx context.Context) ([]op.Key, error) { return []op.Key{s.Key}, nil } /* OP storage */ func (s *Storage) getClientWithDetails(ctx context.Context, authRequestID uuid.UUID) (op.Client, error) { logger.L.Debug("Trying to get client details from auth request") authRequest, err := s.LocalStorage.AuthRequestStorage().GetAuthRequestByID(ctx, authRequestID) if err != nil { return nil, fmt.Errorf("failed to get authRequest from local storage: %w", err) } backend, err := s.LocalStorage.BackendStorage().GetBackendByID(ctx, authRequest.BackendID) if err != nil { return nil, fmt.Errorf("failed to get associated backend from local storage: %w", err) } client, err := s.LocalStorage.ClientStorage().GetClientByID(ctx, authRequest.ClientID) if err != nil { return nil, fmt.Errorf("failed to get associated client from local storage: %w", err) } // oidcClient, ok := s.InitializedBackends[backend.ID] // if !ok { // return nil, fmt.Errorf("no initialized backend for ID %s", backend.ID) // } authRequest.Backend = backend client.AuthRequest = authRequest return client, nil } // We're cheating a bit here since we're using the authrequest to get its associated client // but a request is always associated to a backend, and we really need both, so we have no // choice here. I'll maybe need to have a more elegant solution later, but not choice for now func (s *Storage) GetClientByClientID(ctx context.Context, id string) (op.Client, error) { logger.L.Debugf("Selecting client app with ID %s", id) authRequestID, err := uuid.Parse(id) if err != nil { // it's not a UUID, it means this was called using client_id, we just return the client without details client, err := s.LocalStorage.ClientStorage().GetClientByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get client %s from local storage: %w", id, err) } return client, nil } // we have a UUID, it means we got a requestID, so we can get all details here return s.getClientWithDetails(ctx, authRequestID) } func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error { logger.L.Debugf("Validating client secret %s for client %s", clientSecret, clientID) client, err := s.LocalStorage.ClientStorage().GetClientByID(ctx, clientID) if err != nil { return err } if client.Secret != clientSecret { return errors.New("invalid secret") } return nil } func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { logger.L.Debugf("Setting user info for user %s", userID) user, err := s.LocalStorage.UserStorage().GetUserBySubject(ctx, userID) if err != nil { return fmt.Errorf("failed to get user from DB: %w", err) } for _, s := range scopes { switch s { case "openid": userinfo.Subject = user.Subject case "profile": userinfo.Name = user.Name userinfo.FamilyName = user.FamilyName userinfo.GivenName = user.GivenName userinfo.Nickname = user.Nickname userinfo.Picture = user.Picture userinfo.UpdatedAt = oidc.FromTime(user.UpdatedAt) case "email": userinfo.Email = user.Email userinfo.EmailVerified = oidc.Bool(user.EmailVerified) } } return nil } func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error { return ErrNotImplemented("SetUserinfoFromToken") } func (s *Storage) SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error { return ErrNotImplemented("SetIntrospectionFromToken") } func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) { // For now, let's just return nothing, we don't want to add any private scope return nil, nil } func (s *Storage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) { return nil, ErrNotImplemented("GetKeyByIDAndClientID") } func (s *Storage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) { return nil, ErrNotImplemented("ValidateJWTProfileScopes") } func (s *Storage) Health(ctx context.Context) error { return ErrNotImplemented("Health") } func stringFromCtx(ctx context.Context, key string) (string, bool) { rawVal := ctx.Value(key) if rawVal == nil { return "", false } val, ok := rawVal.(string) return val, ok }