Skip to content
17 changes: 17 additions & 0 deletions sagemaker-core/sample/sagemaker/2017-07-24/service-2.json
Original file line number Diff line number Diff line change
Expand Up @@ -44132,6 +44132,10 @@
"EvaluatorArn":{
"shape":"EvaluatorArn",
"documentation":"<p> The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. </p>"
},
"SequenceLength":{
"shape":"SequenceLength",
"documentation":"<p> The sequence length for the training job. </p>"
}
},
"documentation":"<p> The configuration for the serverless training job. </p>"
Expand All @@ -44143,6 +44147,19 @@
"Evaluation"
]
},
"SequenceLength":{
"type":"string",
"enum":[
"1K",
"2K",
"4K",
"8K",
"16K",
"32K",
"64K",
"128K"
]
},
"ServerlessMaxConcurrency":{
"type":"integer",
"box":true,
Expand Down
2 changes: 2 additions & 0 deletions sagemaker-core/src/sagemaker/core/shapes/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9717,6 +9717,7 @@ class ServerlessJobConfig(Base):
peft: The parameter-efficient fine-tuning configuration.
evaluation_type: The evaluation job type. Required when serverless job type is Evaluation.
evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.
sequence_length: The sequence length for the training job.
"""

base_model_arn: StrPipeVar
Expand All @@ -9726,6 +9727,7 @@ class ServerlessJobConfig(Base):
peft: Optional[StrPipeVar] = Unassigned()
evaluation_type: Optional[StrPipeVar] = Unassigned()
evaluator_arn: Optional[StrPipeVar] = Unassigned()
sequence_length: Optional[StrPipeVar] = Unassigned()


class MlflowConfig(Base):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16206,6 +16206,7 @@
{"name": "Peft", "shape": "Peft", "type": "string"},
{"name": "EvaluationType", "shape": "EvaluationType", "type": "string"},
{"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"},
{"name": "SequenceLength", "shape": "SequenceLength", "type": "string"},
],
"type": "structure",
},
Expand Down
66 changes: 62 additions & 4 deletions sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:
return None


def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
hub_name: Optional[str] = None) -> tuple:
def _parse_context_length(value) -> int:
"""Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192).

Returns 0 if value is None or unparseable.
"""
if not value:
return 0
value = str(value).strip().upper()
if value.endswith("K"):
try:
return int(value[:-1]) * 1024
except ValueError:
return 0
try:
return int(value)
except ValueError:
return 0


def _get_fine_tuning_options_and_model_arn(
model_name: str,
customization_technique: str,
training_type,
sagemaker_session,
sequence_length=None,
hub_name: Optional[str] = None
) -> tuple:
"""Get fine-tuning options and model ARN for given customization technique.

Args:
model_name: Name of the model in the hub.
customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF").
training_type: TrainingType enum or string ("LORA", "FULL").
sagemaker_session: SageMaker session for API calls.
sequence_length: Optional sequence length (e.g., "8K"). When provided, filters
recipes by MaxContextLength >= the requested value.
hub_name: Hub name (default: "SageMakerPublicHub").

Returns:
tuple: (FineTuningOptions, model_arn, is_gated_model)
"""
Expand Down Expand Up @@ -447,6 +481,27 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
if not recipes_with_template:
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")

# Filter by SequenceLength before recipe selection if sequence_length is requested
if sequence_length:
requested = _parse_context_length(sequence_length)
candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")]
if candidates_with_context:
filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested]
if filtered:
filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength")))
recipes_with_template = filtered
else:
available = sorted(set(r.get("SequenceLength") for r in candidates_with_context))
raise ValueError(
f"No recipes found with SequenceLength >= {sequence_length}. "
f"Available sequence lengths: {available}"
)
else:
raise ValueError(
f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, "
f"and sequence length:{sequence_length}"
)

# Select recipe based on training type
# Collect override_params from ALL matching recipes (standard + subscription)
recipe = None
Expand Down Expand Up @@ -608,7 +663,8 @@ def _resolve_model_and_name(model, sagemaker_session=None):


def _create_serverless_config(model_arn, customization_technique,
training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
training_type, accept_eula, evaluator_arn=None,
sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
"""Create serverless job configuration for fine-tuning.

