Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions .github/workflows/medcat-trainer_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,22 @@ jobs:
medcat-trainer/client/htmlcov
if-no-files-found: ignore

- name: Set release version from tag
if: startsWith(github.ref, 'refs/tags')
run: |
VERSION="${GITHUB_REF_NAME#medcat-trainer/v}"
sed -i "s/^version = .*/version = \"$VERSION\"/" client/pyproject.toml
echo "Release version: $VERSION"

- name: Set dev version for TestPyPI
if: github.ref == 'refs/heads/main'
run: sed -i "s/^version = .*/version = \"0.0.0.dev$(date +%s)\"/" client/pyproject.toml

- name: Build client package
run: |
cd client
python -m build

- name: Bump version for TestPyPI
if: github.ref == 'refs/heads/main'
run: sed -i "s/^version = .*/version = \"1.0.0.dev$(date +%s)\"/" client/pyproject.toml

- name: Publish dev distribution to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
if: github.ref == 'refs/heads/main'
Expand All @@ -93,7 +100,6 @@ jobs:
- name: Publish production distribution to PyPI
if: startsWith(github.ref, 'refs/tags') && ! github.event.release.prerelease
uses: pypa/gh-action-pypi-publish@release/v1
continue-on-error: true
with:
packages_dir: medcat-trainer/client/dist

Expand Down
21 changes: 18 additions & 3 deletions medcat-trainer/client/mctclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime
import json
import os
import time
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -287,6 +288,9 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
self.username = username or os.getenv("MCTRAINER_USERNAME")
self.password = password or os.getenv("MCTRAINER_PASSWORD")
self.server = server or 'http://localhost:8001'
self._keycloak_settings: Optional[KeycloakSettings] = None
# positive infinity so that token refresh is never triggered for non-OIDC sessions
self._token_expiry: float = float('inf')

env_use_oidc = os.getenv("MCTRAINER_USE_OIDC", "")
env_use_oidc_truthy = env_use_oidc.strip() == "1"
Expand All @@ -299,9 +303,8 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
if not isinstance(keycloak_settings, KeycloakSettings):
raise TypeError("keycloak_settings must be a KeycloakSettings instance")
kc_settings = keycloak_settings

token = get_keycloak_access_token(kc_settings)
self.headers = {"Authorization": f"Bearer {token}"}
self._keycloak_settings = kc_settings
self._refresh_oidc_token()
return

payload = {"username": self.username, "password": self.password}
Expand All @@ -314,6 +317,18 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
else:
raise MCTUtilsException(f"Failed to login to MedCATtrainer instance running at: {self.server}")

def _refresh_oidc_token(self) -> None:
"""Fetch a new OIDC access token and update the Authorization header."""
token = get_keycloak_access_token(self._keycloak_settings)
self.headers = {"Authorization": f"Bearer {token}"}
# Refresh 60s before the typical 5-minute Keycloak access token lifetime (i.e. 4 minutes from now)
self._token_expiry = time.monotonic() + 240

def ensure_token_fresh(self) -> None:
"""Refresh the OIDC token if it is near expiry. No-op for non-OIDC sessions."""
if self._keycloak_settings is not None and time.monotonic() >= self._token_expiry:
self._refresh_oidc_token()

def create_project(self, name: str,
description: str,
members: Union[List[MCTUser], List[str]],
Expand Down
2 changes: 1 addition & 1 deletion medcat-trainer/client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "medcattrainer-client"
version = "1.3.0"
version = "1.4.0"
description = "Python client for interacting with a MedCATTrainer instance"
readme = "client/README.md"
requires-python = ">=3.10"
Expand Down
56 changes: 56 additions & 0 deletions medcat-trainer/client/tests/test_mctclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
import unittest
from unittest.mock import patch, MagicMock
from mctclient import (
Expand Down Expand Up @@ -748,5 +749,60 @@ def post_side_effect(url, *args, **kwargs):
)
self.assertEqual(result, mock_upload_response)

@patch('mctclient.requests.post')
def test_ensure_token_fresh_is_noop_for_non_oidc_session(self, mock_post):
mock_post.return_value = MagicMock(status_code=200, text='{"token": "abc"}')
session = MedCATTrainerSession(server='http://localhost', username='u', password='p')
original_headers = dict(session.headers)
session.ensure_token_fresh()
self.assertEqual(session.headers, original_headers)
# Only the initial DRF auth call should have been made
mock_post.assert_called_once()

@patch('mctclient.requests.post')
def test_ensure_token_fresh_refreshes_expired_oidc_token(self, mock_post):
call_count = [0]

def post_side_effect(url, *args, **kwargs):
if url.endswith('/protocol/openid-connect/token'):
call_count[0] += 1
return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"})
return MagicMock(status_code=404, text='')

mock_post.side_effect = post_side_effect

kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s')
session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc)
self.assertEqual(session.headers['Authorization'], 'Bearer token-1')

# Simulate expiry by setting _token_expiry to the past
session._token_expiry = time.monotonic() - 1

session.ensure_token_fresh()
self.assertEqual(session.headers['Authorization'], 'Bearer token-2')
self.assertEqual(call_count[0], 2)

@patch('mctclient.requests.post')
def test_ensure_token_fresh_does_not_refresh_valid_oidc_token(self, mock_post):
call_count = [0]

def post_side_effect(url, *args, **kwargs):
if url.endswith('/protocol/openid-connect/token'):
call_count[0] += 1
return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"})
return MagicMock(status_code=404, text='')

mock_post.side_effect = post_side_effect

kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s')
session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc)
self.assertEqual(call_count[0], 1)

# Token is still fresh (_token_expiry is in the future)
session.ensure_token_fresh()
self.assertEqual(call_count[0], 1)
self.assertEqual(session.headers['Authorization'], 'Bearer token-1')


if __name__ == '__main__':
unittest.main()
Loading