diff --git a/src/ska/logging/transactions.py b/src/ska/logging/transactions.py index a29cdf60b7b7f4ae16e7f1b3b565757bc0fd3445..e6c1a1068e576b449774ecd970fbc71d97c64db8 100644 --- a/src/ska/logging/transactions.py +++ b/src/ska/logging/transactions.py @@ -6,45 +6,56 @@ import logging import os import inspect from random import randint +import threading from typing import Mapping, Optional, Text from ska.skuid.client import SkuidClient +thread_local_data = threading.local() + class TransactionIDTagsFilter(logging.Filter): """Adds the transaction ID as a tag to the log. Updates module and line number for the Enter and Exit log messages. """ - def __init__(self, *args, **kwargs): - """Override logging.Filter.__init__ to keep track of the transaction ID and callstack""" - self.transaction_id = kwargs.pop("transaction_id", None) - self.callstack = kwargs.pop("call_stack", None) - super(TransactionIDTagsFilter, self).__init__(*args, **kwargs) + def get_transaction_id(self): + if hasattr(thread_local_data, "transaction_ids"): + thread_id = threading.get_ident() + return thread_local_data.transaction_ids.get(thread_id, None) + return None + + def get_frame(self): + if hasattr(thread_local_data, "frames"): + thread_id = threading.get_ident() + return thread_local_data.frames.get(thread_id, None) + return None def filter(self, record): # Add the transaction ID to the tags - if self.transaction_id: + transaction_id = self.get_transaction_id() + if transaction_id: if hasattr(record, "tags") and record.tags: - if self.transaction_id not in record.tags: - record.tags = f"{record.tags},transaction_id:{self.transaction_id}" + if transaction_id not in record.tags: + record.tags = f"{record.tags},transaction_id:{transaction_id}" else: - record.tags = f"transaction_id:{self.transaction_id}" + record.tags = f"transaction_id:{transaction_id}" # Override the calling module and line number since the log would have logged # `transactions.py#X` on `__enter__` and `__exit__` of `Transaction`. This makes it # difficult to debug where the `Enter` and `Exit` of a transaction log message # originated. # From Python 3.8 we should rather use `stacklevel` - if self.callstack: + frame = self.get_frame() + if frame: if record.filename.startswith("transactions.py") and record.funcName in [ "__enter__", "__exit__", ]: - record.filename = os.path.basename(self.callstack.filename) - record.lineno = self.callstack.lineno - record.funcName = self.callstack.function + record.filename = os.path.basename(frame.filename) + record.lineno = frame.lineno + record.funcName = frame.function return True @@ -88,11 +99,11 @@ class Transaction: Log message formats: On Entry: - Transaction [id]: Enter [name] with parameters [arguments] + Transaction[id]: Enter [name] with parameters [arguments] On Exit: - Transaction [id]: Exit [name] + Transaction[id]: Exit [name] On exception: - Transaction [id]: Exception [name] + Transaction[id]: Exception [name] Stacktrace """ @@ -152,6 +163,7 @@ class Transaction: self._transaction_id_key = transaction_id_key self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id) + self._frame = inspect.stack()[1] if transaction_id and params.get(self._transaction_id_key): self.logger.info( @@ -162,19 +174,35 @@ class Transaction: # Used to match enter and exit when multiple devices calls the same command # on a shared device simultaneously self._random_marker = str(randint(0, 99999)).zfill(5) - self._transaction_filter = TransactionIDTagsFilter( - transaction_id=self._transaction_id, call_stack=inspect.stack()[1] - ) + self._transaction_filter = TransactionIDTagsFilter() + + def store_thread_data(self): + thread_id = threading.get_ident() + if not hasattr(thread_local_data, "transaction_ids"): + thread_local_data.transaction_ids = {} + if not hasattr(thread_local_data, "frames"): + thread_local_data.frames = {} + thread_local_data.transaction_ids[thread_id] = self._transaction_id + thread_local_data.frames[thread_id] = self._frame + + def clear_thread_data(self): + thread_id = threading.get_ident() + if hasattr(thread_local_data, "transaction_ids"): + if thread_id in thread_local_data.transaction_ids: + del thread_local_data.transaction_ids[thread_id] + if hasattr(thread_local_data, "frames"): + if thread_id in thread_local_data.frames: + del thread_local_data.frames[thread_id] def __enter__(self): - + self.store_thread_data() self.logger.addFilter(self._transaction_filter) params_json = json.dumps(self._params) self.logger.info( - f"Transaction [{self._transaction_id}]: Enter[{self._name}] " + f"Transaction[{self._transaction_id}]: Enter[{self._name}] " f"with parameters [{params_json}] " - f"marker [{self._random_marker}]" + f"marker[{self._random_marker}]" ) return self._transaction_id @@ -191,6 +219,7 @@ class Transaction: ) self.logger.removeFilter(self._transaction_filter) + self.clear_thread_data() if exc_type: raise diff --git a/tests/test_transactions.py b/tests/test_transactions.py index 21ebb9d49836ed927163aaac1fb32033cd1b0ec6..445df7470b56fcf4c332b6ed9218c36416fcb811 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -155,7 +155,7 @@ class TestTransactionLogging: return assert 0, f"Could not get a log message with `Inner Log` and `transaction_id` in {records}" - def test_log_override_enter_exit(self, recording_logger): + def test_log_override_enter_exit_passed_logger(self, recording_logger): parameters = {} with transaction("name", parameters, logger=recording_logger) as transaction_id: recording_logger.info("Inner Log") @@ -164,13 +164,13 @@ class TestTransactionLogging: last_log_message, _ = get_last_record_and_log_message(recording_logger) assert "Enter" in second_log_record.message - assert second_log_record.funcName == "test_log_override_enter_exit" + assert second_log_record.funcName == "test_log_override_enter_exit_passed_logger" assert second_log_record.filename == "test_transactions.py" assert "Exit" in last_log_message.message - assert last_log_message.funcName == "test_log_override_enter_exit" + assert last_log_message.funcName == "test_log_override_enter_exit_passed_logger" assert last_log_message.filename == "test_transactions.py" - def test_log_override_enter_exit(self, recording_logger): + def test_log_override_enter_exit_no_logger(self, recording_logger): parameters = {} with transaction("name", parameters) as transaction_id: recording_logger.info("Inner Log") @@ -180,15 +180,15 @@ class TestTransactionLogging: last_log_message, _ = get_last_record_and_log_message(recording_logger) assert "Generated" in first_log_record.message - assert first_log_record.funcName != "test_log_override_enter_exit" + assert first_log_record.funcName != "test_log_override_enter_exit_no_logger" assert first_log_record.filename != "test_transactions.py" assert "Enter" in second_log_record.message - assert second_log_record.funcName == "test_log_override_enter_exit" + assert second_log_record.funcName == "test_log_override_enter_exit_no_logger" assert second_log_record.filename == "test_transactions.py" assert "Exit" in last_log_message.message - assert last_log_message.funcName == "test_log_override_enter_exit" + assert last_log_message.funcName == "test_log_override_enter_exit_no_logger" assert last_log_message.filename == "test_transactions.py" def test_specified_logger(self): diff --git a/tests/test_transactions_threaded.py b/tests/test_transactions_threaded.py index 62c00ada916c0e601d9f818c82c8b5bad0a7212c..b6b35fb2b6ada06ba727318c281e2a541382c09a 100644 --- a/tests/test_transactions_threaded.py +++ b/tests/test_transactions_threaded.py @@ -2,6 +2,7 @@ import pytest from threading import Thread +from collections import Counter from ska.logging import transaction from tests.conftest import get_all_record_logs, clear_logger_logs @@ -11,6 +12,7 @@ class ThreadingLogsGenerator: """Generate logs by spawning a number of threads and logging in them Some uses the transaction context and some not. """ + def __init__(self, logger=None, pass_logger=False): self.logger = logger self.pass_logger = pass_logger @@ -86,7 +88,6 @@ def threaded_logs_global_logger(ensures_tags_logger): class TestThreadScenarios: - @pytest.mark.xfail def test_logs_outside_transaction_has_no_transaction_ids( self, threaded_logs_global_logger, threaded_logs_local_logger ): @@ -96,7 +97,6 @@ class TestThreadScenarios: for log in outside_transaction_logs: assert "transaction_id:txn" not in log, f"transaction_id should not be in log {log}" - @pytest.mark.xfail def test_no_duplicate_transaction_ids(self, threaded_logs_local_logger): all_logs = threaded_logs_local_logger transaction_logs = [log for log in all_logs if "Transaction thread" in log] @@ -129,36 +129,53 @@ class TestThreadScenarios: assert enter_exit_logs assert len(enter_exit_logs) % 2 == 0 - # Group enter exit by marker - markers = {} + transaction_id_marker = [] for log in enter_exit_logs: - marker = log[-6:-1] - assert marker.isdigit() - markers.setdefault(marker, []).append(log) - assert markers - - for key, log_pair in markers.items(): - assert len(log_pair) == 2 - assert "Enter" in log_pair[0] - assert "Exit" in log_pair[1] - assert log_pair[0].endswith(f"[{key}]") - assert log_pair[1].endswith(f"[{key}]") + transaction_id_marker.append( + (get_marker_in_message(log), get_transaction_id_in_message(log)) + ) + # Group enter exit by (transaction_id, marker) + # Make sure there is only 2 of each + counter = dict(Counter(transaction_id_marker)) + for items, count in counter.items(): + assert count == 2, f"Found {count} of {items} instead of 2" + # Make sure there's a enter/exit for every exception exception_logs = [log for log in all_logs if "RuntimeError" in log] + assert exception_logs for log in exception_logs: - marker_start = log.index("marker[") - marker = log[marker_start + 7 : marker_start + 12] - assert marker.isdigit() - assert marker in markers, "An exception marker has no match with a start/end log" - markers[marker].append(log) - - # There is an equal number test transactions that has exceptions to those that do not - count_exceptions = 0 - count_no_exceptions = 0 - for key, logs in markers.items(): - if len(logs) == 2: - count_no_exceptions += 1 - if len(logs) == 3: - count_exceptions += 1 - assert count_no_exceptions != 0 - assert count_no_exceptions == count_exceptions + assert ( + get_marker_in_message(log), + get_transaction_id_in_message(log), + ) in transaction_id_marker + + # Make sure all the transaction ids in the tags match that in the message + for log in enter_exit_logs + exception_logs: + tag_id = get_transaction_id_in_tag(log) + message_id = get_transaction_id_in_message(log) + assert tag_id + assert tag_id == message_id + + +def get_transaction_id_in_tag(log_message): + tags = log_message.split("|")[-2] + if tags: + tags_list = tags.split(",") + for tag in tags_list: + if "transaction_id" in tag: + return tag.split(":")[1] + return None + + +def get_transaction_id_in_message(log_message): + if "Transaction[" in log_message: + transaction_index = log_message.index("Transaction[") + return log_message[transaction_index + 12 : transaction_index + 40] + return None + + +def get_marker_in_message(log_message): + if "marker[" in log_message: + marker_index = log_message.index("marker[") + return log_message[marker_index + 7 : marker_index + 12] + return None