Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sagemaker-core/src/sagemaker/core/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,5 @@ def get_user_agent_extra_suffix():
suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)

return suffix

# Trigger PR check: run full integ test suite.
54 changes: 52 additions & 2 deletions sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
import random
import logging
from datetime import datetime, timezone, timedelta
from urllib.parse import urlparse

import boto3
Expand All @@ -43,10 +44,59 @@ def role_arn():
return get_execution_role()


# Prefix used for all provisioned throughputs created by this test module.
PT_TEST_PREFIX = "test-pt-integ-"
# Provisioned throughputs older than this are considered leaked and reaped on setup.
PT_STALE_AGE = timedelta(hours=2)


@pytest.fixture(scope="module")
def bedrock_client():
"""Create Bedrock client."""
return boto3.client("bedrock", region_name=AWS_REGION)
"""Create Bedrock client and eagerly reap leaked test provisioned throughputs.

Provisioned throughputs cost money and consume a small, easily-exhausted
model-unit quota. A test process killed before its teardown runs (CodeBuild
timeout, worker crash, etc.) leaks its PT, and these accumulate across runs
until the quota is full and CreateProvisionedModelThroughput starts failing.

To stay self-healing, on setup we delete any ``test-pt-integ-*`` PT older
than PT_STALE_AGE. The age guard avoids racing a PT that another concurrent
run just created.
"""
client = boto3.client("bedrock", region_name=AWS_REGION)

try:
cutoff = datetime.now(timezone.utc) - PT_STALE_AGE
paginator_token = None
while True:
params = {"maxResults": 100}
if paginator_token:
params["nextToken"] = paginator_token
response = client.list_provisioned_model_throughputs(**params)
for pt in response.get("provisionedModelSummaries", []):
name = pt.get("provisionedModelName", "")
if not name.startswith(PT_TEST_PREFIX):
continue
created = pt.get("creationTime")
if created and created >= cutoff:
continue
# Only InService/Failed PTs can be deleted.
if pt.get("status") not in ("InService", "Failed"):
continue
try:
logger.info("Eager cleanup of stale provisioned throughput: %s", name)
client.delete_provisioned_model_throughput(
provisionedModelId=pt["provisionedModelArn"]
)
except Exception as e:
logger.warning("Eager cleanup failed for %s: %s", name, e)
paginator_token = response.get("nextToken")
if not paginator_token:
break
except Exception as e:
logger.warning("Failed to list provisioned throughputs for eager cleanup: %s", e)

return client


@pytest.fixture(scope="module")
Expand Down
45 changes: 37 additions & 8 deletions sagemaker-train/tests/integ/train/test_mtrl_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,37 @@ def test_config():


def _ensure_model_package_group_exists(sm_client, group_name):
"""Create the model package group if it doesn't already exist."""
"""Create the model package group if it doesn't already exist.

Race-safe: with pytest-xdist (`-n auto`) multiple workers run this
concurrently, so a plain check-then-create races. If another worker wins
the create, CreateModelPackageGroup raises "already exists"; treat that as
success rather than letting the fixture error out.
"""
try:
sm_client.describe_model_package_group(ModelPackageGroupName=group_name)
return
except Exception:
pass

try:
sm_client.create_model_package_group(
ModelPackageGroupName=group_name,
ModelPackageGroupDescription="Auto-created for MTRL evaluator integ tests",
)
except Exception as e:
# Another concurrent worker created it between our describe and create.
if "already exists" in str(e):
return
raise


def _ensure_model_package_exists(sm_client, group_name, base_model_name):
"""Create a model package in the group if none exists, for test purposes."""
"""Create a model package in the group if none exists, for test purposes.

Race-safe: if a concurrent worker creates one between our list and create,
fall back to listing again and reusing whatever package now exists.
"""
resp = sm_client.list_model_packages(
ModelPackageGroupName=group_name,
MaxResults=1,
Expand All @@ -80,12 +99,22 @@ def _ensure_model_package_exists(sm_client, group_name, base_model_name):
return resp["ModelPackageSummaryList"][0]["ModelPackageArn"]

# Create a minimal unversioned model package (no InferenceSpecification needed)
resp = sm_client.create_model_package(
ModelPackageGroupName=group_name,
ModelPackageDescription="Test model package for MTRL evaluator integ tests",
ModelApprovalStatus="Approved",
)
return resp["ModelPackageArn"]
try:
resp = sm_client.create_model_package(
ModelPackageGroupName=group_name,
ModelPackageDescription="Test model package for MTRL evaluator integ tests",
ModelApprovalStatus="Approved",
)
return resp["ModelPackageArn"]
except Exception:
# A concurrent worker may have created one; reuse the existing package.
resp = sm_client.list_model_packages(
ModelPackageGroupName=group_name,
MaxResults=1,
)
if resp.get("ModelPackageSummaryList"):
return resp["ModelPackageSummaryList"][0]["ModelPackageArn"]
raise


@pytest.fixture(scope="module")
Expand Down
Loading