diff --git a/src/ska/logging/transactions.py b/src/ska/logging/transactions.py
index a29cdf60b7b7f4ae16e7f1b3b565757bc0fd3445..e6c1a1068e576b449774ecd970fbc71d97c64db8 100644
--- a/src/ska/logging/transactions.py
+++ b/src/ska/logging/transactions.py
@@ -6,45 +6,56 @@ import logging
 import os
 import inspect
 from random import randint
+import threading
 
 from typing import Mapping, Optional, Text
 
 from ska.skuid.client import SkuidClient
 
+thread_local_data = threading.local()
+
 
 class TransactionIDTagsFilter(logging.Filter):
     """Adds the transaction ID as a tag to the log. Updates module and line number
     for the Enter and Exit log messages.
     """
 
-    def __init__(self, *args, **kwargs):
-        """Override logging.Filter.__init__ to keep track of the transaction ID and callstack"""
-        self.transaction_id = kwargs.pop("transaction_id", None)
-        self.callstack = kwargs.pop("call_stack", None)
-        super(TransactionIDTagsFilter, self).__init__(*args, **kwargs)
+    def get_transaction_id(self):
+        if hasattr(thread_local_data, "transaction_ids"):
+            thread_id = threading.get_ident()
+            return thread_local_data.transaction_ids.get(thread_id, None)
+        return None
+
+    def get_frame(self):
+        if hasattr(thread_local_data, "frames"):
+            thread_id = threading.get_ident()
+            return thread_local_data.frames.get(thread_id, None)
+        return None
 
     def filter(self, record):
         # Add the transaction ID to the tags
-        if self.transaction_id:
+        transaction_id = self.get_transaction_id()
+        if transaction_id:
             if hasattr(record, "tags") and record.tags:
-                if self.transaction_id not in record.tags:
-                    record.tags = f"{record.tags},transaction_id:{self.transaction_id}"
+                if transaction_id not in record.tags:
+                    record.tags = f"{record.tags},transaction_id:{transaction_id}"
             else:
-                record.tags = f"transaction_id:{self.transaction_id}"
+                record.tags = f"transaction_id:{transaction_id}"
 
         # Override the calling module and line number since the log would have logged
         # `transactions.py#X` on `__enter__` and `__exit__` of `Transaction`. This makes it
         # difficult to debug where the `Enter` and `Exit` of a transaction log message
         # originated.
         # From Python 3.8 we should rather use `stacklevel`
-        if self.callstack:
+        frame = self.get_frame()
+        if frame:
             if record.filename.startswith("transactions.py") and record.funcName in [
                 "__enter__",
                 "__exit__",
             ]:
-                record.filename = os.path.basename(self.callstack.filename)
-                record.lineno = self.callstack.lineno
-                record.funcName = self.callstack.function
+                record.filename = os.path.basename(frame.filename)
+                record.lineno = frame.lineno
+                record.funcName = frame.function
 
         return True
 
@@ -88,11 +99,11 @@ class Transaction:
 
     Log message formats:
        On Entry:
-           Transaction [id]: Enter [name] with parameters [arguments]
+           Transaction[id]: Enter [name] with parameters [arguments]
        On Exit:
-           Transaction [id]: Exit [name]
+           Transaction[id]: Exit [name]
        On exception:
