diff --git a/changes/unreleased/Fixed-20260616-150152.yaml b/changes/unreleased/Fixed-20260616-150152.yaml new file mode 100644 index 00000000..23098c1e --- /dev/null +++ b/changes/unreleased/Fixed-20260616-150152.yaml @@ -0,0 +1,3 @@ +kind: Fixed +body: Database specs with duplicate allocated ports on a single host are now rejected by the API. +time: 2026-06-16T15:01:52.449574-04:00 diff --git a/server/internal/api/apiv1/errors.go b/server/internal/api/apiv1/errors.go index 1f4d6243..72e05323 100644 --- a/server/internal/api/apiv1/errors.go +++ b/server/internal/api/apiv1/errors.go @@ -10,6 +10,7 @@ import ( "github.com/pgEdge/control-plane/server/internal/database" "github.com/pgEdge/control-plane/server/internal/etcd" "github.com/pgEdge/control-plane/server/internal/task" + "github.com/pgEdge/control-plane/server/internal/validation" "github.com/pgEdge/control-plane/server/internal/workflows" ) @@ -51,7 +52,7 @@ func ErrHostAlreadyExistsWithID(hostID string) *api.APIError { func apiErr(err error) error { var goaErr *goa.ServiceError var apiErr *api.APIError - var vErr *validationError + var vErr *validation.Error switch { case err == nil: return nil diff --git a/server/internal/api/apiv1/validate.go b/server/internal/api/apiv1/validate.go index 585b41b5..947d94bc 100644 --- a/server/internal/api/apiv1/validate.go +++ b/server/internal/api/apiv1/validate.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "path/filepath" "regexp" "slices" @@ -20,63 +21,21 @@ import ( "github.com/pgEdge/control-plane/server/internal/postgres/hba" "github.com/pgEdge/control-plane/server/internal/storage" "github.com/pgEdge/control-plane/server/internal/utils" + "github.com/pgEdge/control-plane/server/internal/validation" ) -type validationError struct { - path []string - err error -} - -func newValidationError(err error, path []string) *validationError { - return &validationError{ - path: path, - err: err, - } -} - -func (v *validationError) Unwrap() error { - return v.err -} - -func (v *validationError) Error() string { - if len(v.path) == 0 { - return v.err.Error() - } - - var path strings.Builder - for i, ele := range v.path { - if i > 0 && !strings.HasPrefix(ele, "[") { - path.WriteString(".") - } - path.WriteString(ele) - } - return fmt.Sprintf("%s: %s", path.String(), v.err.Error()) -} - -func arrayIndexPath(idx int) string { - return fmt.Sprintf("[%d]", idx) -} - -func mapKeyPath(key string) string { - return fmt.Sprintf("[%s]", key) -} - -func appendPath(path []string, new ...string) []string { - return append(slices.Clone(path), new...) -} - // validateAuthFileGUCs rejects postgresql_conf settings that would make // user-supplied pg_hba_conf/pg_ident_conf entries ineffective. When hba_file // or ident_file is set, Patroni ignores the pg_hba/pg_ident arrays it manages, // so the control-plane-generated file (including user entries) would never be // written. GUC names are case-insensitive in PostgreSQL, so we compare lower. -func validateAuthFileGUCs(conf map[string]any, path []string) []error { +func validateAuthFileGUCs(conf map[string]any, path validation.Path) []error { var errs []error for key := range conf { switch strings.ToLower(strings.TrimSpace(key)) { case "hba_file", "ident_file": err := fmt.Errorf("%q is not allowed: it overrides the control-plane-managed pg_hba.conf/pg_ident.conf and would make pg_hba_conf/pg_ident_conf entries ineffective", key) - errs = append(errs, newValidationError(err, appendPath(path, mapKeyPath(key)))) + errs = append(errs, validation.NewError(err, path.AppendMapKey(key))) } } return errs @@ -85,7 +44,7 @@ func validateAuthFileGUCs(conf map[string]any, path []string) []error { // validatePgHbaConf checks that every non-comment pg_hba_conf entry parses. // Blank and comment lines are allowed and skipped. Validation is intentionally // minimal — see server/internal/postgres/hba/parse.go. -func validatePgHbaConf(lines []string, path []string) []error { +func validatePgHbaConf(lines []string, path validation.Path) []error { var errs []error for i, line := range lines { if hba.IsComment(line) { @@ -93,14 +52,14 @@ func validatePgHbaConf(lines []string, path []string) []error { } if _, err := hba.ParseEntry(line); err != nil { wrapped := fmt.Errorf("invalid pg_hba entry %q: %w", line, err) - errs = append(errs, newValidationError(wrapped, appendPath(path, arrayIndexPath(i)))) + errs = append(errs, validation.NewError(wrapped, path.AppendArrayIndex(i))) } } return errs } // validatePgIdentConf checks that every non-comment pg_ident_conf entry parses. -func validatePgIdentConf(lines []string, path []string) []error { +func validatePgIdentConf(lines []string, path validation.Path) []error { var errs []error for i, line := range lines { if hba.IsComment(line) { @@ -108,7 +67,7 @@ func validatePgIdentConf(lines []string, path []string) []error { } if _, err := hba.ParseIdent(line); err != nil { wrapped := fmt.Errorf("invalid pg_ident entry %q: %w", line, err) - errs = append(errs, newValidationError(wrapped, appendPath(path, arrayIndexPath(i)))) + errs = append(errs, validation.NewError(wrapped, path.AppendArrayIndex(i))) } } return errs @@ -117,23 +76,24 @@ func validatePgIdentConf(lines []string, path []string) []error { func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, spec *api.DatabaseSpec) error { var errs []error - errs = append(errs, validateCPUs(spec.Cpus, []string{"cpus"})...) - errs = append(errs, validateMemory(spec.Memory, []string{"memory"})...) - errs = append(errs, validatePorts(spec.Port, spec.PatroniPort, []string{"port"})) - errs = append(errs, validateUsers(spec.DatabaseUsers, []string{"database_users"})...) - errs = append(errs, validateScripts(spec.Scripts, []string{"scripts"})...) + errs = append(errs, validateCPUs(spec.Cpus, validation.NewPath("cpus"))...) + errs = append(errs, validateMemory(spec.Memory, validation.NewPath("memory"))...) + errs = append(errs, validateUniquePorts(spec)...) + errs = append(errs, validateUsers(spec.DatabaseUsers, validation.NewPath("database_users"))...) + errs = append(errs, validateScripts(spec.Scripts, validation.NewPath("scripts"))...) // Track node-name uniqueness and prepare set for cross-node checks. seenNodeNames := make(ds.Set[string], len(spec.Nodes)) // Track nodes that themselves have a source_node (treated as "new" nodes). newNodesWithSource := make(ds.Set[string], len(spec.Nodes)) + nodesPath := validation.NewPath("nodes") for i, node := range spec.Nodes { - nodePath := []string{"nodes", arrayIndexPath(i)} + nodePath := nodesPath.AppendArrayIndex(i) if seenNodeNames.Has(node.Name) { err := errors.New("node names must be unique within a database") - errs = append(errs, newValidationError(err, nodePath)) + errs = append(errs, validation.NewError(err, nodePath)) } seenNodeNames.Add(node.Name) @@ -154,11 +114,11 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s continue } - srcPath := []string{"nodes", arrayIndexPath(i), "source_node"} + srcPath := nodesPath.AppendArrayIndex(i).Append("source_node") if !seenNodeNames.Has(src) { // Attach error to the specific field path - errs = append(errs, newValidationError(errors.New("source node does not exist"), + errs = append(errs, validation.NewError(errors.New("source node does not exist"), srcPath)) continue } @@ -166,7 +126,7 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s // prevent using a "new" node (one that has its own source_node) // as the source for another node. if newNodesWithSource.Has(src) { - errs = append(errs, newValidationError( + errs = append(errs, validation.NewError( errors.New("source node must refer to an existing node"), srcPath, )) @@ -175,44 +135,39 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s // Reject postgresql_conf GUCs that would make user-supplied pg_hba/pg_ident // entries ineffective, then validate the entries themselves. - errs = append(errs, validateAuthFileGUCs(spec.PostgresqlConf, []string{"postgresql_conf"})...) - errs = append(errs, validatePgHbaConf(spec.PgHbaConf, []string{"pg_hba_conf"})...) - errs = append(errs, validatePgIdentConf(spec.PgIdentConf, []string{"pg_ident_conf"})...) + errs = append(errs, validateAuthFileGUCs(spec.PostgresqlConf, validation.NewPath("postgresql_conf"))...) + errs = append(errs, validatePgHbaConf(spec.PgHbaConf, validation.NewPath("pg_hba_conf"))...) + errs = append(errs, validatePgIdentConf(spec.PgIdentConf, validation.NewPath("pg_ident_conf"))...) if spec.BackupConfig != nil { - errs = append(errs, validateBackupConfig(spec.BackupConfig, []string{"backup_config"})...) + errs = append(errs, validateBackupConfig(spec.BackupConfig, validation.NewPath("backup_config"))...) } if spec.RestoreConfig != nil { - errs = append(errs, validateRestoreConfig(spec.RestoreConfig, []string{"restore_config"})...) + errs = append(errs, validateRestoreConfig(spec.RestoreConfig, validation.NewPath("restore_config"))...) } // Validate orchestrator_opts (spec-level) - errs = append(errs, validateOrchestratorOpts(spec.OrchestratorOpts, []string{"orchestrator_opts"})...) - - // Validate services — seed portOwner with Postgres ports so services can't collide with the database. - portOwner := make(servicePortOwnerMap) - seedPostgresPorts(spec, portOwner) + errs = append(errs, validateOrchestratorOpts(spec.OrchestratorOpts, validation.NewPath("orchestrator_opts"))...) - servicesPath := []string{"services"} + servicesPath := validation.NewPath("services") switch orchestrator { case config.OrchestratorSystemD: if len(spec.Services) != 0 { - errs = append(errs, newValidationError(errors.New("services are not yet supported for systemd clusters"), servicesPath)) + errs = append(errs, validation.NewError(errors.New("services are not yet supported for systemd clusters"), servicesPath)) } default: seenServiceIDs := make(ds.Set[string], len(spec.Services)) for i, svc := range spec.Services { - svcPath := appendPath(servicesPath, arrayIndexPath(i)) + svcPath := servicesPath.AppendArrayIndex(i) // Check for duplicate service IDs if seenServiceIDs.Has(string(svc.ServiceID)) { err := errors.New("service IDs must be unique within a database") - errs = append(errs, newValidationError(err, svcPath)) + errs = append(errs, validation.NewError(err, svcPath)) } seenServiceIDs.Add(string(svc.ServiceID)) - errs = append(errs, validateServicePortConflicts(svc, svcPath, portOwner)...) errs = append(errs, validateServiceSpec(svc, svcPath, false, databaseID, spec.DatabaseUsers, seenNodeNames)...) } } @@ -243,8 +198,8 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { if !existingNodeNames.Has(src) { // Newly added node is trying to use a new/non-existing node as source. - path := []string{"nodes", arrayIndexPath(i), "source_node"} - errs = append(errs, newValidationError( + path := validation.NewPath("nodes", validation.ArrayIndexElement(i), "source_node") + errs = append(errs, validation.NewError( errors.New("source node must refer to an existing node"), path, )) @@ -269,18 +224,13 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { existingServiceIDs.Add(svc.ServiceID) } - // Seed portOwner with Postgres ports so services can't collide with the database. - portOwner := make(servicePortOwnerMap) - seedPostgresPorts(new, portOwner) - // Validate each service. Pass isUpdate=false for services being added for the // first time so that bootstrap-only fields are accepted. For service types that // have no bootstrap fields (e.g. postgrest) the flag has no effect. for i, svc := range new.Services { - svcPath := []string{"services", arrayIndexPath(i)} + svcPath := validation.NewPath("services", validation.ArrayIndexElement(i)) isExistingService := existingServiceIDs.Has(string(svc.ServiceID)) - errs = append(errs, validateServicePortConflicts(svc, svcPath, portOwner)...) errs = append(errs, validateServiceSpec(svc, svcPath, isExistingService, old.DatabaseID, new.DatabaseUsers, newNodeNames)...) } @@ -291,37 +241,26 @@ func validateNode( orchestrator config.Orchestrator, db *api.DatabaseSpec, node *api.DatabaseNodeSpec, - path []string, + path validation.Path, ) []error { var errs []error - cpusPath := appendPath(path, "cpus") + cpusPath := path.Append("cpus") errs = append(errs, validateCPUs(node.Cpus, cpusPath)...) - memPath := appendPath(path, "memory") + memPath := path.Append("memory") errs = append(errs, validateMemory(node.Memory, memPath)...) - port := db.Port - if node.Port != nil { - port = node.Port - } - patroniPort := db.PatroniPort - if node.PatroniPort != nil { - patroniPort = node.PatroniPort - } - portPath := appendPath(path, "port") - errs = append(errs, validatePorts(port, patroniPort, portPath)) - seenHostIDs := make(ds.Set[string], len(node.HostIds)) for i, h := range node.HostIds { hostID := string(h) - hostPath := appendPath(path, "host_ids", arrayIndexPath(i)) + hostPath := path.Append("host_ids").AppendArrayIndex(i) errs = append(errs, validateIdentifier(hostID, hostPath)) if seenHostIDs.Has(hostID) { err := errors.New("host IDs must be unique within a node") - errs = append(errs, newValidationError(err, hostPath)) + errs = append(errs, validation.NewError(err, hostPath)) } seenHostIDs.Add(hostID) @@ -329,88 +268,88 @@ func validateNode( // source_node + restore_config validation (field-level) src := utils.FromPointer(node.SourceNode) - srcPath := appendPath(path, "source_node") + srcPath := path.Append("source_node") // If restore_config is provided, source_node must be empty if node.RestoreConfig != nil && src != "" { - errs = append(errs, newValidationError(errors.New("specify either source_node or restore_config"), srcPath)) + errs = append(errs, validation.NewError(errors.New("specify either source_node or restore_config"), srcPath)) } else if src != "" { // Self-reference is invalid if src == node.Name { - errs = append(errs, newValidationError(errors.New("a node cannot use itself as a source node"), srcPath)) + errs = append(errs, validation.NewError(errors.New("a node cannot use itself as a source node"), srcPath)) } } - errs = append(errs, validateAuthFileGUCs(node.PostgresqlConf, appendPath(path, "postgresql_conf"))...) - errs = append(errs, validatePgHbaConf(node.PgHbaConf, appendPath(path, "pg_hba_conf"))...) - errs = append(errs, validatePgIdentConf(node.PgIdentConf, appendPath(path, "pg_ident_conf"))...) + errs = append(errs, validateAuthFileGUCs(node.PostgresqlConf, path.Append("postgresql_conf"))...) + errs = append(errs, validatePgHbaConf(node.PgHbaConf, path.Append("pg_hba_conf"))...) + errs = append(errs, validatePgIdentConf(node.PgIdentConf, path.Append("pg_ident_conf"))...) if node.BackupConfig != nil { - backupConfigPath := appendPath(path, "backup_config") + backupConfigPath := path.Append("backup_config") errs = append(errs, validateBackupConfig(node.BackupConfig, backupConfigPath)...) } if node.RestoreConfig != nil { - restoreConfigPath := appendPath(path, "restore_config") + restoreConfigPath := path.Append("restore_config") errs = append(errs, validateRestoreConfig(node.RestoreConfig, restoreConfigPath)...) } switch orchestrator { case config.OrchestratorSystemD: if db.Port == nil && node.Port == nil { - portPath := appendPath(path, "port") - errs = append(errs, newValidationError(errors.New("port must be defined"), portPath)) + portPath := path.Append("port") + errs = append(errs, validation.NewError(errors.New("port must be defined"), portPath)) } if db.PatroniPort == nil && node.PatroniPort == nil { - portPath := appendPath(path, "patroni_port") - errs = append(errs, newValidationError(errors.New("patroni_port must be defined"), portPath)) + portPath := path.Append("patroni_port") + errs = append(errs, validation.NewError(errors.New("patroni_port must be defined"), portPath)) } } // Validate orchestrator_opts (per-node) - errs = append(errs, validateOrchestratorOpts(node.OrchestratorOpts, appendPath(path, "orchestrator_opts"))...) + errs = append(errs, validateOrchestratorOpts(node.OrchestratorOpts, path.Append("orchestrator_opts"))...) return errs } -func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, databaseID string, dbUsers []*api.DatabaseUserSpec, nodeNames ...ds.Set[string]) []error { +func validateServiceSpec(svc *api.ServiceSpec, path validation.Path, isUpdate bool, databaseID string, dbUsers []*api.DatabaseUserSpec, nodeNames ...ds.Set[string]) []error { var errs []error // Validate service_id - serviceIDPath := appendPath(path, "service_id") + serviceIDPath := path.Append("service_id") errs = append(errs, validateIdentifier(string(svc.ServiceID), serviceIDPath)) // Enforce Docker Swarm service name budget: "{databaseID}-{serviceID}-{8charHash}" must be ≤63 chars. if len(databaseID)+len(string(svc.ServiceID)) > 53 { err := fmt.Errorf("database ID and service ID combined must not exceed 53 characters (got %d)", len(databaseID)+len(string(svc.ServiceID))) - errs = append(errs, newValidationError(err, serviceIDPath)) + errs = append(errs, validation.NewError(err, serviceIDPath)) } // Validate service_type allowlist - supportedServiceTypes := []string{"mcp", "postgrest", "rag"} + supportedServiceTypes := validation.NewPath("mcp", "postgrest", "rag") if !slices.Contains(supportedServiceTypes, svc.ServiceType) { err := fmt.Errorf("unsupported service type %q (supported: %s)", svc.ServiceType, strings.Join(supportedServiceTypes, ", ")) - errs = append(errs, newValidationError(err, appendPath(path, "service_type"))) + errs = append(errs, validation.NewError(err, path.Append("service_type"))) } // Validate version (semver pattern or "latest") if svc.Version != "latest" && !semverPattern.MatchString(svc.Version) { err := errors.New("version must be in semver format (e.g., '1.0.0') or 'latest'") - errs = append(errs, newValidationError(err, appendPath(path, "version"))) + errs = append(errs, validation.NewError(err, path.Append("version"))) } // Validate host_ids (uniqueness and format) seenHostIDs := make(ds.Set[string], len(svc.HostIds)) for i, hostID := range svc.HostIds { hostIDStr := string(hostID) - hostIDPath := appendPath(path, "host_ids", arrayIndexPath(i)) + hostIDPath := path.Append("host_ids").AppendArrayIndex(i) errs = append(errs, validateIdentifier(hostIDStr, hostIDPath)) // may need to relax this if there is a use-case for multiple service instances on the same host if seenHostIDs.Has(hostIDStr) { err := errors.New("host IDs must be unique within a service") - errs = append(errs, newValidationError(err, hostIDPath)) + errs = append(errs, validation.NewError(err, hostIDPath)) } seenHostIDs.Add(hostIDStr) } @@ -418,16 +357,16 @@ func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, dat // Validate config based on service_type switch svc.ServiceType { case "mcp": - errs = append(errs, validateMCPServiceConfig(svc.Config, appendPath(path, "config"), isUpdate)...) + errs = append(errs, validateMCPServiceConfig(svc.Config, path.Append("config"), isUpdate)...) case "postgrest": - errs = append(errs, validatePostgRESTServiceConfig(svc.Config, appendPath(path, "config"))...) + errs = append(errs, validatePostgRESTServiceConfig(svc.Config, path.Append("config"))...) case "rag": - errs = append(errs, validateRAGServiceConfig(svc.Config, appendPath(path, "config"), isUpdate)...) + errs = append(errs, validateRAGServiceConfig(svc.Config, path.Append("config"), isUpdate)...) } // Validate database_connection if provided if svc.DatabaseConnection != nil { - dcPath := appendPath(path, "database_connection") + dcPath := path.Append("database_connection") var nn ds.Set[string] if len(nodeNames) > 0 { nn = nodeNames[0] @@ -445,31 +384,31 @@ func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, dat writeSafe := map[string]bool{database.TargetSessionAttrsPrimary: true, database.TargetSessionAttrsReadWrite: true} if tsa != "" && !writeSafe[tsa] { err := fmt.Errorf("allow_writes requires target_session_attrs 'primary' or 'read-write', got '%s'", tsa) - errs = append(errs, newValidationError(err, appendPath(path, "database_connection", "target_session_attrs"))) + errs = append(errs, validation.NewError(err, path.Append("database_connection", "target_session_attrs"))) } } } // Validate cpus if provided if svc.Cpus != nil { - errs = append(errs, validateCPUs(svc.Cpus, appendPath(path, "cpus"))...) + errs = append(errs, validateCPUs(svc.Cpus, path.Append("cpus"))...) } // Validate memory if provided if svc.Memory != nil { - errs = append(errs, validateMemory(svc.Memory, appendPath(path, "memory"))...) + errs = append(errs, validateMemory(svc.Memory, path.Append("memory"))...) } // Validate orchestrator_opts (service-specific restrictions on top of shared checks) - errs = append(errs, validateServiceOrchestratorOpts(svc.OrchestratorOpts, appendPath(path, "orchestrator_opts"))...) + errs = append(errs, validateServiceOrchestratorOpts(svc.OrchestratorOpts, path.Append("orchestrator_opts"))...) return errs } -func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, path []string) []error { - connectAsPath := appendPath(path, "connect_as") +func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, path validation.Path) []error { + connectAsPath := path.Append("connect_as") if svc.ConnectAs == "" { - return []error{newValidationError(errors.New("connect_as is required"), connectAsPath)} + return []error{validation.NewError(errors.New("connect_as is required"), connectAsPath)} } for _, u := range dbUsers { @@ -479,42 +418,42 @@ func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, pa } err := fmt.Errorf("connect_as %q does not match any database_users entry", svc.ConnectAs) - return []error{newValidationError(err, connectAsPath)} + return []error{validation.NewError(err, connectAsPath)} } -func validateMCPServiceConfig(config map[string]any, path []string, isUpdate bool) []error { +func validateMCPServiceConfig(config map[string]any, path validation.Path, isUpdate bool) []error { _, errs := database.ParseMCPServiceConfig(config, isUpdate) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validatePostgRESTServiceConfig(config map[string]any, path []string) []error { +func validatePostgRESTServiceConfig(config map[string]any, path validation.Path) []error { _, errs := database.ParsePostgRESTServiceConfig(config) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validateDatabaseConnection(dc *api.DatabaseConnection, path []string, nodeNames ds.Set[string]) []error { +func validateDatabaseConnection(dc *api.DatabaseConnection, path validation.Path, nodeNames ds.Set[string]) []error { var errs []error // Validate target_nodes: no duplicates, no empty strings, must exist in spec if dc.TargetNodes != nil { seen := make(ds.Set[string], len(dc.TargetNodes)) for i, node := range dc.TargetNodes { - nodePath := appendPath(path, "target_nodes", arrayIndexPath(i)) + nodePath := path.Append("target_nodes").AppendArrayIndex(i) if node == "" { - errs = append(errs, newValidationError(errors.New("node name must not be empty"), nodePath)) + errs = append(errs, validation.NewError(errors.New("node name must not be empty"), nodePath)) } else if nodeNames != nil && !nodeNames.Has(node) { - errs = append(errs, newValidationError(fmt.Errorf("node %q does not exist in the database spec", node), nodePath)) + errs = append(errs, validation.NewError(fmt.Errorf("node %q does not exist in the database spec", node), nodePath)) } if seen.Has(node) { - errs = append(errs, newValidationError(fmt.Errorf("duplicate node name %q", node), nodePath)) + errs = append(errs, validation.NewError(fmt.Errorf("duplicate node name %q", node), nodePath)) } seen.Add(node) } @@ -531,74 +470,117 @@ func validateDatabaseConnection(dc *api.DatabaseConnection, path []string, nodeN } if !valid[*dc.TargetSessionAttrs] { err := fmt.Errorf("invalid target_session_attrs %q (must be primary, prefer-standby, standby, read-write, or any)", *dc.TargetSessionAttrs) - errs = append(errs, newValidationError(err, appendPath(path, "target_session_attrs"))) + errs = append(errs, validation.NewError(err, path.Append("target_session_attrs"))) } } return errs } -func validateRAGServiceConfig(config map[string]any, path []string, isUpdate bool) []error { +func validateRAGServiceConfig(config map[string]any, path validation.Path, isUpdate bool) []error { _, errs := database.ParseRAGServiceConfig(config, isUpdate) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validateCPUs(value *string, path []string) []error { +func validateCPUs(value *string, path validation.Path) []error { var errs []error cpus, err := parseCPUs(value) if err != nil { - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } if cpus != 0 && cpus < 0.001 { err := errors.New("cannot be less than 1 millicpu") - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } return errs } -func validateMemory(value *string, path []string) []error { +func validateMemory(value *string, path validation.Path) []error { var errs []error _, err := parseBytes(value) if err != nil { - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } return errs } -func validatePorts(postgresPort, patroniPort *int, path []string) error { - postgres := utils.FromPointer(postgresPort) - patroni := utils.FromPointer(patroniPort) +func validateUniquePorts(spec *api.DatabaseSpec) []error { + hostPorts := map[string]*validation.Unique[int]{} + specPostgresPort := utils.FromPointer(spec.Port) + specPatroniPort := utils.FromPointer(spec.PatroniPort) - if postgres > 0 && postgres == patroni { - return newValidationError(errors.New("postgres and patroni ports must not conflict"), path) + nodesPath := validation.NewPath("nodes") + for i, node := range spec.Nodes { + postgresPort := utils.FromPointer(node.Port) + if postgresPort == 0 { + postgresPort = specPostgresPort + } + patroniPort := utils.FromPointer(node.PatroniPort) + if patroniPort == 0 { + patroniPort = specPatroniPort + } + + nodePath := nodesPath.AppendArrayIndex(i) + for _, h := range node.HostIds { + hostID := string(h) + if _, ok := hostPorts[hostID]; !ok { + hostPorts[hostID] = validation.NewUnique[int]() + } + if postgresPort != 0 { + hostPorts[hostID].RecordSeen(nodePath.Append("port"), postgresPort) + } + if patroniPort != 0 { + hostPorts[hostID].RecordSeen(nodePath.Append("patroni_port"), patroniPort) + } + } } - return nil + servicesPath := validation.NewPath("services") + for i, service := range spec.Services { + servicePath := servicesPath.AppendArrayIndex(i) + + for _, h := range service.HostIds { + hostID := string(h) + if _, ok := hostPorts[hostID]; !ok { + hostPorts[hostID] = validation.NewUnique[int]() + } + if port := utils.FromPointer(service.Port); port != 0 { + hostPorts[hostID].RecordSeen(servicePath.Append("port"), port) + } + } + } + + var errs []error + for _, hostID := range slices.Sorted(maps.Keys(hostPorts)) { + errs = append(errs, hostPorts[hostID].Validate(fmt.Errorf("duplicate ports allocated on host '%s'", hostID))...) + } + + return errs } -func validateUsers(users []*api.DatabaseUserSpec, path []string) []error { +func validateUsers(users []*api.DatabaseUserSpec, path validation.Path) []error { var errs []error seenNames := ds.NewSet[string]() var hasOwner bool for i, user := range users { - userPath := appendPath(path, arrayIndexPath(i)) + userPath := path.AppendArrayIndex(i) if seenNames.Has(user.Username) { err := errors.New("usernames must be unique within a database") - errs = append(errs, newValidationError(err, userPath)) + errs = append(errs, validation.NewError(err, userPath)) } if user.DbOwner != nil && *user.DbOwner && hasOwner { err := errors.New("cannot have multiple users with db_owner = true") - errs = append(errs, newValidationError(err, userPath)) + errs = append(errs, validation.NewError(err, userPath)) } seenNames.Add(user.Username) @@ -611,83 +593,33 @@ func validateUsers(users []*api.DatabaseUserSpec, path []string) []error { return errs } -// seedPostgresPorts registers each node's effective Postgres port in the -// portOwner map so that service port validation can detect collisions with -// the database. A node-level port override (node.Port) takes precedence -// over the spec-level default (spec.Port). -func seedPostgresPorts(spec *api.DatabaseSpec, owner servicePortOwnerMap) { - for _, node := range spec.Nodes { - pgPort := utils.FromPointer(spec.Port) - if node.Port != nil { - pgPort = *node.Port - } - if pgPort > 0 { - for _, hostID := range node.HostIds { - owner[hostPort{hostID: string(hostID), port: pgPort}] = "postgres" - } - } - } -} - -// hostPort identifies a unique (host, port) binding for cross-service -// port conflict detection. -type hostPort struct { - hostID string - port int -} - -// servicePortOwnerMap tracks which service owns a given (host, port) pair. -// Callers create one map and pass it to validateServicePortConflicts for -// each service in the spec. -type servicePortOwnerMap map[hostPort]string - -// validateServicePortConflicts checks that the service's explicit port (if any) -// does not collide with a port already claimed by another service on the same host. -func validateServicePortConflicts(svc *api.ServiceSpec, path []string, owner servicePortOwnerMap) []error { - if svc.Port == nil || *svc.Port <= 0 { - return nil - } - - var errs []error - for _, hostID := range svc.HostIds { - key := hostPort{hostID: string(hostID), port: *svc.Port} - if prev, exists := owner[key]; exists { - err := fmt.Errorf("port %d conflicts with service %q on the same host", *svc.Port, prev) - errs = append(errs, newValidationError(err, appendPath(path, "port"))) - } else { - owner[key] = string(svc.ServiceID) - } - } - return errs -} - -func validateBackupConfig(cfg *api.BackupConfigSpec, path []string) []error { +func validateBackupConfig(cfg *api.BackupConfigSpec, path validation.Path) []error { var errs []error for i, repo := range cfg.Repositories { - repoPath := appendPath(path, "repositories", arrayIndexPath(i)) + repoPath := path.Append("repositories").AppendArrayIndex(i) errs = append(errs, validateBackupRepository(repo, repoPath)...) } return errs } -func validateRestoreConfig(cfg *api.RestoreConfigSpec, path []string) []error { +func validateRestoreConfig(cfg *api.RestoreConfigSpec, path validation.Path) []error { var errs []error - sourceDbIdPath := appendPath(path, "source_database_id") + sourceDbIdPath := path.Append("source_database_id") errs = append(errs, validateIdentifier(string(cfg.SourceDatabaseID), sourceDbIdPath)) - repoPath := appendPath(path, "repository") + repoPath := path.Append("repository") errs = append(errs, validateRestoreRepository(cfg.Repository, repoPath)...) - restoreOptsPath := appendPath(path, "restore_options") + restoreOptsPath := path.Append("restore_options") errs = append(errs, validatePgBackRestOptions(cfg.RestoreOptions, restoreOptsPath)...) return errs } -func validateBackupRepository(cfg *api.BackupRepositorySpec, path []string) []error { +func validateBackupRepository(cfg *api.BackupRepositorySpec, path validation.Path) []error { props := repoProperties{ id: cfg.ID, repoType: cfg.Type, @@ -704,7 +636,7 @@ func validateBackupRepository(cfg *api.BackupRepositorySpec, path []string) []er return validateRepoProperties(props, path) } -func validateRestoreRepository(cfg *api.RestoreRepositorySpec, path []string) []error { +func validateRestoreRepository(cfg *api.RestoreRepositorySpec, path validation.Path) []error { props := repoProperties{ id: cfg.ID, repoType: cfg.Type, @@ -734,12 +666,12 @@ type repoProperties struct { customOptions map[string]string } -func validateRepoProperties(props repoProperties, path []string) []error { +func validateRepoProperties(props repoProperties, path validation.Path) []error { var errs []error id := utils.FromPointer(props.id) if id != "" { - idPath := appendPath(path, "id") + idPath := path.Append("id") errs = append(errs, validateIdentifier(string(id), idPath)) } @@ -754,70 +686,70 @@ func validateRepoProperties(props repoProperties, path []string) []error { case pgbackrest.RepositoryTypeS3: errs = append(errs, validateS3RepoProperties(props, path)...) default: - err := newValidationError( + err := validation.NewError( fmt.Errorf("unsupported repo type '%s'", repoType), - appendPath(path, "type"), + path.Append("type"), ) errs = append(errs, err) } - customOptsPath := appendPath(path, "custom_options") + customOptsPath := path.Append("custom_options") errs = append(errs, validatePgBackRestOptions(props.customOptions, customOptsPath)...) return errs } -func validateAzureRepoProperties(props repoProperties, path []string) []error { +func validateAzureRepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.azureAccount) == "" { err := errors.New("azure_account is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_account"))) + errs = append(errs, validation.NewError(err, path.Append("azure_account"))) } if utils.FromPointer(props.azureContainer) == "" { err := errors.New("azure_container is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_container"))) + errs = append(errs, validation.NewError(err, path.Append("azure_container"))) } if utils.FromPointer(props.azureKey) == "" { err := errors.New("azure_key is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_key"))) + errs = append(errs, validation.NewError(err, path.Append("azure_key"))) } return errs } -func validateFSRepoProperties(props repoProperties, path []string) []error { +func validateFSRepoProperties(props repoProperties, path validation.Path) []error { var errs []error basePath := utils.FromPointer(props.basePath) if basePath == "" { err := fmt.Errorf("base_path is required for %s repositories", props.repoType) - errs = append(errs, newValidationError(err, appendPath(path, "base_path"))) + errs = append(errs, validation.NewError(err, path.Append("base_path"))) } else if !filepath.IsAbs(*props.basePath) { err := fmt.Errorf("base_path must be absolute for %s repositories", props.repoType) - errs = append(errs, newValidationError(err, appendPath(path, "base_path"))) + errs = append(errs, validation.NewError(err, path.Append("base_path"))) } return errs } -func validateGCSRepoProperties(props repoProperties, path []string) []error { +func validateGCSRepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.gcsBucket) == "" { err := errors.New("gcs_bucket is required for gcs repositories") - errs = append(errs, newValidationError(err, appendPath(path, "gcs_bucket"))) + errs = append(errs, validation.NewError(err, path.Append("gcs_bucket"))) } return errs } -func validateS3RepoProperties(props repoProperties, path []string) []error { +func validateS3RepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.s3Bucket) == "" { err := errors.New("s3_bucket is required for s3 repositories") - errs = append(errs, newValidationError(err, appendPath(path, "s3_bucket"))) + errs = append(errs, validation.NewError(err, path.Append("s3_bucket"))) } return errs @@ -829,7 +761,7 @@ var semverPattern = regexp.MustCompile(`^\d+\.\d+(\.\d+)?$`) // reservedLabelPrefix is the label key prefix reserved for system use. const reservedLabelPrefix = "pgedge." -func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error { +func validateOrchestratorOpts(opts *api.OrchestratorOpts, path validation.Path) []error { if opts == nil || opts.Swarm == nil { return nil } @@ -837,9 +769,9 @@ func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error var errs []error for key := range opts.Swarm.ExtraLabels { if strings.HasPrefix(key, reservedLabelPrefix) { - labelPath := appendPath(path, "swarm", "extra_labels", mapKeyPath(key)) + labelPath := path.Append("swarm", "extra_labels").AppendMapKey(key) err := fmt.Errorf("labels starting with %q are reserved for system use", reservedLabelPrefix) - errs = append(errs, newValidationError(err, labelPath)) + errs = append(errs, validation.NewError(err, labelPath)) } } return errs @@ -848,7 +780,7 @@ func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error // validateServiceOrchestratorOpts runs the shared orchestrator_opts checks and // adds service-specific restrictions. Services do not support extra_volumes // (bind mounts are configured per service type) or driver_opts on extra_networks. -func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error { +func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path validation.Path) []error { errs := validateOrchestratorOpts(opts, path) if opts == nil || opts.Swarm == nil { @@ -857,28 +789,28 @@ func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path []string) if len(opts.Swarm.ExtraVolumes) > 0 { err := errors.New("extra_volumes is not supported for services") - errs = append(errs, newValidationError(err, appendPath(path, "swarm", "extra_volumes"))) + errs = append(errs, validation.NewError(err, path.Append("swarm", "extra_volumes"))) } for i, net := range opts.Swarm.ExtraNetworks { if len(net.DriverOpts) > 0 { - netPath := appendPath(path, "swarm", "extra_networks", arrayIndexPath(i), "driver_opts") + netPath := path.Append("swarm", "extra_networks").AppendArrayIndex(i).Append("driver_opts") err := errors.New("driver_opts is not supported for services") - errs = append(errs, newValidationError(err, netPath)) + errs = append(errs, validation.NewError(err, netPath)) } } return errs } -func validatePgBackRestOptions(opts map[string]string, path []string) []error { +func validatePgBackRestOptions(opts map[string]string, path validation.Path) []error { var errs []error for key := range opts { if !pgBackRestOptionPattern.MatchString(key) { - optPath := appendPath(path, mapKeyPath(key)) + optPath := path.AppendMapKey(key) err := errors.New("invalid option name") - errs = append(errs, newValidationError(err, optPath)) + errs = append(errs, validation.NewError(err, optPath)) } } @@ -888,15 +820,15 @@ func validatePgBackRestOptions(opts map[string]string, path []string) []error { func validateBackupOptions(opts *api.BackupOptions) error { var errs []error - optsPath := []string{"backup_options"} + optsPath := validation.NewPath("backup_options") errs = append(errs, validatePgBackRestOptions(opts.BackupOptions, optsPath)...) return errors.Join(errs...) } -func validateIdentifier(ident string, path []string) error { +func validateIdentifier(ident string, path validation.Path) error { if err := utils.ValidateID(ident); err != nil { - return newValidationError(err, path) + return validation.NewError(err, path) } return nil @@ -919,20 +851,20 @@ func validateHostIDUniqueness(ctx context.Context, hostSvc *host.Service, hostID } } -func validateScripts(scripts *api.DatabaseScripts, path []string) []error { +func validateScripts(scripts *api.DatabaseScripts, path validation.Path) []error { if scripts == nil { return nil } return slices.Concat( - validateScript(scripts.PostInit, appendPath(path, "post_init")), - validateScript(scripts.PostDatabaseCreate, appendPath(path, "post_database_create")), + validateScript(scripts.PostInit, path.Append("post_init")), + validateScript(scripts.PostDatabaseCreate, path.Append("post_database_create")), ) } -func validateScript(statements []string, path []string) []error { +func validateScript(statements []string, path validation.Path) []error { var errs []error for i, statement := range statements { - statementPath := appendPath(path, arrayIndexPath(i)) + statementPath := path.AppendArrayIndex(i) if err := validateSQLStatement(statement, statementPath); err != nil { errs = append(errs, err) } @@ -940,11 +872,11 @@ func validateScript(statements []string, path []string) []error { return errs } -func validateSQLStatement(statement string, path []string) error { +func validateSQLStatement(statement string, path validation.Path) error { _, err := postgresparser.ParseSQLStrict(statement) if err != nil { err = fmt.Errorf("failed to parse SQL statement: %w", err) - return newValidationError(err, path) + return validation.NewError(err, path) } return nil } diff --git a/server/internal/api/apiv1/validate_test.go b/server/internal/api/apiv1/validate_test.go index 97f1a84e..b43aa701 100644 --- a/server/internal/api/apiv1/validate_test.go +++ b/server/internal/api/apiv1/validate_test.go @@ -12,25 +12,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidationError(t *testing.T) { - t.Run("with path", func(t *testing.T) { - err := newValidationError(errors.New("test error"), []string{ - "array", - arrayIndexPath(0), - "map", - mapKeyPath("key"), - }) - - assert.ErrorContains(t, err, "array[0].map[key]: test error") - }) - - t.Run("without path", func(t *testing.T) { - err := newValidationError(errors.New("test error"), nil) - - assert.ErrorContains(t, err, "test error") - }) -} - func TestValidateCPUs(t *testing.T) { for _, tc := range []struct { name string @@ -98,61 +79,124 @@ func TestValidateMemory(t *testing.T) { } } -func TestValidatePorts(t *testing.T) { +func TestValidateUniquePorts(t *testing.T) { for _, tc := range []struct { - name string - postgresPort *int - patroniPort *int - expected string + name string + spec *api.DatabaseSpec + expected []string }{ { - name: "both nil", - postgresPort: nil, - patroniPort: nil, - }, - { - name: "postgres port nil", - postgresPort: nil, - patroniPort: utils.PointerTo(8888), - }, - { - name: "patroni port nil", - postgresPort: utils.PointerTo(8888), - patroniPort: nil, - }, - { - name: "both zero", - postgresPort: utils.PointerTo(0), - patroniPort: utils.PointerTo(0), + name: "no ports", + spec: &api.DatabaseSpec{ + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, }, { - name: "postgres port zero", - postgresPort: nil, - patroniPort: utils.PointerTo(0), + name: "patroni and postgres port conflict", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(5432), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + }, }, { - name: "patroni port zero", - postgresPort: utils.PointerTo(0), - patroniPort: nil, + name: "patroni and postgres port conflict with override", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(8888), + PatroniPort: utils.PointerTo(8888), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(5432), + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + }, }, { - name: "both defined non-equal", - postgresPort: utils.PointerTo(5432), - patroniPort: utils.PointerTo(8888), + name: "service port conflict", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + Services: []*api.ServiceSpec{ + { + ServiceID: "mcp", + ServiceType: "mcp", + Port: utils.PointerTo(5432), + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, services[0].port`, + }, }, { - name: "conflicting", - postgresPort: utils.PointerTo(5432), - patroniPort: utils.PointerTo(5432), - expected: "postgres and patroni ports must not conflict", + name: "two nodes on same host", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(8888), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + { + Name: "n2", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, nodes[1].port`, + `duplicate ports allocated on host 'host-1': '8888' duplicated in: nodes[0].patroni_port, nodes[1].patroni_port`, + }, }, } { t.Run(tc.name, func(t *testing.T) { - err := validatePorts(tc.postgresPort, tc.patroniPort, nil) - if tc.expected == "" { + err := errors.Join(validateUniquePorts(tc.spec)...) + if len(tc.expected) < 1 { assert.NoError(t, err) } else { - assert.ErrorContains(t, err, tc.expected) + for _, expected := range tc.expected { + assert.ErrorContains(t, err, expected) + } } }) } @@ -493,71 +537,6 @@ func TestValidateNode(t *testing.T) { "patroni_port: patroni_port must be defined", }, }, - { - name: "invalid inherited ports", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(5432), - }, - node: &api.DatabaseNodeSpec{ - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid inherited db port", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(8888), - }, - node: &api.DatabaseNodeSpec{ - PatroniPort: utils.PointerTo(5432), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid inherited patroni port", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(8888), - }, - node: &api.DatabaseNodeSpec{ - Port: utils.PointerTo(8888), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid node ports", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{}, - node: &api.DatabaseNodeSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(5432), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, { name: "invalid", orchestrator: config.OrchestratorSwarm, @@ -802,7 +781,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `services[1].port: port 8080 conflicts with service "mcp-server" on the same host`, + `duplicate ports allocated on host 'host-1': '8080' duplicated in: services[0].port, services[1].port`, }, }, { @@ -991,7 +970,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `port 5432 conflicts with service "postgres" on the same host`, + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, services[0].port`, }, }, { @@ -1027,7 +1006,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `port 5433 conflicts with service "postgres" on the same host`, + `duplicate ports allocated on host 'host-1': '5433' duplicated in: nodes[0].port, services[0].port`, }, }, { @@ -1177,8 +1156,9 @@ func TestValidateDatabaseSpec(t *testing.T) { "nodes[2]: node names must be unique within a database", "backup_config.repositories[0].base_path: base_path must be absolute for posix repositories", "restore_config.repository.base_path: base_path must be absolute for posix repositories", - "port: postgres and patroni ports must not conflict", - "nodes[1].port: postgres and patroni ports must not conflict", + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + `duplicate ports allocated on host 'host-2': '8888' duplicated in: nodes[1].patroni_port, nodes[1].port`, + `duplicate ports allocated on host 'host-3': '5432' duplicated in: nodes[2].patroni_port, nodes[2].port`, }, }, { @@ -2342,38 +2322,6 @@ func TestValidateDatabaseUpdate_ServiceBootstrapFields(t *testing.T) { }, }, }, - { - name: "port conflict on update-database", - old: &database.Spec{}, - new: &api.DatabaseSpec{ - DatabaseUsers: []*api.DatabaseUserSpec{ - {Username: "app", DbOwner: utils.PointerTo(true)}, - }, - Services: []*api.ServiceSpec{ - { - ServiceID: "mcp-server", - ServiceType: "mcp", - Version: "latest", - HostIds: []api.Identifier{"host-1"}, - ConnectAs: "app", - Port: utils.PointerTo(8080), - Config: validMCPConfig, - }, - { - ServiceID: "postgrest-server", - ServiceType: "postgrest", - Version: "latest", - HostIds: []api.Identifier{"host-1"}, - ConnectAs: "app", - Port: utils.PointerTo(8080), - Config: map[string]any{}, - }, - }, - }, - expected: []string{ - `port 8080 conflicts with service "mcp-server" on the same host`, - }, - }, } { t.Run(tc.name, func(t *testing.T) { err := validateDatabaseUpdate(tc.old, tc.new) diff --git a/server/internal/ds/set.go b/server/internal/ds/set.go index ef829c77..004a77a5 100644 --- a/server/internal/ds/set.go +++ b/server/internal/ds/set.go @@ -1,7 +1,9 @@ package ds import ( + "cmp" "slices" + "strings" ) // Set is a generic set type. @@ -131,3 +133,17 @@ func SetDifference[T comparable](a, b []T) Set[T] { func SetSymmetricDifference[T comparable](a, b []T) Set[T] { return NewSet(a...).SymmetricDifference(NewSet(b...)) } + +// SetToString is a shortcut for producing a sorted, comma-separated string +// representation of a Set of string-ish values. +func SetToString[T ~string](s Set[T]) string { + lastIdx := s.Size() - 1 + var builder strings.Builder + for i, element := range s.ToSortedSlice(cmp.Compare) { + builder.WriteString(string(element)) + if i < lastIdx { + builder.WriteString(", ") + } + } + return builder.String() +} diff --git a/server/internal/validation/error.go b/server/internal/validation/error.go new file mode 100644 index 00000000..4f6e1c25 --- /dev/null +++ b/server/internal/validation/error.go @@ -0,0 +1,64 @@ +package validation + +import ( + "fmt" + "slices" + "strings" +) + +type Path []string + +func NewPath(elems ...string) Path { + return elems +} + +func ArrayIndexElement(idx int) string { + return fmt.Sprintf("[%d]", idx) +} + +func MapKeyElement(key string) string { + return fmt.Sprintf("[%s]", key) +} + +func (p Path) String() string { + var path strings.Builder + for i, ele := range p { + if i > 0 && !strings.HasPrefix(ele, "[") { + path.WriteString(".") + } + path.WriteString(ele) + } + return path.String() +} + +func (p Path) Append(elem ...string) Path { + return append(slices.Clone(p), elem...) +} + +func (p Path) AppendArrayIndex(idx int) Path { + return p.Append(ArrayIndexElement(idx)) +} + +func (p Path) AppendMapKey(key string) Path { + return p.Append(MapKeyElement(key)) +} + +type Error struct { + Path Path + Err error +} + +func NewError(err error, path Path) *Error { + return &Error{Err: err, Path: path} +} + +func (e *Error) Unwrap() error { + return e.Err +} + +func (e *Error) Error() string { + if len(e.Path) == 0 { + return e.Err.Error() + } + return fmt.Sprintf("%s: %s", e.Path.String(), e.Err.Error()) +} diff --git a/server/internal/validation/error_test.go b/server/internal/validation/error_test.go new file mode 100644 index 00000000..6e24d84e --- /dev/null +++ b/server/internal/validation/error_test.go @@ -0,0 +1,28 @@ +package validation_test + +import ( + "errors" + "testing" + + "github.com/pgEdge/control-plane/server/internal/validation" + "github.com/stretchr/testify/assert" +) + +func TestValidationError(t *testing.T) { + t.Run("with path", func(t *testing.T) { + err := validation.NewError(errors.New("test error"), validation.NewPath( + "array", + validation.ArrayIndexElement(0), + "map", + validation.MapKeyElement("key"), + )) + + assert.ErrorContains(t, err, "array[0].map[key]: test error") + }) + + t.Run("without path", func(t *testing.T) { + err := validation.NewError(errors.New("test error"), nil) + + assert.ErrorContains(t, err, "test error") + }) +} diff --git a/server/internal/validation/validators.go b/server/internal/validation/validators.go new file mode 100644 index 00000000..7603b012 --- /dev/null +++ b/server/internal/validation/validators.go @@ -0,0 +1,50 @@ +package validation + +import ( + "cmp" + "errors" + "fmt" + "maps" + "slices" + + "github.com/pgEdge/control-plane/server/internal/ds" +) + +var ErrUnique = errors.New("must be unique") + +type Unique[T cmp.Ordered] struct { + seen map[T]ds.Set[string] +} + +func NewUnique[T cmp.Ordered]() *Unique[T] { + return &Unique[T]{ + seen: map[T]ds.Set[string]{}, + } +} + +func (u *Unique[T]) RecordSeen(path Path, value T) { + if u.seen == nil { + u.seen = make(map[T]ds.Set[string]) + } + if _, ok := u.seen[value]; !ok { + u.seen[value] = ds.NewSet[string]() + } + u.seen[value].Add(path.String()) +} + +func (u *Unique[T]) Validate(base error) []error { + if base == nil { + base = ErrUnique + } + var errs []error + for _, key := range slices.Sorted(maps.Keys(u.seen)) { + paths := u.seen[key] + if len(paths) <= 1 { + continue + } + errs = append(errs, &Error{ + Err: fmt.Errorf("%w: '%v' duplicated in: %s", base, key, ds.SetToString(paths)), + }) + } + return errs +}