Args:
Expand All @@ -617,6 +673,7 @@ def _create_serverless_config(model_arn, customization_technique,
training_type: Training type (TrainingType enum or string)
accept_eula: Boolean indicating if EULA is accepted
evaluator_arn: Optional evaluator ARN for RLVR/RLAIF
sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K")
job_type: Type of job (default: "FineTuning")

Returns:
Expand All @@ -632,7 +689,8 @@ def _create_serverless_config(model_arn, customization_technique,
customization_technique=customization_technique,
peft=peft,
evaluator_arn=evaluator_arn,
accept_eula=accept_eula
accept_eula=accept_eula,
sequence_length=sequence_length,
)

return serverless_config
Expand Down
36 changes: 22 additions & 14 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
sequence_length (Optional[str]):
The sequence length for the training job. Valid values are
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
If not specified, the service will use default recipe selection behavior.
"""
def __init__(
self,
Expand All @@ -116,6 +120,7 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -134,16 +139,17 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
self.sequence_length = sequence_length

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
CustomizationTechnique.DPO.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session

))
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
self._model_name,
CustomizationTechnique.DPO.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
sequence_length=self.sequence_length
)

# Process hyperparameters
self._process_hyperparameters()

Expand Down Expand Up @@ -227,12 +233,14 @@ def train(self,
kms_key_id=self.kms_key_id
)

serverless_config = _create_serverless_config(model_arn=self._model_arn,
customization_technique=CustomizationTechnique.DPO.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
job_type=JOB_TYPE
)
serverless_config = _create_serverless_config(
model_arn=self._model_arn,
customization_technique=CustomizationTechnique.DPO.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
sequence_length=self.sequence_length,
job_type=JOB_TYPE
)

mlflow_config = _create_mlflow_config(
sagemaker_session,
Expand Down
35 changes: 22 additions & 13 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
sequence_length (Optional[str]):
The sequence length for the training job. Valid values are
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
If not specified, the service will use default recipe selection behavior.
"""

def __init__(
Expand All @@ -135,6 +139,7 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -156,14 +161,16 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
self.sequence_length = sequence_length

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
CustomizationTechnique.RLAIF.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
))
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
self._model_name,
CustomizationTechnique.RLAIF.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
sequence_length=self.sequence_length
)

# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
Expand Down Expand Up @@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

evaluator_arn = getattr(self, '_evaluator_arn', None)
serverless_config = _create_serverless_config(model_arn=self._model_arn,
customization_technique=CustomizationTechnique.RLAIF.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
evaluator_arn=evaluator_arn,
job_type=JOB_TYPE
)
serverless_config = _create_serverless_config(
model_arn=self._model_arn,
customization_technique=CustomizationTechnique.RLAIF.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
evaluator_arn=evaluator_arn,
sequence_length=self.sequence_length,
job_type=JOB_TYPE
)

mlflow_config = _create_mlflow_config(
sagemaker_session,
Expand Down
37 changes: 23 additions & 14 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
sequence_length (Optional[str]):
The sequence length for the training job. Valid values are
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
If not specified, the service will use default recipe selection behavior.
"""

def __init__(
Expand All @@ -126,6 +130,7 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -146,15 +151,17 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
self.sequence_length = sequence_length

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
CustomizationTechnique.RLVR.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
))

self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
self._model_name,
CustomizationTechnique.RLVR.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
sequence_length=self.sequence_length
)

# Remove constructor-handled hyperparameters
self._process_hyperparameters()

Expand Down Expand Up @@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,

# Extract and validate evaluator ARN
evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None
serverless_config = _create_serverless_config(model_arn=self._model_arn,
customization_technique=CustomizationTechnique.RLVR.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
evaluator_arn=evaluator_arn,
job_type=JOB_TYPE
)
serverless_config = _create_serverless_config(
model_arn=self._model_arn,
customization_technique=CustomizationTechnique.RLVR.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
evaluator_arn=evaluator_arn,
sequence_length=self.sequence_length,
job_type=JOB_TYPE
)
mlflow_config = _create_mlflow_config(
sagemaker_session,
mlflow_resource_arn=self.mlflow_resource_arn,
Expand Down
Loading
Loading