From 0677534ee50182bc3c5857230891b5bf2d21baf7 Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Thu, 1 Feb 2024 13:40:21 +0300
Subject: [PATCH] test: used to crash on non utf-8 output

---
 test/conftest.py            | 18 ++++++++++++------
 test/int/test_shutdown.py   |  4 ++--
 test/manual/test_scaling.py |  4 ++--
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/test/conftest.py b/test/conftest.py
index b9f68f48f3..bac5c48520 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -478,7 +478,7 @@ class Instance:
     env: dict[str, str] = field(default_factory=dict)
     process: subprocess.Popen | None = None
     raft_id: int = INVALID_RAFT_ID
-    _on_output_callbacks: list[Callable[[str], None]] = field(default_factory=list)
+    _on_output_callbacks: list[Callable[[bytes], None]] = field(default_factory=list)
 
     @property
     def listen(self):
@@ -662,22 +662,28 @@ class Instance:
         finally:
             self.kill()
 
-    def _process_output(self, src, out):
+    def _process_output(self, src, out: io.TextIOWrapper):
         id = self.instance_id or f":{self.port}"
         prefix = f"{id:<3} | "
 
         if sys.stdout.isatty():
             prefix = self.color(prefix)
 
-        for line in io.TextIOWrapper(src, line_buffering=True):
+        prefix_bytes = prefix.encode("utf-8")
+
+        # `iter(callable, sentinel)` form: calls callable until it returns sentinel
+        for line in iter(src.readline, b""):
             with OUT_LOCK:
-                out.write(prefix)
-                out.write(line)
+                out.buffer.write(prefix_bytes)
+                out.buffer.write(line)
                 out.flush()
                 for cb in self._on_output_callbacks:
                     cb(line)
 
-    def on_output_line(self, cb: Callable[[str], None]):
+        # Close the stream, because `Instance.fail_to_start` is waiting for it
+        src.close()
+
+    def on_output_line(self, cb: Callable[[bytes], None]):
         self._on_output_callbacks.append(cb)
 
     def start(self, peers=[]):
diff --git a/test/int/test_shutdown.py b/test/int/test_shutdown.py
index d2071feef2..3794941d5c 100644
--- a/test/int/test_shutdown.py
+++ b/test/int/test_shutdown.py
@@ -21,10 +21,10 @@ def cluster3(cluster: Cluster):
 class log_crawler:
     def __init__(self, instance: Instance, search_str: str) -> None:
         self.matched = False
-        self.search_str = search_str
+        self.search_str = search_str.encode("utf-8")
         instance.on_output_line(self._cb)
 
-    def _cb(self, line):
+    def _cb(self, line: bytes):
         if self.search_str in line:
             self.matched = True
 
diff --git a/test/manual/test_scaling.py b/test/manual/test_scaling.py
index d90ea40b9e..b6d762f327 100644
--- a/test/manual/test_scaling.py
+++ b/test/manual/test_scaling.py
@@ -90,9 +90,9 @@ def test_cas_conflicts(binary_path, tmpdir_factory, cluster_ids, port_range):
         start = time.time()
         cluster.deploy(instance_count=1)
 
-        def count_conflicts(line: str):
+        def count_conflicts(line: bytes):
             nonlocal cas_conflicts
-            if line.find("compare-and-swap: ConflictFound") != -1:
+            if line.find(b"compare-and-swap: ConflictFound") != -1:
                 cas_conflicts += 1
 
         cluster[0].on_output_line(count_conflicts)
-- 
GitLab