Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions argus/backend/models/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ArgusReleasePlan(Model):
creation_time = columns.DateTime(default=lambda: datetime.datetime.now(tz=datetime.UTC))
last_updated = columns.DateTime(default=lambda: datetime.datetime.now(tz=datetime.UTC))
ends_at = columns.DateTime()
key = columns.Text()

def __eq__(self, other):
if isinstance(other, ArgusReleasePlan):
Expand Down
36 changes: 27 additions & 9 deletions argus/backend/service/planner_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@ class PlanningService:
def version(self):
return "v1"

def _generate_plan_key(self, release_id: UUID | str) -> str:
release: ArgusRelease = ArgusRelease.get(id=release_id)
candidate = f"{release.name}#1"
release_plans = list(ArgusReleasePlan.filter(release_id=release.id).allow_filtering().all())
if len(release_plans) == 0:
return candidate
existing_keys = [int(p.key.split("#")[1]) for p in release_plans]
previous_number = max(existing_keys)

return f"{release.name}#{previous_number+1}"

def _resolve_plan(self, ref: str | UUID) -> ArgusReleasePlan:
try:
return ArgusReleasePlan.get(id=UUID(str(ref)))
except (ValueError, ArgusReleasePlan.DoesNotExist):
return ArgusReleasePlan.filter(key=str(ref)).allow_filtering().get()

def create_plan(self, payload: dict[str, Any]) -> ArgusReleasePlan:
plan_request = CreatePlanPayload(**payload)

Expand Down Expand Up @@ -171,14 +188,15 @@ def create_plan(self, payload: dict[str, Any]) -> ArgusReleasePlan:
plan.view_id = plan_request.view_id
view = self.update_view_for_plan(plan, existing=True)

plan.key = self._generate_plan_key(plan.release_id)
plan.save()
invalidate_release_snapshots(plan.release_id)
return plan

def update_plan(self, payload: dict[str, Any]) -> bool:
plan_request = PlanDiffPayload(**payload)

plan: ArgusReleasePlan = ArgusReleasePlan.get(id=plan_request.id)
plan: ArgusReleasePlan = self._resolve_plan(plan_request.id)

if plan_request.name is not None:
plan.name = plan_request.name
Expand All @@ -196,7 +214,7 @@ def update_plan(self, payload: dict[str, Any]) -> bool:
try:
existing = ArgusReleasePlan.filter(
name=plan.name, target_version=plan.target_version).allow_filtering().get()
if existing and existing.id != UUID(plan_request.id):
if existing and existing.id != plan.id:
raise PlannerServiceException(
f"Found existing plan {existing.name} ({existing.target_version}) with the same name and version", existing, plan_request)
except ArgusReleasePlan.DoesNotExist:
Expand Down Expand Up @@ -339,7 +357,7 @@ def create_view_for_plan(self, plan: ArgusReleasePlan) -> ArgusUserView:

def change_plan_owner(self, plan_id: UUID | str, new_owner: UUID | str) -> bool:
user: User = User.get(id=new_owner)
plan: ArgusReleasePlan = ArgusReleasePlan.get(id=plan_id)
plan: ArgusReleasePlan = self._resolve_plan(plan_id)

plan.owner = user.id
plan.last_updated = datetime.datetime.now(tz=datetime.UTC)
Expand All @@ -348,7 +366,7 @@ def change_plan_owner(self, plan_id: UUID | str, new_owner: UUID | str) -> bool:
return True

def get_plan(self, plan_id: str | UUID) -> ArgusReleasePlan:
return ArgusReleasePlan.get(id=plan_id)
return self._resolve_plan(plan_id)

def get_gridview_for_release(self, release_id: str | UUID) -> dict[str, dict]:
release = ArgusRelease.get(id=release_id)
Expand Down Expand Up @@ -391,8 +409,7 @@ def copy_plan(self, payload: CopyPlanPayload) -> ArgusReleasePlan:
except ArgusReleasePlan.DoesNotExist:
pass

