Skip to content

Commit 10ddeb6

Browse files
fix(tools): preserve full type information in tool signatures (#865)
* fix(tools): preserve full type information in tool signatures - Add support for typing.Union (not just types.UnionType) - Handle generic types with arguments (list[int], dict[str, int], etc.) - Special case for NoneType to display as 'None' - Add comprehensive test cases for generic and union types Fixes #349 * fix(tests): use modern union syntax for type annotations Use list[int] | None instead of Optional[list[int]] to satisfy ruff UP007 linting rule. Co-authored-by: Bob <bob@superuserlabs.org>
1 parent 58cce70 commit 10ddeb6

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

gptme/tools/base.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
Literal,
1818
Protocol,
1919
TypeAlias,
20+
Union,
2021
cast,
22+
get_args,
2123
get_origin,
2224
)
2325

@@ -129,14 +131,31 @@ class Parameter:
129131

130132
# TODO: there must be a better way?
131133
def derive_type(t) -> str:
132-
if get_origin(t) == Literal:
133-
v = ", ".join(f'"{a}"' for a in t.__args__)
134+
origin = get_origin(t)
135+
136+
# Handle Literal types
137+
if origin == Literal:
138+
v = ", ".join(f'"{a}"' for a in get_args(t))
134139
return f"Literal[{v}]"
135-
elif get_origin(t) == types.UnionType:
136-
v = ", ".join(derive_type(a) for a in t.__args__)
140+
141+
# Handle Union types (both typing.Union and types.UnionType)
142+
if origin == Union or origin == types.UnionType:
143+
v = ", ".join(derive_type(a) for a in get_args(t))
137144
return f"Union[{v}]"
138-
else:
139-
return t.__name__
145+
146+
# Handle other generic types (list[int], dict[str, int], etc.)
147+
if origin is not None:
148+
args = get_args(t)
149+
if args:
150+
type_args = ", ".join(derive_type(arg) for arg in args)
151+
return f"{origin.__name__}[{type_args}]"
152+
153+
# Special case for NoneType
154+
if t is type(None):
155+
return "None"
156+
157+
# Fallback to type name
158+
return t.__name__
140159

141160

142161
def callable_signature(func: Callable) -> str:

tests/test_tools_python.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,26 @@ def h(a: TestType) -> str:
4444
return str(a)
4545

4646
assert callable_signature(h) == 'h(a: Literal["a", "b"]) -> str'
47+
48+
# Test generic types
49+
50+
def i(a: list[int]) -> str:
51+
return str(a)
52+
53+
assert callable_signature(i) == "i(a: list[int]) -> str"
54+
55+
def j(a: list[int] | None) -> str:
56+
return str(a)
57+
58+
assert callable_signature(j) == "j(a: Union[list[int], None]) -> str"
59+
60+
def k(a: dict[str, int]) -> str:
61+
return str(a)
62+
63+
assert callable_signature(k) == "k(a: dict[str, int]) -> str"
64+
65+
# Test union types with | syntax
66+
def m(a: int | str) -> str:
67+
return str(a)
68+
69+
assert callable_signature(m) == "m(a: Union[int, str]) -> str"

0 commit comments

Comments
 (0)