Skip to content

Commit ac342fa

Browse files
author
Olivier Chafik
committed
agent: mypy type fixes
mypy examples/agent/__main__.py mypy examples/agent/fastify.py mypy examples/openai/__main__.py
1 parent eced057 commit ac342fa

File tree

11 files changed

+109
-98
lines changed

11 files changed

+109
-98
lines changed

examples/agent/agent.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from time import sleep
44
import typer
55
from pydantic import BaseModel, Json, TypeAdapter
6-
from typing import Annotated, Callable, List, Union, Optional, Type
6+
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
77
import json, requests
88

99
from examples.agent.openapi_client import OpenAPIMethod, openapi_methods_from_endpoint
@@ -13,7 +13,7 @@
1313
from examples.openai.prompting import ToolsPromptStyle
1414
from examples.openai.subprocesses import spawn_subprocess
1515

16-
def _get_params_schema(fn: Callable, verbose):
16+
def _get_params_schema(fn: Callable[[Any], Any], verbose):
1717
if isinstance(fn, OpenAPIMethod):
1818
return fn.parameters_schema
1919

@@ -26,9 +26,9 @@ def _get_params_schema(fn: Callable, verbose):
2626

2727
def completion_with_tool_usage(
2828
*,
29-
response_model: Optional[Union[Json, Type]]=None,
29+
response_model: Optional[Union[Json[Any], type]]=None,
3030
max_iterations: Optional[int]=None,
31-
tools: List[Callable],
31+
tools: List[Callable[..., Any]],
3232
endpoint: str,
3333
messages: List[Message],
3434
auth: Optional[str],
@@ -56,7 +56,7 @@ def completion_with_tool_usage(
5656
type="function",
5757
function=ToolFunction(
5858
name=fn.__name__,
59-
description=fn.__doc__,
59+
description=fn.__doc__ or '',
6060
parameters=_get_params_schema(fn, verbose=verbose)
6161
)
6262
)
@@ -128,15 +128,15 @@ def completion_with_tool_usage(
128128
def main(
129129
goal: Annotated[str, typer.Option()],
130130
tools: Optional[List[str]] = None,
131-
format: Annotated[str, typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
131+
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
132132
max_iterations: Optional[int] = 10,
133133
std_tools: Optional[bool] = False,
134134
auth: Optional[str] = None,
135135
parallel_calls: Optional[bool] = False,
136136
verbose: bool = False,
137137
style: Optional[ToolsPromptStyle] = None,
138138

139-
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
139+
model: Annotated[str, typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
140140
endpoint: Optional[str] = None,
141141
context_length: Optional[int] = None,
142142
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
@@ -187,8 +187,8 @@ def main(
187187
sleep(5)
188188

189189
tool_functions = []
190-
types = {}
191-
for f in tools:
190+
types: Dict[str, type] = {}
191+
for f in (tools or []):
192192
if f.startswith('http://') or f.startswith('https://'):
193193
tool_functions.extend(openapi_methods_from_endpoint(f))
194194
else:
@@ -203,7 +203,7 @@ def main(
203203
if std_tools:
204204
tool_functions.extend(collect_functions(StandardTools))
205205

206-
response_model = None #str
206+
response_model: Union[type, Json[Any]] = None #str
207207
if format:
208208
if format in types:
209209
response_model = types[format]
@@ -246,10 +246,7 @@ def main(
246246
seed=seed,
247247
n_probs=n_probs,
248248
min_keep=min_keep,
249-
messages=[{
250-
"role": "user",
251-
"content": goal,
252-
}]
249+
messages=[Message(role="user", content=goal)],
253250
)
254251
print(result if response_model else f'➡️ {result}')
255252
# exit(0)

examples/agent/fastify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def bind_functions(app, module):
1717
if k == k.capitalize():
1818
continue
1919
v = getattr(module, k)
20-
if not callable(v) or isinstance(v, Type):
20+
if not callable(v) or isinstance(v, type):
2121
continue
2222
if not hasattr(v, '__annotations__'):
2323
continue

examples/agent/openapi_client.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,44 @@ def __init__(self, url, name, descriptor, catalog):
1717
request_body = post_descriptor.get('requestBody')
1818

1919
self.parameters = {p['name']: p for p in parameters}
20-
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {path}, descriptor: {json.dumps(descriptor)})'
20+
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
2121

2222
self.body = None
23-
self.body_name = None
2423
if request_body:
25-
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {path}, descriptor: {json.dumps(descriptor)})'
24+
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
25+
26+
body_name = 'body'
27+
i = 2
28+
while body_name in self.parameters:
29+
body_name = f'body{i}'
30+
i += 1
31+
2632
self.body = dict(
33+
name=body_name,
2734
required=request_body['required'],
2835
schema=request_body['content']['application/json']['schema'],
2936
)
3037

31-
self.body_name = 'body'
32-
i = 2
33-
while self.body_name in self.parameters:
34-
self.body_name = f'body{i}'
35-
i += 1
36-
3738
self.parameters_schema = dict(
3839
type='object',
3940
properties={
4041
**({
41-
self.body_name: self.body['schema']
42+
self.body['name']: self.body['schema']
4243
} if self.body else {}),
4344
**{
4445
name: param['schema']
4546
for name, param in self.parameters.items()
4647
}
4748
},
4849
components=catalog.get('components'),
49-
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body_name] if self.body and self.body['required'] else [])
50+
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
5051
)
5152

5253
def __call__(self, **kwargs):
5354
if self.body:
54-
body = kwargs.pop(self.body_name, None)
55+
body = kwargs.pop(self.body['name'], None)
5556
if self.body['required']:
56-
assert body is not None, f'Missing required body parameter: {self.body_name}'
57+
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
5758
else:
5859
body = None
5960

examples/agent/tools/std_tools.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class Duration(BaseModel):
1515
months: Optional[int] = None
1616
years: Optional[int] = None
1717

18+
def __str__(self) -> str:
19+
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
20+
1821
@property
1922
def get_total_seconds(self) -> int:
2023
return sum([
@@ -29,6 +32,10 @@ def get_total_seconds(self) -> int:
2932
class WaitForDuration(BaseModel):
3033
duration: Duration
3134

35+
def __call__(self):
36+
sys.stderr.write(f"Waiting for {self.duration}...\n")
37+
time.sleep(self.duration.get_total_seconds)
38+
3239
class WaitForDate(BaseModel):
3340
until: date
3441

@@ -43,7 +50,7 @@ def __call__(self):
4350

4451
days, seconds = time_diff.days, time_diff.seconds
4552

46-
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
53+
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
4754
time.sleep(days * 86400 + seconds)
4855
sys.stderr.write(f"Reached the target date: {self.until}\n")
4956

@@ -67,8 +74,8 @@ def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
6774
return _for()
6875

6976
@staticmethod
70-
def say_out_loud(something: str) -> str:
77+
def say_out_loud(something: str) -> None:
7178
"""
7279
Just says something. Used to say each thought out loud
7380
"""
74-
return subprocess.check_call(["say", something])
81+
subprocess.check_call(["say", something])

examples/agent/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ def load_source_as_module(source):
99
i += 1
1010

1111
spec = importlib.util.spec_from_file_location(module_name, source)
12+
assert spec, f'Failed to load {source} as module'
1213
module = importlib.util.module_from_spec(spec)
1314
sys.modules[module_name] = module
15+
assert spec.loader, f'{source} spec has no loader'
1416
spec.loader.exec_module(module)
1517
return module
1618

@@ -29,7 +31,7 @@ def collect_functions(module):
2931
if k == k.capitalize():
3032
continue
3133
v = getattr(module, k)
32-
if not callable(v) or isinstance(v, Type):
34+
if not callable(v) or isinstance(v, type):
3335
continue
3436
if not hasattr(v, '__annotations__'):
3537
continue

examples/json_schema_to_grammar.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def opt_repetitions(up_to_n, prefix_with_sep=False):
5555

5656

5757
class BuiltinRule:
58-
def __init__(self, content: str, deps: list = None):
58+
def __init__(self, content: str, deps: List[str]):
5959
self.content = content
60-
self.deps = deps or []
60+
self.deps = deps
6161

6262
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
6363

@@ -118,7 +118,7 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
118118

119119
def _format_literal(self, literal):
120120
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
121-
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
121+
lambda m: GRAMMAR_LITERAL_ESCAPES[m.group(0)], literal
122122
)
123123
return f'"{escaped}"'
124124

@@ -157,13 +157,13 @@ def _add_rule(self, name, rule):
157157
self._rules[key] = rule
158158
return key
159159

160-
def resolve_refs(self, schema: dict, url: str):
160+
def resolve_refs(self, schema: Any, url: str):
161161
'''
162162
Resolves all $ref fields in the given schema, fetching any remote schemas,
163163
replacing $ref with absolute reference URL and populating self._refs with the
164164
respective referenced (sub)schema dictionaries.
165165
'''
166-
def visit(n: dict):
166+
def visit(n: Any):
167167
if isinstance(n, list):
168168
return [visit(x) for x in n]
169169
elif isinstance(n, dict):
@@ -223,7 +223,7 @@ def _visit_pattern(self, pattern, name):
223223

224224
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
225225
pattern = pattern[1:-1]
226-
sub_rule_ids = {}
226+
sub_rule_ids: Dict[str, str] = {}
227227

228228
i = 0
229229
length = len(pattern)

examples/openai/api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC
2-
from typing import Any, Dict, Literal, Optional, Union
2+
from typing import Any, Dict, List, Literal, Optional, Union
33
from pydantic import BaseModel, Json, TypeAdapter
44

55
class FunctionCall(BaseModel):
@@ -16,7 +16,7 @@ class Message(BaseModel):
1616
name: Optional[str] = None
1717
tool_call_id: Optional[str] = None
1818
content: Optional[str]
19-
tool_calls: Optional[list[ToolCall]] = None
19+
tool_calls: Optional[List[ToolCall]] = None
2020

2121
class ToolFunction(BaseModel):
2222
name: str
@@ -29,7 +29,7 @@ class Tool(BaseModel):
2929

3030
class ResponseFormat(BaseModel):
3131
type: Literal["json_object"]
32-
schema: Optional[Dict] = None
32+
schema: Optional[Json[Any]] = None # type: ignore
3333

3434
class LlamaCppParams(BaseModel):
3535
n_predict: Optional[int] = None
@@ -56,8 +56,8 @@ class LlamaCppParams(BaseModel):
5656

5757
class ChatCompletionRequest(LlamaCppParams):
5858
model: str
59-
tools: Optional[list[Tool]] = None
60-
messages: list[Message] = None
59+
tools: Optional[List[Tool]] = None
60+
messages: Optional[List[Message]] = None
6161
prompt: Optional[str] = None
6262
response_format: Optional[ResponseFormat] = None
6363

@@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
6767
class Choice(BaseModel):
6868
index: int
6969
message: Message
70-
logprobs: Optional[Json] = None
70+
logprobs: Optional[Json[Any]] = None
7171
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
7272

7373
class Usage(BaseModel):
@@ -84,7 +84,7 @@ class ChatCompletionResponse(BaseModel):
8484
object: Literal["chat.completion"]
8585
created: int
8686
model: str
87-
choices: list[Choice]
87+
choices: List[Choice]
8888
usage: Usage
8989
system_fingerprint: str
9090
error: Optional[CompletionError] = None

examples/openai/llama_cpp_server_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22
from pydantic import Json
33

44
from examples.openai.api import LlamaCppParams
@@ -9,4 +9,4 @@ class LlamaCppServerCompletionRequest(LlamaCppParams):
99
cache_prompt: Optional[bool] = None
1010

1111
grammar: Optional[str] = None
12-
json_schema: Optional[Json] = None
12+
json_schema: Optional[Json[Any]] = None

0 commit comments

Comments
 (0)