original_plan: ArgusReleasePlan = ArgusReleasePlan.get(
id=payload.plan.id)
original_plan: ArgusReleasePlan = self._resolve_plan(payload.plan.id)
target_release: ArgusRelease = ArgusRelease.get(
id=payload.targetReleaseId)
original_release: ArgusRelease = ArgusRelease.get(
Expand Down Expand Up @@ -452,13 +469,14 @@ def copy_plan(self, payload: CopyPlanPayload) -> ArgusReleasePlan:
view = self.create_view_for_plan(new_plan)
new_plan.view_id = view.id

new_plan.key = self._generate_plan_key(target_release.id)
new_plan.save()
invalidate_release_snapshots(new_plan.release_id)
return new_plan

def check_plan_copy_eligibility(self, plan_id: str | UUID, target_release_id: str | UUID) -> dict:
target_release: ArgusRelease = ArgusRelease.get(id=target_release_id)
plan: ArgusReleasePlan = ArgusReleasePlan.get(id=plan_id)
plan: ArgusReleasePlan = self._resolve_plan(plan_id)
original_release: ArgusRelease = ArgusRelease.get(id=plan.release_id)

original_tests: list[ArgusTest] = ArgusTest.filter(
Expand Down Expand Up @@ -522,7 +540,7 @@ def get_plans_for_release(self, release_id: str | UUID) -> list[ArgusReleasePlan
return list(ArgusReleasePlan.filter(release_id=release_id).all())

def delete_plan(self, plan_id: str | UUID, delete_view: bool = True):
plan: ArgusReleasePlan = ArgusReleasePlan.get(id=plan_id)
plan: ArgusReleasePlan = self._resolve_plan(plan_id)
if plan.view_id:
view: ArgusUserView = ArgusUserView.get(id=plan.view_id)
if delete_view:
Expand Down Expand Up @@ -607,7 +625,7 @@ def complete_plan(self, plan_id: str | UUID) -> bool:
return plan.completed

def resolve_plan(self, plan_id: str | UUID) -> list[dict[str, Any]]:
plan: ArgusReleasePlan = ArgusReleasePlan.get(id=plan_id)
plan: ArgusReleasePlan = self._resolve_plan(plan_id)

release: ArgusRelease = ArgusRelease.get(id=plan.release_id)
tests: list[ArgusTest] = []
Expand Down
93 changes: 68 additions & 25 deletions argus/client/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import logging
from dataclasses import asdict
from pathlib import Path
from typing import Any, Type
from uuid import UUID

Expand All @@ -9,6 +10,7 @@
from argus.common.enums import TestStatus
from argus.client.session import create_session
from argus.client.generic_result import GenericResultTable
from argus.client.replay_log import ReplayLog, ReplayLogOnlyResponse
from argus.client.sct.types import LogLink

JSON = dict[str, Any] | list[Any] | int | str | float | bool | Type[None]
Expand All @@ -34,24 +36,52 @@ class Routes():
FETCH_RESULTS = "/testrun/$type/$id/fetch_results"
FINALIZE = "/testrun/$type/$id/finalize"

def __init__(self, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None,
timeout: int = 60, max_retries: int = 3, use_tunnel: bool | None = None) -> None:
# Subclasses override ``test_type`` as a class attribute; ``run_id`` is
# set on the instance by subclass constructors. Both are surfaced in the
# replay-log filename.
test_type: str | None = None

def __init__(self, auth_token: str, base_url: str, log_dir: str | Path, api_version="v1",
extra_headers: dict | None = None, timeout: int = 60, max_retries: int = 3,
use_tunnel: bool | None = None, replay_log_only: bool = False,
run_id: UUID | str | None = None) -> None:
self._auth_token = auth_token
self._base_url = base_url
self._api_ver = api_version
self._timeout = timeout
self.session = create_session(
auth_token=auth_token,
base_url=base_url,
use_tunnel=use_tunnel,
max_retries=max_retries,
self._replay_log_only = replay_log_only
# Set run_id on the instance so subclasses that read ``self.run_id``
# later see the explicit value, not the class-attribute default.
if run_id is not None:
self.run_id = run_id
# In replay-log-only mode no HTTP calls are made, so skip opening a
# session (and any SSH tunnel that might come with it).
if replay_log_only:
self.session = None
else:
self.session = create_session(
auth_token=auth_token,
base_url=base_url,
use_tunnel=use_tunnel,
max_retries=max_retries,
)
if extra_headers:
self.session.headers.update(extra_headers)

self._replay_log = ReplayLog(
log_dir=log_dir,
run_id=str(run_id) if run_id is not None else None,
test_type=self.test_type,
)

if extra_headers:
self.session.headers.update(extra_headers)
@property
def replay_log_path(self) -> Path:
return self._replay_log.path

def close(self) -> None:
self.session.close()
if self.session is not None:
self.session.close()
self._replay_log.close()

def __enter__(self) -> "ArgusClient":
return self
Expand Down Expand Up @@ -111,6 +141,12 @@ def request_headers(self):
}

def get(self, endpoint: str, location_params: dict[str, str] = None, params: dict = None) -> requests.Response:
# In replay-log-only mode no HTTP call is made; behave like a mock so
# callers (e.g. SCT tests that previously used ``MagicMock``) do not
# have to special-case GETs.
if self._replay_log_only:
LOGGER.debug("GET [replay-log-only] %s params: %s", endpoint, params)
return ReplayLogOnlyResponse(endpoint=endpoint)
url = self.get_url_for_endpoint(
endpoint=endpoint,
location_params=location_params
Expand All @@ -133,21 +169,28 @@ def post(
params: dict = None,
body: dict = None,
) -> requests.Response:
url = self.get_url_for_endpoint(
endpoint=endpoint,
location_params=location_params
)
LOGGER.debug("POST Request: %s, params: %s, body: %s", url, params, body)
response = self.session.post(
url=url,
params=params,
json=body,
headers=self.request_headers,
timeout=self._timeout
)
LOGGER.debug("POST Response: %s %s", response.status_code, response.url)

return response
with self._replay_log.record("POST", endpoint, location_params, params, body) as rec:
if self._replay_log_only:
# Record the request so a future replay can re-send it, but
# skip the HTTP call. ``rec`` stays at success=False (default).
LOGGER.debug("POST [replay-log-only] %s body: %s", endpoint, body)
return ReplayLogOnlyResponse(endpoint=endpoint)

url = self.get_url_for_endpoint(
endpoint=endpoint,
location_params=location_params
)
LOGGER.debug("POST Request: %s, params: %s, body: %s", url, params, body)
response = self.session.post(
url=url,
params=params,
json=body,
headers=self.request_headers,
timeout=self._timeout
)
LOGGER.debug("POST Response: %s %s", response.status_code, response.url)
rec.record(response)
return response

def submit_run(self, run_type: str, run_body: dict) -> requests.Response:
return self.post(endpoint=self.Routes.SUBMIT, location_params={"type": run_type}, body={
Expand Down
Loading
Loading