Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions .github/workflows/medcat-trainer_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,22 @@ jobs:
medcat-trainer/client/htmlcov
if-no-files-found: ignore

- name: Set release version from tag
if: startsWith(github.ref, 'refs/tags')
run: |
VERSION="${GITHUB_REF_NAME#medcat-trainer/v}"
sed -i "s/^version = .*/version = \"$VERSION\"/" client/pyproject.toml
echo "Release version: $VERSION"

- name: Set dev version for TestPyPI
if: github.ref == 'refs/heads/main'
run: sed -i "s/^version = .*/version = \"0.0.0.dev$(date +%s)\"/" client/pyproject.toml

- name: Build client package
run: |
cd client
python -m build

- name: Bump version for TestPyPI
if: github.ref == 'refs/heads/main'
run: sed -i "s/^version = .*/version = \"1.0.0.dev$(date +%s)\"/" client/pyproject.toml

- name: Publish dev distribution to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
if: github.ref == 'refs/heads/main'
Expand All @@ -93,7 +100,6 @@ jobs:
- name: Publish production distribution to PyPI
if: startsWith(github.ref, 'refs/tags') && ! github.event.release.prerelease
uses: pypa/gh-action-pypi-publish@release/v1
continue-on-error: true
with:
packages_dir: medcat-trainer/client/dist

Expand Down
22 changes: 19 additions & 3 deletions medcat-trainer/client/mctclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime
import json
import os
import time
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -287,6 +288,9 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
self.username = username or os.getenv("MCTRAINER_USERNAME")
self.password = password or os.getenv("MCTRAINER_PASSWORD")
self.server = server or 'http://localhost:8001'
self._keycloak_settings: Optional[KeycloakSettings] = None
# positive infinity so that token refresh is never triggered for non-OIDC sessions
self._token_expiry: float = float('inf')

env_use_oidc = os.getenv("MCTRAINER_USE_OIDC", "")
env_use_oidc_truthy = env_use_oidc.strip() == "1"
Expand All @@ -299,9 +303,8 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
if not isinstance(keycloak_settings, KeycloakSettings):
raise TypeError("keycloak_settings must be a KeycloakSettings instance")
kc_settings = keycloak_settings

token = get_keycloak_access_token(kc_settings)
self.headers = {"Authorization": f"Bearer {token}"}
self._keycloak_settings = kc_settings
self._refresh_oidc_token()
return

payload = {"username": self.username, "password": self.password}
Expand All @@ -314,6 +317,19 @@ def __init__(self, server=None, username=None, password=None, keycloak_settings=
else:
raise MCTUtilsException(f"Failed to login to MedCATtrainer instance running at: {self.server}")

def _refresh_oidc_token(self) -> None:
"""Fetch a new OIDC access token and update the Authorization header."""
token = get_keycloak_access_token(self._keycloak_settings)
self.headers = {"Authorization": f"Bearer {token}"}
# By default, refresh 60s before the typical 5-minute Keycloak access token lifetime (i.e., 4 minutes / 240 seconds from now)
interval = int(os.getenv("MCTRAINER_TOKEN_REFRESH_INTERVAL", "240"))
self._token_expiry = time.monotonic() + interval

def ensure_token_fresh(self) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be wired in somewhere? Wondering how this is triggered

@jocelyneholdbrook jocelyneholdbrook Jun 15, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be called by whoever uses this client. The idea is that you check whether the current OIDC token is about to expire, and if so, fetches a new one from Keycloak before the next API call goes out. I will be using it in medcattery's medcat-trainer.py. But in an ideal scenario it should be checked in every medcat-trainer HTTP request but I wanted to avoid making a load of changes at this stage.

"""Refresh the OIDC token if it is near expiry. No-op for non-OIDC sessions."""
if self._keycloak_settings is not None and time.monotonic() >= self._token_expiry:
self._refresh_oidc_token()

def create_project(self, name: str,
description: str,
members: Union[List[MCTUser], List[str]],
Expand Down
2 changes: 1 addition & 1 deletion medcat-trainer/client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "medcattrainer-client"
version = "1.3.0"
version = "1.4.0"
description = "Python client for interacting with a MedCATTrainer instance"
readme = "client/README.md"
requires-python = ">=3.10"
Expand Down
56 changes: 56 additions & 0 deletions medcat-trainer/client/tests/test_mctclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
import unittest
from unittest.mock import patch, MagicMock
from mctclient import (
Expand Down Expand Up @@ -748,5 +749,60 @@ def post_side_effect(url, *args, **kwargs):
)
self.assertEqual(result, mock_upload_response)

@patch('mctclient.requests.post')
def test_ensure_token_fresh_is_noop_for_non_oidc_session(self, mock_post):
mock_post.return_value = MagicMock(status_code=200, text='{"token": "abc"}')
session = MedCATTrainerSession(server='http://localhost', username='u', password='p')
original_headers = dict(session.headers)
session.ensure_token_fresh()
self.assertEqual(session.headers, original_headers)
# Only the initial DRF auth call should have been made
mock_post.assert_called_once()

@patch('mctclient.requests.post')
def test_ensure_token_fresh_refreshes_expired_oidc_token(self, mock_post):
call_count = [0]

def post_side_effect(url, *args, **kwargs):
if url.endswith('/protocol/openid-connect/token'):
call_count[0] += 1
return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"})
return MagicMock(status_code=404, text='')

mock_post.side_effect = post_side_effect

kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s')
session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc)
self.assertEqual(session.headers['Authorization'], 'Bearer token-1')

# Simulate expiry by setting _token_expiry to the past
session._token_expiry = time.monotonic() - 1

session.ensure_token_fresh()
self.assertEqual(session.headers['Authorization'], 'Bearer token-2')
self.assertEqual(call_count[0], 2)

@patch('mctclient.requests.post')
def test_ensure_token_fresh_does_not_refresh_valid_oidc_token(self, mock_post):
call_count = [0]

def post_side_effect(url, *args, **kwargs):
if url.endswith('/protocol/openid-connect/token'):
call_count[0] += 1
return MagicMock(status_code=200, json=lambda: {"access_token": f"token-{call_count[0]}"})
return MagicMock(status_code=404, text='')

mock_post.side_effect = post_side_effect

kc = KeycloakSettings(keycloak_url='http://kc', realm='r', client_id='c', client_secret='s')
session = MedCATTrainerSession(server='http://localhost', use_oidc=True, keycloak_settings=kc)
self.assertEqual(call_count[0], 1)

# Token is still fresh (_token_expiry is in the future)
session.ensure_token_fresh()
self.assertEqual(call_count[0], 1)
self.assertEqual(session.headers['Authorization'], 'Bearer token-1')


if __name__ == '__main__':
unittest.main()
Loading