diff --git a/.github/workflows/medcat-trainer_ci.yml b/.github/workflows/medcat-trainer_ci.yml index 4a7c15951..0074b64b3 100644 --- a/.github/workflows/medcat-trainer_ci.yml +++ b/.github/workflows/medcat-trainer_ci.yml @@ -73,15 +73,22 @@ jobs: medcat-trainer/client/htmlcov if-no-files-found: ignore + - name: Set release version from tag + if: startsWith(github.ref, 'refs/tags') + run: | + VERSION="${GITHUB_REF_NAME#medcat-trainer/v}" + sed -i "s/^version = .*/version = \"$VERSION\"/" client/pyproject.toml + echo "Release version: $VERSION" + + - name: Set dev version for TestPyPI + if: github.ref == 'refs/heads/main' + run: sed -i "s/^version = .*/version = \"0.0.0.dev$(date +%s)\"/" client/pyproject.toml + - name: Build client package run: | cd client python -m build - - name: Bump version for TestPyPI - if: github.ref == 'refs/heads/main' - run: sed -i "s/^version = .*/version = \"1.0.0.dev$(date +%s)\"/" client/pyproject.toml - - name: Publish dev distribution to Test PyPI uses: pypa/gh-action-pypi-publish@release/v1 if: github.ref == 'refs/heads/main' @@ -93,7 +100,6 @@ jobs: - name: Publish production distribution to PyPI if: startsWith(github.ref, 'refs/tags') && ! github.event.release.prerelease uses: pypa/gh-action-pypi-publish@release/v1 - continue-on-error: true with: packages_dir: medcat-trainer/client/dist diff --git a/medcat-trainer/client/mctclient.py b/medcat-trainer/client/mctclient.py index 1af6b17a8..b081505ed 100644 --- a/medcat-trainer/client/mctclient.py +++ b/medcat-trainer/client/mctclient.py @@ -2,6 +2,7 @@ from datetime import datetime import json import os +import time from abc import ABC from typing import Any, Dict, List, Optional, Tuple, Union @@ -287,6 +288,9 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings= self.username = username or os.getenv("MCTRAINER_USERNAME") self.password = password or os.getenv("MCTRAINER_PASSWORD") self.server = server or 'http://localhost:8001' + self._keycloak_settings: Optional[KeycloakSettings] = None + # positive infinity so that token refresh is never triggered for non-OIDC sessions + self._token_expiry: float = float('inf') env_use_oidc = os.getenv("MCTRAINER_USE_OIDC", "") env_use_oidc_truthy = env_use_oidc.strip() == "1" @@ -299,9 +303,8 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings= if not isinstance(keycloak_settings, KeycloakSettings): raise TypeError("keycloak_settings must be a KeycloakSettings instance") kc_settings = keycloak_settings - - token = get_keycloak_access_token(kc_settings) - self.headers = {"Authorization": f"Bearer {token}"} + self._keycloak_settings = kc_settings + self._refresh_oidc_token() return payload = {"username": self.username, "password": self.password} @@ -314,6 +317,18 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings= else: raise MCTUtilsException(f"Failed to login to MedCATtrainer instance running at: {self.server}") + def _refresh_oidc_token(self) -> None: + """Fetch a new OIDC access token and update the Authorization header.""" + token = get_keycloak_access_token(self._keycloak_settings) + self.headers = {"Authorization": f"Bearer {token}"} + # Refresh 60s before the typical 5-minute Keycloak access token lifetime (i.e. 4 minutes from now) + self._token_expiry = time.monotonic() + 240 + + def ensure_token_fresh(self) -> None: + """Refresh the OIDC token if it is near expiry. No-op for non-OIDC sessions.""" + if self._keycloak_settings is not None and time.monotonic() >= self._token_expiry: + self._refresh_oidc_token() + def create_project(self, name: str, description: str, members: Union[List[MCTUser], List[str]], diff --git a/medcat-trainer/client/pyproject.toml b/medcat-trainer/client/pyproject.toml index a327f45de..ebbe9d280 100644 --- a/medcat-trainer/client/pyproject.toml +++ b/medcat-trainer/client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "medcattrainer-client" -version = "1.3.0" +version = "1.4.0" description = "Python client for interacting with a MedCATTrainer instance" readme = "client/README.md" requires-python = ">=3.10" diff --git a/medcat-trainer/client/tests/test_mctclient.py b/medcat-trainer/client/tests/test_mctclient.py index 7512e97a3..e1d8a91fe 100644 --- a/medcat-trainer/client/tests/test_mctclient.py +++ b/medcat-trainer/client/tests/test_mctclient.py @@ -1,4 +1,5 @@ import json +import time import unittest from unittest.mock import patch, MagicMock from mctclient import ( @@ -748,5 +749,60 @@ def post_side_effect(url, *args, **kwargs): ) self.assertEqual(result, mock_upload_response) + @patch('mctclient.requests.post') + def test_ensure_token_fresh_is_noop_for_non_oidc_session(self, mock_post): + mock_post.return_value = MagicMock(status_code=200, text='{"token": "abc"}') + session = MedCATTrainerSession(server='http://localhost', username='u', password='p') + original_headers = dict(session.headers) + session.ensure_token_fresh() + self.assertEqual(session.headers, original_headers) + # Only the initial DRF auth call should have been made + mock_post.assert_called_once() + + @patch('mctclient.requests.post') + def test_ensure_token_fresh_refreshes_expired_oidc_token(self, mock_post): + call_count = [0] + + def post_side_effect(url, *args, **kwargs): + if url.endswith('/protocol/openid-connect/token'): + call_count[0] += 1 + return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"}) + return MagicMock(status_code=404, text='') + + mock_post.side_effect = post_side_effect + + kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s') + session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc) + self.assertEqual(session.headers['Authorization'], 'Bearer token-1') + + # Simulate expiry by setting _token_expiry to the past + session._token_expiry = time.monotonic() - 1 + + session.ensure_token_fresh() + self.assertEqual(session.headers['Authorization'], 'Bearer token-2') + self.assertEqual(call_count[0], 2) + + @patch('mctclient.requests.post') + def test_ensure_token_fresh_does_not_refresh_valid_oidc_token(self, mock_post): + call_count = [0] + + def post_side_effect(url, *args, **kwargs): + if url.endswith('/protocol/openid-connect/token'): + call_count[0] += 1 + return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"}) + return MagicMock(status_code=404, text='') + + mock_post.side_effect = post_side_effect + + kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s') + session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc) + self.assertEqual(call_count[0], 1) + + # Token is still fresh (_token_expiry is in the future) + session.ensure_token_fresh() + self.assertEqual(call_count[0], 1) + self.assertEqual(session.headers['Authorization'], 'Bearer token-1') + + if __name__ == '__main__': unittest.main() \ No newline at end of file