Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 104 additions & 33 deletions pkg/db/sql_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,62 @@ var statusFieldMappings = map[string]string{
"status.conditions": "status_conditions",
}

// OrderByAllowedFields defines valid columns for orderBy and search operations per resource type.
// When adding new GORM columns to model structs, update this map to make them sortable/searchable.
var OrderByAllowedFields = map[string]map[string]bool{
"Cluster": {
"id": true,
"name": true,
"created_time": true,
"updated_time": true,
"deleted_time": true,
"kind": true,
"created_by": true,
"updated_by": true,
"deleted_by": true,
"generation": true,
"href": true,
"status_conditions": true, // mapped from status.conditions
},
"NodePool": {
"id": true,
"name": true,
"created_time": true,
"updated_time": true,
"deleted_time": true,
"kind": true,
"created_by": true,
"updated_by": true,
"deleted_by": true,
"generation": true,
"href": true,
"owner_id": true,
"owner_kind": true,
"owner_href": true,
"status_conditions": true,
},
"Resource": {
"id": true,
"name": true,
"created_time": true,
"updated_time": true,
"deleted_time": true,
"kind": true,
"created_by": true,
"updated_by": true,
"deleted_by": true,
"generation": true,
"href": true,
"owner_id": true,
"owner_kind": true,
"owner_href": true,
},
}

