Skip to content
Snippets Groups Projects
Unverified Commit 569def50 authored by SKAJohanVenter's avatar SKAJohanVenter
Browse files

SAR-150 Resolved failing tests that tests threading scenarios

parent 8fd81ddf
No related branches found
No related tags found
No related merge requests found
......@@ -6,45 +6,56 @@ import logging
import os
import inspect
from random import randint
import threading
from typing import Mapping, Optional, Text
from ska.skuid.client import SkuidClient
thread_local_data = threading.local()
class TransactionIDTagsFilter(logging.Filter):
"""Adds the transaction ID as a tag to the log. Updates module and line number
for the Enter and Exit log messages.
"""
def __init__(self, *args, **kwargs):
"""Override logging.Filter.__init__ to keep track of the transaction ID and callstack"""
self.transaction_id = kwargs.pop("transaction_id", None)
self.callstack = kwargs.pop("call_stack", None)
super(TransactionIDTagsFilter, self).__init__(*args, **kwargs)
def get_transaction_id(self):
if hasattr(thread_local_data, "transaction_ids"):
thread_id = threading.get_ident()
return thread_local_data.transaction_ids.get(thread_id, None)
return None
def get_frame(self):
if hasattr(thread_local_data, "frames"):
thread_id = threading.get_ident()
return thread_local_data.frames.get(thread_id, None)
return None
def filter(self, record):
# Add the transaction ID to the tags
if self.transaction_id:
transaction_id = self.get_transaction_id()
if transaction_id:
if hasattr(record, "tags") and record.tags:
if self.transaction_id not in record.tags:
record.tags = f"{record.tags},transaction_id:{self.transaction_id}"
if transaction_id not in record.tags:
record.tags = f"{record.tags},transaction_id:{transaction_id}"
else:
record.tags = f"transaction_id:{self.transaction_id}"
record.tags = f"transaction_id:{transaction_id}"
# Override the calling module and line number since the log would have logged
# `transactions.py#X` on `__enter__` and `__exit__` of `Transaction`. This makes it
# difficult to debug where the `Enter` and `Exit` of a transaction log message
# originated.
# From Python 3.8 we should rather use `stacklevel`
if self.callstack:
frame = self.get_frame()
if frame:
if record.filename.startswith("transactions.py") and record.funcName in [
"__enter__",
"__exit__",
]:
record.filename = os.path.basename(self.callstack.filename)
record.lineno = self.callstack.lineno
record.funcName = self.callstack.function
record.filename = os.path.basename(frame.filename)
record.lineno = frame.lineno
record.funcName = frame.function
return True
......@@ -152,6 +163,7 @@ class Transaction:
self._transaction_id_key = transaction_id_key
self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id)
self._frame = inspect.stack()[1]
if transaction_id and params.get(self._transaction_id_key):
self.logger.info(
......@@ -162,12 +174,28 @@ class Transaction:
# Used to match enter and exit when multiple devices calls the same command
# on a shared device simultaneously
self._random_marker = str(randint(0, 99999)).zfill(5)
self._transaction_filter = TransactionIDTagsFilter(
transaction_id=self._transaction_id, call_stack=inspect.stack()[1]
)
self._transaction_filter = TransactionIDTagsFilter()
def store_thread_data(self):
thread_id = threading.get_ident()
if not hasattr(thread_local_data, "transaction_ids"):
thread_local_data.transaction_ids = {}
if not hasattr(thread_local_data, "frames"):
thread_local_data.frames = {}
thread_local_data.transaction_ids[thread_id] = self._transaction_id
thread_local_data.frames[thread_id] = self._frame
def clear_thread_data(self):
thread_id = threading.get_ident()
if hasattr(thread_local_data, "transaction_ids"):
if thread_id in thread_local_data.transaction_ids:
del thread_local_data.transaction_ids[thread_id]
if hasattr(thread_local_data, "frames"):
if thread_id in thread_local_data.frames:
del thread_local_data.frames[thread_id]
def __enter__(self):
self.store_thread_data()
self.logger.addFilter(self._transaction_filter)
params_json = json.dumps(self._params)
......@@ -191,6 +219,7 @@ class Transaction:
)
self.logger.removeFilter(self._transaction_filter)
self.clear_thread_data()
if exc_type:
raise
......
......@@ -155,7 +155,7 @@ class TestTransactionLogging:
return
assert 0, f"Could not get a log message with `Inner Log` and `transaction_id` in {records}"
def test_log_override_enter_exit(self, recording_logger):
def test_log_override_enter_exit_passed_logger(self, recording_logger):
parameters = {}
with transaction("name", parameters, logger=recording_logger) as transaction_id:
recording_logger.info("Inner Log")
......@@ -164,13 +164,13 @@ class TestTransactionLogging:
last_log_message, _ = get_last_record_and_log_message(recording_logger)
assert "Enter" in second_log_record.message
assert second_log_record.funcName == "test_log_override_enter_exit"
assert second_log_record.funcName == "test_log_override_enter_exit_passed_logger"
assert second_log_record.filename == "test_transactions.py"
assert "Exit" in last_log_message.message
assert last_log_message.funcName == "test_log_override_enter_exit"
assert last_log_message.funcName == "test_log_override_enter_exit_passed_logger"
assert last_log_message.filename == "test_transactions.py"
def test_log_override_enter_exit(self, recording_logger):
def test_log_override_enter_exit_no_logger(self, recording_logger):
parameters = {}
with transaction("name", parameters) as transaction_id:
recording_logger.info("Inner Log")
......@@ -180,15 +180,15 @@ class TestTransactionLogging:
last_log_message, _ = get_last_record_and_log_message(recording_logger)
assert "Generated" in first_log_record.message
assert first_log_record.funcName != "test_log_override_enter_exit"
assert first_log_record.funcName != "test_log_override_enter_exit_no_logger"
assert first_log_record.filename != "test_transactions.py"
assert "Enter" in second_log_record.message
assert second_log_record.funcName == "test_log_override_enter_exit"
assert second_log_record.funcName == "test_log_override_enter_exit_no_logger"
assert second_log_record.filename == "test_transactions.py"
assert "Exit" in last_log_message.message
assert last_log_message.funcName == "test_log_override_enter_exit"
assert last_log_message.funcName == "test_log_override_enter_exit_no_logger"
assert last_log_message.filename == "test_transactions.py"
def test_specified_logger(self):
......
......@@ -2,6 +2,7 @@
import pytest
from threading import Thread
from collections import Counter
from ska.logging import transaction
from tests.conftest import get_all_record_logs, clear_logger_logs
......@@ -11,6 +12,7 @@ class ThreadingLogsGenerator:
"""Generate logs by spawning a number of threads and logging in them
Some uses the transaction context and some not.
"""
def __init__(self, logger=None, pass_logger=False):
self.logger = logger
self.pass_logger = pass_logger
......@@ -86,7 +88,6 @@ def threaded_logs_global_logger(ensures_tags_logger):
class TestThreadScenarios:
@pytest.mark.xfail
def test_logs_outside_transaction_has_no_transaction_ids(
self, threaded_logs_global_logger, threaded_logs_local_logger
):
......@@ -96,7 +97,6 @@ class TestThreadScenarios:
for log in outside_transaction_logs:
assert "transaction_id:txn" not in log, f"transaction_id should not be in log {log}"
@pytest.mark.xfail
def test_no_duplicate_transaction_ids(self, threaded_logs_local_logger):
all_logs = threaded_logs_local_logger
transaction_logs = [log for log in all_logs if "Transaction thread" in log]
......@@ -129,36 +129,53 @@ class TestThreadScenarios:
assert enter_exit_logs
assert len(enter_exit_logs) % 2 == 0
# Group enter exit by marker
markers = {}
transaction_id_marker = []
for log in enter_exit_logs:
marker = log[-6:-1]
assert marker.isdigit()
markers.setdefault(marker, []).append(log)
assert markers
for key, log_pair in markers.items():
assert len(log_pair) == 2
assert "Enter" in log_pair[0]
assert "Exit" in log_pair[1]
assert log_pair[0].endswith(f"[{key}]")
assert log_pair[1].endswith(f"[{key}]")
transaction_id_marker.append(
(get_marker_in_message(log), get_transaction_id_in_message(log))
)
# Group enter exit by (transaction_id, marker)
# Make sure there is only 2 of each
counter = dict(Counter(transaction_id_marker))
for items, count in counter.items():
assert count == 2, f"Found {count} of {items} instead of 2"
# Make sure there's a enter/exit for every exception
exception_logs = [log for log in all_logs if "RuntimeError" in log]
assert exception_logs
for log in exception_logs:
marker_start = log.index("marker[")
marker = log[marker_start + 7 : marker_start + 12]
assert marker.isdigit()
assert marker in markers, "An exception marker has no match with a start/end log"
markers[marker].append(log)
# There is an equal number test transactions that has exceptions to those that do not
count_exceptions = 0
count_no_exceptions = 0
for key, logs in markers.items():
if len(logs) == 2:
count_no_exceptions += 1
if len(logs) == 3:
count_exceptions += 1
assert count_no_exceptions != 0
assert count_no_exceptions == count_exceptions
assert (
get_marker_in_message(log),
get_transaction_id_in_message(log),
) in transaction_id_marker
# Make sure all the transaction ids in the tags match that in the message
for log in enter_exit_logs + exception_logs:
tag_id = get_transaction_id_in_tag(log)
message_id = get_transaction_id_in_message(log)
assert tag_id
assert tag_id == message_id
def get_transaction_id_in_tag(log_message):
tags = log_message.split("|")[-2]
if tags:
tags_list = tags.split(",")
for tag in tags_list:
if "transaction_id" in tag:
return tag.split(":")[1]
return None
def get_transaction_id_in_message(log_message):
if "Transaction[" in log_message:
transaction_index = log_message.index("Transaction[")
return log_message[transaction_index + 12 : transaction_index + 40]
return None
def get_marker_in_message(log_message):
if "marker[" in log_message:
marker_index = log_message.index("marker[")
return log_message[marker_index + 7 : marker_index + 12]
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment