Skip to content

Commit 3e38ec4

Browse files
committed
Add ability to configure status spinner (#59)
* Add ability to configure status spinner
1 parent cf0db35 commit 3e38ec4

File tree

3 files changed

+71
-14
lines changed

3 files changed

+71
-14
lines changed

src/shelloracle/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
45
import sys
56
from collections.abc import Mapping, Iterator
67
from pathlib import Path
78
from typing import Any
9+
from yaspin.spinners import SPINNERS_DATA
810

911
if sys.version_info < (3, 11):
1012
import tomli as tomllib
1113
else:
1214
import tomllib
1315

16+
logger = logging.getLogger(__name__)
1417
shelloracle_home = Path.home() / ".shelloracle"
1518

1619

@@ -42,6 +45,16 @@ def __iter__(self) -> Iterator[Any]:
4245
def provider(self) -> str:
4346
return self["shelloracle"]["provider"]
4447

48+
@property
49+
def spinner_style(self) -> str | None:
50+
style = self["shelloracle"].get("spinner_style", None)
51+
if not style:
52+
return None
53+
if style not in SPINNERS_DATA:
54+
logger.warning("invalid spinner style: %s", style)
55+
return None
56+
return style
57+
4558

4659
_config: Configuration | None = None
4760

src/shelloracle/shelloracle.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,22 @@
55
import os
66
import sys
77
from pathlib import Path
8+
from typing import TYPE_CHECKING
89

910
from prompt_toolkit import PromptSession, print_formatted_text
1011
from prompt_toolkit.application import create_app_session_from_tty
1112
from prompt_toolkit.formatted_text import FormattedText
1213
from prompt_toolkit.history import FileHistory
1314
from prompt_toolkit.patch_stdout import patch_stdout
1415
from yaspin import yaspin
16+
from yaspin.spinners import Spinners
1517

1618
from .config import get_config
1719
from .providers import get_provider
1820

21+
if TYPE_CHECKING:
22+
from yaspin.core import Yaspin
23+
1924
logger = logging.getLogger(__name__)
2025

2126

@@ -44,6 +49,18 @@ def get_query_from_pipe() -> str | None:
4449
return lines[0].rstrip()
4550

4651

52+
def spinner() -> Yaspin:
53+
"""Get the correct spinner based on the user's configuration
54+
55+
:returns: yaspin object
56+
"""
57+
config = get_config()
58+
if not config.spinner_style:
59+
return yaspin()
60+
style = getattr(Spinners, config.spinner_style)
61+
return yaspin(style)
62+
63+
4764
async def shelloracle() -> None:
4865
"""ShellOracle program entrypoint
4966
@@ -65,7 +82,7 @@ async def shelloracle() -> None:
6582
logger.info("user prompt: %s", prompt)
6683

6784
shell_command = ""
68-
with create_app_session_from_tty(), patch_stdout(raw=True), yaspin() as sp:
85+
with create_app_session_from_tty(), patch_stdout(raw=True), spinner() as sp:
6986
async for token in provider.generate(prompt):
7087
# some models may erroneously return a newline, which causes issues with the status spinner
7188
token = token.replace("\n", "")

tests/test_shelloracle.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,53 @@
1+
from __future__ import annotations
2+
13
import os
24
import sys
5+
from unittest.mock import MagicMock, call
36

47
import pytest
8+
from yaspin.spinners import Spinners
59

6-
from shelloracle.shelloracle import get_query_from_pipe
10+
from shelloracle.shelloracle import get_query_from_pipe, spinner
711

812

9-
def test_get_query_from_pipe(monkeypatch):
10-
# Is a TTY
11-
monkeypatch.setattr(os, "isatty", lambda _: True)
12-
assert get_query_from_pipe() is None
13+
@pytest.fixture
14+
def mock_yaspin(monkeypatch):
15+
mock = MagicMock()
16+
monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock)
17+
return mock
1318

14-
# Not a TTY and no lines in the pipe
15-
monkeypatch.setattr(os, "isatty", lambda _: False)
16-
monkeypatch.setattr(sys.stdin, "readlines", lambda: [])
17-
assert get_query_from_pipe() is None
1819

19-
# Not TTY and one line in the pipe
20-
monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up"])
21-
assert get_query_from_pipe() == "what is up"
20+
@pytest.fixture
21+
def mock_config(monkeypatch):
22+
config = MagicMock()
23+
monkeypatch.setattr("shelloracle.config._config", config)
24+
return config
25+
26+
27+
@pytest.mark.parametrize("spinner_style,expected", [(None, call()), ("earth", call(Spinners.earth))])
28+
def test_spinner(spinner_style, expected, mock_config, mock_yaspin):
29+
mock_config.spinner_style = spinner_style
30+
spinner()
31+
assert mock_yaspin.call_args == expected
32+
2233

23-
# Not a TTY and multiple lines in the pipe
34+
def test_spinner_fail(mock_yaspin, mock_config):
35+
mock_config.spinner_style = "not a spinner style"
36+
with pytest.raises(AttributeError):
37+
spinner()
38+
39+
40+
@pytest.mark.parametrize("isatty,readlines,expected", [
41+
(True, None, None), (False, [], None), (False, ["what is up"], "what is up")
42+
])
43+
def test_get_query_from_pipe(isatty, readlines, expected, monkeypatch):
44+
monkeypatch.setattr(os, "isatty", lambda _: isatty)
45+
monkeypatch.setattr(sys.stdin, "readlines", lambda: readlines)
46+
assert get_query_from_pipe() == expected
47+
48+
49+
def test_get_query_from_pipe_fail(monkeypatch):
50+
monkeypatch.setattr(os, "isatty", lambda _: False)
2451
monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up", "what is down"])
2552
with pytest.raises(ValueError):
2653
get_query_from_pipe()

0 commit comments

Comments
 (0)