// getField gets the sql field associated with a name.
func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) {
func getField(
name string, disallowedFields map[string]string, allowedFields map[string]bool,
) (field string, err *errors.ServiceError) {
// We want to accept names with trailing and leading spaces
trimmedName := strings.Trim(name, " ")

Expand Down Expand Up @@ -126,6 +180,13 @@ func getField(name string, disallowedFields map[string]string) (field string, er
err = errors.BadRequest("%s is not a valid field name", name)
return
}

// Validate field name against allowlist to prevent SQL injection
if allowedFields != nil && !allowedFields[trimmedName] {
err = errors.BadRequest("field '%s' is not valid for ordering or searching", name)
return
}

field = trimmedName
return
}
Expand Down Expand Up @@ -569,7 +630,8 @@ func extractConditionsWalk(n tsl.Node, conditions *[]sq.Sqlizer) (tsl.Node, *err
// b. replace the field name with the SQL column name.
func FieldNameWalk(
n tsl.Node,
disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) {
disallowedFields map[string]string,
allowedFields map[string]bool) (newNode tsl.Node, err *errors.ServiceError) {

var field string
var l, r tsl.Node
Expand All @@ -591,7 +653,7 @@ func FieldNameWalk(
}

// Check field name in the disallowedFields field names.
field, err = getField(userFieldName, disallowedFields)
field, err = getField(userFieldName, disallowedFields, allowedFields)
if err != nil {
return
}
Expand All @@ -609,7 +671,7 @@ func FieldNameWalk(
err = errors.BadRequest("invalid node structure")
return
}
l, err = FieldNameWalk(leftNode, disallowedFields)
l, err = FieldNameWalk(leftNode, disallowedFields, allowedFields)
if err != nil {
return
}
Expand All @@ -620,7 +682,7 @@ func FieldNameWalk(
switch v := n.Right.(type) {
case tsl.Node:
// It's a regular node, just add it.
r, err = FieldNameWalk(v, disallowedFields)
r, err = FieldNameWalk(v, disallowedFields, allowedFields)
if err != nil {
return
}
Expand All @@ -634,7 +696,7 @@ func FieldNameWalk(

// Add all nodes in the right side array.
for _, e := range v {
r, err = FieldNameWalk(e, disallowedFields)
r, err = FieldNameWalk(e, disallowedFields, allowedFields)
if err != nil {
return
}
Expand All @@ -661,51 +723,60 @@ func FieldNameWalk(
}

// cleanOrderBy takes the orderBy arg and cleans it.
func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) {
var orderField string

func cleanOrderBy(
userArg string, disallowedFields map[string]string, allowedFields map[string]bool,
) (orderBy string, err *errors.ServiceError) {
// We want to accept user params with trailing and leading spaces
trimedName := strings.Trim(userArg, " ")
trimmedName := strings.Trim(userArg, " ")

// Each OrderBy can be a "<field-name>" or a "<field-name> asc|desc"
order := strings.Split(trimedName, " ")
direction := "none valid"
order := strings.Split(trimmedName, " ")

if len(order) == 1 {
orderField, err = getField(order[0], disallowedFields)
direction = "asc"
} else if len(order) == 2 {
orderField, err = getField(order[0], disallowedFields)
direction = order[1]
}
if err != nil || (direction != "asc" && direction != "desc") {
// Reject invalid format (e.g., subqueries with multiple spaces)
if len(order) != 1 && len(order) != 2 {
err = errors.BadRequest("bad order value '%s'", userArg)
return
}

// Validate field name
orderField, err := getField(order[0], disallowedFields, allowedFields)
if err != nil {
return "", err
}

// Determine direction (default to asc)
direction := "asc"
if len(order) == 2 {
direction = strings.ToLower(order[1])
if direction != "asc" && direction != "desc" {
err = errors.BadRequest("bad order value '%s'", userArg)
return
}
}

orderBy = fmt.Sprintf("%s %s", orderField, direction)
return
}

// ArgsToOrderBy returns cleaned orderBy list.
func ArgsToOrderBy(
orderByArgs []string,
disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) {

var order string
if len(orderByArgs) != 0 {
orderBy = []string{}
for _, o := range orderByArgs {
order, err = cleanOrderBy(o, disallowedFields)
if err != nil {
return
}
disallowedFields map[string]string,
allowedFields map[string]bool,
) (orderBy []string, err *errors.ServiceError) {
if len(orderByArgs) == 0 {
return nil, nil
}

// If valid add the user entered order by, to the order by list
orderBy = append(orderBy, order)
orderBy = make([]string, 0, len(orderByArgs))
for _, arg := range orderByArgs {
order, err := cleanOrderBy(arg, disallowedFields, allowedFields)
if err != nil {
return nil, err
}
orderBy = append(orderBy, order)
}
return
return orderBy, nil
}

func GetTableName(g2 *gorm.DB) string {
Expand Down
4 changes: 2 additions & 2 deletions pkg/db/sql_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ func TestGetField_SpecMapping(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
RegisterTestingT(t)

field, err := getField(tt.input, map[string]string{})
field, err := getField(tt.input, map[string]string{}, nil)
if tt.expectError {
Expect(err).ToNot(BeNil())
} else {
Expand All @@ -719,7 +719,7 @@ func TestGetField_SpecDisallowed(t *testing.T) {

disallowed := map[string]string{"spec": "spec"}

_, err := getField("spec.is_default", disallowed)
_, err := getField("spec.is_default", disallowed, nil)
Expect(err).ToNot(BeNil())
Expect(err.Reason).To(ContainSubstring("not a valid field name"))
}
Expand Down
13 changes: 11 additions & 2 deletions pkg/services/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type listContext struct {
args *ListArguments
pagingMeta *api.PagingMeta
disallowedFields *map[string]string
allowedFields *map[string]bool
joins map[string]dao.TableRelation
set map[string]bool
resourceType string
Expand All @@ -79,13 +80,21 @@ func (s *sqlGenericService) newListContext(
if disallowedFields == nil {
disallowedFields = allFieldsAllowed
}

allowedFields := db.OrderByAllowedFields[resourceTypeStr]

if allowedFields == nil {
return nil, nil, errors.GeneralError("Could not determine what resource type to order by")
}

args.Search = strings.Trim(args.Search, " ")
return &listContext{
ctx: ctx,
args: args,
pagingMeta: &api.PagingMeta{Page: args.Page},
resourceList: resourceList,
disallowedFields: &disallowedFields,
allowedFields: &allowedFields,
resourceType: resourceTypeStr,
}, reflect.New(resourceModel).Interface(), nil
}
Expand Down Expand Up @@ -150,7 +159,7 @@ func (s *sqlGenericService) buildPreload(listCtx *listContext, d *dao.GenericDao

func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao) (bool, *errors.ServiceError) {
if len(listCtx.args.OrderBy) != 0 {
orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields)
orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields, *listCtx.allowedFields)
if serviceErr != nil {
return false, serviceErr
}
Expand Down Expand Up @@ -196,7 +205,7 @@ func (s *sqlGenericService) buildSearchValues(

// apply field name mapping first (status.xxx -> status_xxx, labels.xxx -> labels->>'xxx')
// this must happen before treeWalkForRelatedTables to prevent treating "status" and "labels" as related resources
tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields)
tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields, *listCtx.allowedFields)
if serviceErr != nil {
return "", nil, serviceErr
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/services/generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ func TestSQLTranslation(t *testing.T) {
// tests for sql parsing
tests = []map[string]interface{}{
{
"search": "username in ('ooo.openshift')",
"sql": "username IN (?)",
"search": "created_by in ('ooo.openshift')",
"sql": "created_by IN (?)",
"values": ConsistOf("ooo.openshift"),
},
// Test status.conditions field mapping (use status.conditions.<Type>='<Status>' syntax for condition queries)
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestSQLTranslation(t *testing.T) {
Expect(err).ToNot(HaveOccurred())
// Apply field name mapping (status.xxx -> status_xxx, labels.xxx -> labels->>'xxx')
// This must happen before converting to sqlizer
tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields)
tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields, *listCtx.allowedFields)
Expect(serviceErr).ToNot(HaveOccurred())
sqlizer, serviceErr := genericService.treeWalkForSqlizer(listCtx, tslTree)
Expect(serviceErr).ToNot(HaveOccurred())
Expand Down
Loading