Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,9 +1682,7 @@ async def create_chat_completion(self, request: Request) -> Response:

model_family = desc.get("model_family", "")
function_call_models = (
["chatglm3", "gorilla-openfunctions-v1"]
+ QWEN_TOOL_CALL_FAMILY
+ GLM4_TOOL_CALL_FAMILY
["gorilla-openfunctions-v1"] + QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY
)

is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
Expand Down
48 changes: 0 additions & 48 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,54 +465,6 @@ def test_restful_api_for_tool_calls(setup, model_format, quantization):
response_data = response.json()
assert len(response_data["data"]) == 1

# glm4-chat fail response: 好的,请告诉我您希望使用的温度单位是摄氏度还是华氏度?
if "glm4" not in model_name:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "获取当前天气",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "城市,例如北京"},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "使用的温度单位。从所在的城市进行推断。",
},
},
"required": ["location", "format"],
},
},
}
]

url = f"{endpoint}/v1/chat/completions"
payload = {
"model": model_uid_res,
"messages": [
{"role": "system", "content": "你是一个有用的助手。不要对要函数调用的值做出假设。"},
{"role": "user", "content": "上海现在的天气怎么样?"},
],
"temperature": 0.7,
"tools": tools,
"stop": ["\n"],
}
response = requests.post(url, json=payload)
completion = response.json()

assert (
"get_current_weather"
== completion["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
), completion
arguments = completion["choices"][0]["message"]["tool_calls"][0]["function"][
"arguments"
]
arg = json.loads(arguments)
assert arg == {"location": "上海", "format": "celsius"}

# tool
tools = [
{
Expand Down
4 changes: 2 additions & 2 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@
"none"
],
"model_id": "THUDM/glm-4-9b-chat",
"model_revision": "76f3474a854145aa4a9ed2612fee9bc8d4a8966b"
"model_revision": "aae8bd74af5c6dff63a49d7fbdcc89349ebf87aa"
},
{
"model_format": "ggufv2",
Expand Down Expand Up @@ -890,7 +890,7 @@
"none"
],
"model_id": "THUDM/glm-4-9b-chat-1m",
"model_revision": "715ddbe91082f976ff6a4ca06d59e5bbff6c3642"
"model_revision": "0aa722c7e0745dd21453427dd44c257dd253304f"
},
{
"model_format": "ggufv2",
Expand Down
19 changes: 6 additions & 13 deletions xinference/model/llm/pytorch/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _get_generate_args(
return kwargs, tools

@torch.inference_mode()
def stream_chat(
def _stream_chat(
self,
tokenizer,
query: str,
Expand Down Expand Up @@ -399,7 +399,7 @@ def stream_chat(
yield new_response, new_history

@torch.inference_mode()
def non_stream_chat(
def _non_stream_chat(
self,
tokenizer,
query: str,
Expand Down Expand Up @@ -475,10 +475,6 @@ def chat(
if stream and (
not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
):
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
stream_chat = self.stream_chat
else:
stream_chat = self._model.stream_chat

def _stream_generator():
last_chunk_text_length = 0
Expand All @@ -487,7 +483,7 @@ def _stream_generator():
inputs = self._tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self._model.device)
prompt_tokens = len(inputs["input_ids"][0])
for chunk_text, _ in stream_chat(
for chunk_text, _ in self._stream_chat(
self._tokenizer, prompt, chat_history, **kwargs
):
if tools and isinstance(chunk_text, dict):
Expand Down Expand Up @@ -548,12 +544,9 @@ def _stream_generator():

return self._to_chat_completion_chunks(_stream_generator())
else:
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
chat = self.non_stream_chat
else:
chat = self._model.chat

response = chat(self._tokenizer, prompt, chat_history, **kwargs)
response = self._non_stream_chat(
self._tokenizer, prompt, chat_history, **kwargs
)
if tools:
return self._tool_calls_completion(
self.model_family, self.model_uid, response, tools
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def _eval_tool_arguments(cls, model_family, c, tools):
family = model_family.model_family or model_family.model_name
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
elif family in ["chatglm3"] + GLM4_TOOL_CALL_FAMILY:
elif family in GLM4_TOOL_CALL_FAMILY:
content, func, args = cls._eval_glm_chat_arguments(c, tools)
elif family in QWEN_TOOL_CALL_FAMILY:
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
Expand Down