From 0677534ee50182bc3c5857230891b5bf2d21baf7 Mon Sep 17 00:00:00 2001 From: Georgy Moshkin <gmoshkin@picodata.io> Date: Thu, 1 Feb 2024 13:40:21 +0300 Subject: [PATCH] test: used to crash on non utf-8 output --- test/conftest.py | 18 ++++++++++++------ test/int/test_shutdown.py | 4 ++-- test/manual/test_scaling.py | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index b9f68f48f3..bac5c48520 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -478,7 +478,7 @@ class Instance: env: dict[str, str] = field(default_factory=dict) process: subprocess.Popen | None = None raft_id: int = INVALID_RAFT_ID - _on_output_callbacks: list[Callable[[str], None]] = field(default_factory=list) + _on_output_callbacks: list[Callable[[bytes], None]] = field(default_factory=list) @property def listen(self): @@ -662,22 +662,28 @@ class Instance: finally: self.kill() - def _process_output(self, src, out): + def _process_output(self, src, out: io.TextIOWrapper): id = self.instance_id or f":{self.port}" prefix = f"{id:<3} | " if sys.stdout.isatty(): prefix = self.color(prefix) - for line in io.TextIOWrapper(src, line_buffering=True): + prefix_bytes = prefix.encode("utf-8") + + # `iter(callable, sentinel)` form: calls callable until it returns sentinel + for line in iter(src.readline, b""): with OUT_LOCK: - out.write(prefix) - out.write(line) + out.buffer.write(prefix_bytes) + out.buffer.write(line) out.flush() for cb in self._on_output_callbacks: cb(line) - def on_output_line(self, cb: Callable[[str], None]): + # Close the stream, because `Instance.fail_to_start` is waiting for it + src.close() + + def on_output_line(self, cb: Callable[[bytes], None]): self._on_output_callbacks.append(cb) def start(self, peers=[]): diff --git a/test/int/test_shutdown.py b/test/int/test_shutdown.py index d2071feef2..3794941d5c 100644 --- a/test/int/test_shutdown.py +++ b/test/int/test_shutdown.py @@ -21,10 +21,10 @@ def cluster3(cluster: Cluster): class log_crawler: def __init__(self, instance: Instance, search_str: str) -> None: self.matched = False - self.search_str = search_str + self.search_str = search_str.encode("utf-8") instance.on_output_line(self._cb) - def _cb(self, line): + def _cb(self, line: bytes): if self.search_str in line: self.matched = True diff --git a/test/manual/test_scaling.py b/test/manual/test_scaling.py index d90ea40b9e..b6d762f327 100644 --- a/test/manual/test_scaling.py +++ b/test/manual/test_scaling.py @@ -90,9 +90,9 @@ def test_cas_conflicts(binary_path, tmpdir_factory, cluster_ids, port_range): start = time.time() cluster.deploy(instance_count=1) - def count_conflicts(line: str): + def count_conflicts(line: bytes): nonlocal cas_conflicts - if line.find("compare-and-swap: ConflictFound") != -1: + if line.find(b"compare-and-swap: ConflictFound") != -1: cas_conflicts += 1 cluster[0].on_output_line(count_conflicts) -- GitLab