-           Transaction [id]: Exception [name]
+           Transaction[id]: Exception [name]
            Stacktrace
     """
 
@@ -152,6 +163,7 @@ class Transaction:
         self._transaction_id_key = transaction_id_key
 
         self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id)
+        self._frame = inspect.stack()[1]
 
         if transaction_id and params.get(self._transaction_id_key):
             self.logger.info(
@@ -162,19 +174,35 @@ class Transaction:
         # Used to match enter and exit when multiple devices calls the same command
         # on a shared device simultaneously
         self._random_marker = str(randint(0, 99999)).zfill(5)
-        self._transaction_filter = TransactionIDTagsFilter(
-            transaction_id=self._transaction_id, call_stack=inspect.stack()[1]
-        )
+        self._transaction_filter = TransactionIDTagsFilter()
+
+    def store_thread_data(self):
+        thread_id = threading.get_ident()
+        if not hasattr(thread_local_data, "transaction_ids"):
+            thread_local_data.transaction_ids = {}
+        if not hasattr(thread_local_data, "frames"):
+            thread_local_data.frames = {}
+        thread_local_data.transaction_ids[thread_id] = self._transaction_id
+        thread_local_data.frames[thread_id] = self._frame
+
+    def clear_thread_data(self):
+        thread_id = threading.get_ident()
+        if hasattr(thread_local_data, "transaction_ids"):
+            if thread_id in thread_local_data.transaction_ids:
+                del thread_local_data.transaction_ids[thread_id]
+        if hasattr(thread_local_data, "frames"):
+            if thread_id in thread_local_data.frames:
+                del thread_local_data.frames[thread_id]
 
     def __enter__(self):
-
+        self.store_thread_data()
         self.logger.addFilter(self._transaction_filter)
         params_json = json.dumps(self._params)
 
         self.logger.info(
-            f"Transaction [{self._transaction_id}]: Enter[{self._name}] "
+            f"Transaction[{self._transaction_id}]: Enter[{self._name}] "
             f"with parameters [{params_json}] "
-            f"marker [{self._random_marker}]"
+            f"marker[{self._random_marker}]"
         )
 
         return self._transaction_id
@@ -191,6 +219,7 @@ class Transaction:
         )
 
         self.logger.removeFilter(self._transaction_filter)
+        self.clear_thread_data()
 
         if exc_type:
             raise
diff --git a/tests/test_transactions.py b/tests/test_transactions.py
index 21ebb9d49836ed927163aaac1fb32033cd1b0ec6..445df7470b56fcf4c332b6ed9218c36416fcb811 100644
--- a/tests/test_transactions.py
+++ b/tests/test_transactions.py
@@ -155,7 +155,7 @@ class TestTransactionLogging:
                 return
         assert 0, f"Could not get a log message with `Inner Log` and `transaction_id` in {records}"
 
-    def test_log_override_enter_exit(self, recording_logger):
+    def test_log_override_enter_exit_passed_logger(self, recording_logger):
         parameters = {}
         with transaction("name", parameters, logger=recording_logger) as transaction_id:
             recording_logger.info("Inner Log")
@@ -164,13 +164,13 @@ class TestTransactionLogging:
         last_log_message, _ = get_last_record_and_log_message(recording_logger)
 
         assert "Enter" in second_log_record.message
-        assert second_log_record.funcName == "test_log_override_enter_exit"
+        assert second_log_record.funcName == "test_log_override_enter_exit_passed_logger"
         assert second_log_record.filename == "test_transactions.py"
         assert "Exit" in last_log_message.message
-        assert last_log_message.funcName == "test_log_override_enter_exit"
+        assert last_log_message.funcName == "test_log_override_enter_exit_passed_logger"
         assert last_log_message.filename == "test_transactions.py"
 
-    def test_log_override_enter_exit(self, recording_logger):
+    def test_log_override_enter_exit_no_logger(self, recording_logger):
         parameters = {}
         with transaction("name", parameters) as transaction_id:
             recording_logger.info("Inner Log")
@@ -180,15 +180,15 @@ class TestTransactionLogging:
         last_log_message, _ = get_last_record_and_log_message(recording_logger)
 
         assert "Generated" in first_log_record.message
-        assert first_log_record.funcName != "test_log_override_enter_exit"
+        assert first_log_record.funcName != "test_log_override_enter_exit_no_logger"
         assert first_log_record.filename != "test_transactions.py"
 
         assert "Enter" in second_log_record.message
-        assert second_log_record.funcName == "test_log_override_enter_exit"
+        assert second_log_record.funcName == "test_log_override_enter_exit_no_logger"
         assert second_log_record.filename == "test_transactions.py"
 
         assert "Exit" in last_log_message.message
-        assert last_log_message.funcName == "test_log_override_enter_exit"
+        assert last_log_message.funcName == "test_log_override_enter_exit_no_logger"
         assert last_log_message.filename == "test_transactions.py"
 
     def test_specified_logger(self):
diff --git a/tests/test_transactions_threaded.py b/tests/test_transactions_threaded.py
index 62c00ada916c0e601d9f818c82c8b5bad0a7212c..b6b35fb2b6ada06ba727318c281e2a541382c09a 100644
--- a/tests/test_transactions_threaded.py
+++ b/tests/test_transactions_threaded.py
@@ -2,6 +2,7 @@
 import pytest
 
 from threading import Thread
+from collections import Counter
 
 from ska.logging import transaction
 from tests.conftest import get_all_record_logs, clear_logger_logs
@@ -11,6 +12,7 @@ class ThreadingLogsGenerator:
     """Generate logs by spawning a number of threads and logging in them
     Some uses the transaction context and some not.
     """
+
     def __init__(self, logger=None, pass_logger=False):
         self.logger = logger
         self.pass_logger = pass_logger
@@ -86,7 +88,6 @@ def threaded_logs_global_logger(ensures_tags_logger):
 
 
 class TestThreadScenarios:
-    @pytest.mark.xfail
     def test_logs_outside_transaction_has_no_transaction_ids(
         self, threaded_logs_global_logger, threaded_logs_local_logger
     ):
@@ -96,7 +97,6 @@ class TestThreadScenarios:
         for log in outside_transaction_logs:
             assert "transaction_id:txn" not in log, f"transaction_id should not be in log {log}"
 
-    @pytest.mark.xfail
     def test_no_duplicate_transaction_ids(self, threaded_logs_local_logger):
         all_logs = threaded_logs_local_logger
         transaction_logs = [log for log in all_logs if "Transaction thread" in log]
@@ -129,36 +129,53 @@ class TestThreadScenarios:
         assert enter_exit_logs
         assert len(enter_exit_logs) % 2 == 0
 
-        # Group enter exit by marker
-        markers = {}
+        transaction_id_marker = []
         for log in enter_exit_logs:
-            marker = log[-6:-1]
-            assert marker.isdigit()
-            markers.setdefault(marker, []).append(log)
-        assert markers
-
-        for key, log_pair in markers.items():
-            assert len(log_pair) == 2
-            assert "Enter" in log_pair[0]
-            assert "Exit" in log_pair[1]
-            assert log_pair[0].endswith(f"[{key}]")
-            assert log_pair[1].endswith(f"[{key}]")
+            transaction_id_marker.append(
+                (get_marker_in_message(log), get_transaction_id_in_message(log))
+            )
+        # Group enter exit by (transaction_id, marker)
+        # Make sure there is only 2 of each
+        counter = dict(Counter(transaction_id_marker))
+        for items, count in counter.items():
+            assert count == 2, f"Found {count} of {items} instead of 2"
 
+        # Make sure there's a enter/exit for every exception
         exception_logs = [log for log in all_logs if "RuntimeError" in log]
+        assert exception_logs
         for log in exception_logs:
-            marker_start = log.index("marker[")
-            marker = log[marker_start + 7 : marker_start + 12]
-            assert marker.isdigit()
-            assert marker in markers, "An exception marker has no match with a start/end log"
-            markers[marker].append(log)
-
-        # There is an equal number test transactions that has exceptions to those that do not
-        count_exceptions = 0
-        count_no_exceptions = 0
-        for key, logs in markers.items():
-            if len(logs) == 2:
-                count_no_exceptions += 1
-            if len(logs) == 3:
-                count_exceptions += 1
-        assert count_no_exceptions != 0
-        assert count_no_exceptions == count_exceptions
+            assert (
+                get_marker_in_message(log),
+                get_transaction_id_in_message(log),
+            ) in transaction_id_marker
+
+        # Make sure all the transaction ids in the tags match that in the message
+        for log in enter_exit_logs + exception_logs:
+            tag_id = get_transaction_id_in_tag(log)
+            message_id = get_transaction_id_in_message(log)
+            assert tag_id
+            assert tag_id == message_id
+
+
+def get_transaction_id_in_tag(log_message):
+    tags = log_message.split("|")[-2]
+    if tags:
+        tags_list = tags.split(",")
+        for tag in tags_list:
+            if "transaction_id" in tag:
+                return tag.split(":")[1]
+    return None
+
+
+def get_transaction_id_in_message(log_message):
+    if "Transaction[" in log_message:
+        transaction_index = log_message.index("Transaction[")
+        return log_message[transaction_index + 12 : transaction_index + 40]
+    return None
+
+
+def get_marker_in_message(log_message):
+    if "marker[" in log_message:
+        marker_index = log_message.index("marker[")
+        return log_message[marker_index + 7 : marker_index + 12]
+    return None