Skip to content

Commit 684bbb6

Browse files
feat(message): add MessageMetadata TypedDict for token/cost tracking (#943)
* feat(message): add MessageMetadata TypedDict for token/cost tracking Per Erik's review on PR #939, adding Message.metadata field to store: - Token usage (input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens) - Cost in USD - Model used for the response Design decisions: - TypedDict with total=False for type safety + optional fields - Only non-None metadata is serialized (compact JSONL storage) - Full support for JSON/JSONL and TOML roundtrips Part of the cost awareness feature. Follow-up to PR #939. 🤖 Generated with [gptme](https://github.com/gptme/gptme) * refactor(message): restructure MessageMetadata to use nested tokens format Address Erik's review: https://github.com/gptme/gptme/pull/943/files#r2603233028 Changes: - Replace flat token fields with nested tokens structure - tokens.input can be dict (base, cache_read, cache_write) or int - tokens.output is int - Add TokensInput and Tokens TypedDicts - Add helper functions for TOML serialization of nested dicts - Update tests to use new structure with proper type narrowing New format: { "model": "claude-sonnet", "tokens": { "input": {"base": 100, "cache_read": 80}, "output": 50 }, "cost": 0.005 } * refactor(message): use flat token format per review feedback Per Erik's feedback: #943 (comment) Changes: - Remove nested TokensInput and Tokens TypedDicts - Flatten MessageMetadata to use: input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens (matches cost_tracker) - Simplify _format_toml_value (no longer needs dict handling) - Update tests for new flat format This aligns with cost_tracker.py and common industry conventions. * feat(llm): integrate MessageMetadata into message generation Per Erik's review: integrate metadata into gptme/llm/ where messages are constructed/generated. Changes: - Modify _record_usage() in both providers to return MessageMetadata - Update chat() functions to return tuple[str, MessageMetadata | None] - Update stream() generators to return metadata via generator return value - Add _StreamWithMetadata wrapper to capture generator return values - Update _reply_stream() to attach metadata to returned Messages - Update _chat_complete() to handle tuple returns and propagate metadata - Fix callers in util/prompt.py and util/auto_naming.py The MessageMetadata now flows from provider usage tracking through to the final Message object, enabling token/cost tracking per message. * fix: update callers of _chat_complete for tuple return type After integrating MessageMetadata into _chat_complete (returning tuple[str, MessageMetadata | None]), several callers were not updated: - gptme/tools/morph.py: Unpack tuple to get string for .strip() - gptme/hooks/form_autodetect.py: Unpack tuple for re.search/json.loads - gptme/server/api_v2_sessions.py: Handle non-streaming case by wrapping response in list for iteration compatibility All three files now correctly unpack the tuple return value. Co-authored-by: Bob <bob@superuserlabs.org> * fix(test): update mock to match new _chat_complete tuple return format The test mock was returning the old format [["response"]] but _chat_complete now returns (str, MessageMetadata | None) tuple.
1 parent ed409cb commit 684bbb6

File tree

11 files changed

+268
-50
lines changed

11 files changed

+268
-50
lines changed

gptme/hooks/form_autodetect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _parse_options_with_llm(content: str) -> dict | None:
123123
"user", PARSE_PROMPT.format(message=content[:2000])
124124
) # Limit context
125125
]
126-
response = _chat_complete(messages, model=model, tools=None)
126+
response, _metadata = _chat_complete(messages, model=model, tools=None)
127127

128128
# Parse JSON from response
129129
import json

gptme/llm/__init__.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import shutil
33
import sys
44
import time
5-
from collections.abc import Iterator
5+
from collections.abc import Generator, Iterator
66
from functools import lru_cache
77
from pathlib import Path
88
from typing import cast
@@ -11,7 +11,7 @@
1111

1212
from ..config import Config, get_config
1313
from ..constants import prompt_assistant
14-
from ..message import Message, format_msgs, len_tokens
14+
from ..message import Message, MessageMetadata, format_msgs, len_tokens
1515
from ..telemetry import trace_function
1616
from ..tools import ToolSpec, ToolUse
1717
from ..util import console
@@ -99,12 +99,12 @@ def reply(
9999
)
100100
else:
101101
rprint(f"{prompt_assistant(agent_name)}: Thinking...", end="\r")
102-
response = _chat_complete(
102+
response, metadata = _chat_complete(
103103
generation_msgs, model, tools, output_schema=output_schema
104104
)
105105
rprint(" " * shutil.get_terminal_size().columns, end="\r")
106106
rprint(f"{prompt_assistant(agent_name)}: {response}")
107-
return Message("assistant", response)
107+
return Message("assistant", response, metadata=metadata)
108108

109109

110110
def get_provider_from_model(model: str) -> Provider:
@@ -141,7 +141,7 @@ def _chat_complete(
141141
tools: list[ToolSpec] | None,
142142
output_schema: type | None = None,
143143
max_retries: int = 3,
144-
) -> str:
144+
) -> tuple[str, MessageMetadata | None]:
145145
from pydantic import BaseModel, ValidationError
146146

147147
provider = get_provider_from_model(model)
@@ -156,14 +156,17 @@ def _chat_complete(
156156
)
157157

158158
# Validation-only fallback for unsupported providers
159+
metadata: MessageMetadata | None = None
159160
if output_schema is not None:
160161
logger = logging.getLogger(__name__)
161162
for attempt in range(max_retries):
162163
# Generate without constraints
163164
if provider in PROVIDERS_OPENAI:
164-
response = chat_openai(messages, model, tools)
165+
response, metadata = chat_openai(messages, model, tools)
165166
elif provider == "anthropic":
166-
response = chat_anthropic(messages, _get_base_model(model), tools)
167+
response, metadata = chat_anthropic(
168+
messages, _get_base_model(model), tools
169+
)
167170
else:
168171
raise ValueError(f"Unsupported provider: {provider}")
169172

@@ -173,7 +176,7 @@ def _chat_complete(
173176
output_schema, BaseModel
174177
):
175178
output_schema.model_validate_json(response)
176-
return response # Validation succeeded
179+
return response, metadata # Validation succeeded
177180
except ValidationError as e:
178181
if attempt < max_retries - 1:
179182
# Add validation error to context for retry
@@ -191,7 +194,7 @@ def _chat_complete(
191194
logger.warning(
192195
f"Failed to validate response after {max_retries} attempts: {e}"
193196
)
194-
return response
197+
return response, metadata
195198

196199
# No schema requested, generate normally
197200
if provider in PROVIDERS_OPENAI:
@@ -202,21 +205,44 @@ def _chat_complete(
202205
raise ValueError(f"Unsupported provider: {provider}")
203206

204207

208+
class _StreamWithMetadata:
209+
"""Wrapper that captures a generator's return value (metadata)."""
210+
211+
def __init__(self, gen: Generator[str, None, MessageMetadata | None], model: str):
212+
self.gen = gen
213+
self.model = model
214+
self.metadata: MessageMetadata | None = None
215+
216+
def __iter__(self) -> Iterator[str]:
217+
try:
218+
while True:
219+
yield next(self.gen)
220+
except StopIteration as e:
221+
self.metadata = e.value
222+
# Ensure model is set in metadata even if provider didn't include it
223+
if self.metadata is None:
224+
self.metadata = {"model": self.model}
225+
elif "model" not in self.metadata:
226+
self.metadata["model"] = self.model
227+
228+
205229
@trace_function(name="llm.stream", attributes={"component": "llm"})
206230
def _stream(
207231
messages: list[Message],
208232
model: str,
209233
tools: list[ToolSpec] | None,
210234
output_schema: type | None = None,
211-
) -> Iterator[str]:
235+
) -> _StreamWithMetadata:
212236
provider = get_provider_from_model(model)
213237
# Custom providers are OpenAI-compatible, so route them through the OpenAI path
214238
if provider in PROVIDERS_OPENAI or is_custom_provider(provider):
215-
return stream_openai(messages, model, tools, output_schema=output_schema)
239+
gen = stream_openai(messages, model, tools, output_schema=output_schema)
240+
return _StreamWithMetadata(gen, model)
216241
elif provider == "anthropic":
217-
return stream_anthropic(
242+
gen = stream_anthropic(
218243
messages, _get_base_model(model), tools, output_schema=output_schema
219244
)
245+
return _StreamWithMetadata(gen, model)
220246
else:
221247
# Note: Validation-only fallback for streaming is complex
222248
# For now, unsupported providers don't support output_schema in streaming mode
@@ -247,12 +273,12 @@ def print_clear(length: int = 0):
247273
start_time = time.time()
248274
first_token_time = None
249275
are_thinking = False
276+
277+
# Create stream wrapper to capture metadata
278+
stream = _stream(messages, model, tools, output_schema=output_schema)
279+
250280
try:
251-
for char in (
252-
char
253-
for chunk in _stream(messages, model, tools, output_schema=output_schema)
254-
for char in chunk
255-
):
281+
for char in (char for chunk in stream for char in chunk):
256282
if not output: # first character
257283
first_token_time = time.time()
258284
print_clear()
@@ -310,7 +336,9 @@ def print_clear(length: int = 0):
310336
break
311337

312338
except KeyboardInterrupt:
313-
return Message("assistant", output + "... ^C Interrupted")
339+
return Message(
340+
"assistant", output + "... ^C Interrupted", metadata=stream.metadata
341+
)
314342
finally:
315343
print_clear()
316344
if first_token_time:
@@ -322,7 +350,7 @@ def print_clear(length: int = 0):
322350
f"tok/s: {len_tokens(output, model)/(end_time - first_token_time):.1f})"
323351
)
324352

325-
return Message("assistant", output)
353+
return Message("assistant", output, metadata=stream.metadata)
326354

327355

328356
@trace_function(name="llm.summarize", attributes={"component": "llm"})
@@ -349,7 +377,7 @@ def _summarize_str(content: str) -> str:
349377
f"Cannot summarize more than {model.context} tokens, got {len_tokens(messages, model.model)}"
350378
)
351379

352-
summary = _chat_complete(messages, model.full, None)
380+
summary, _metadata = _chat_complete(messages, model.full, None)
353381
assert summary
354382
logger.debug(
355383
f"Summarized long output ({len_tokens(content, model.model)} -> {len_tokens(summary, model.model)} tokens): "

gptme/llm/llm_anthropic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel # fmt: skip
1717

1818
from ..constants import TEMPERATURE, TOP_P
19-
from ..message import Message, msgs2dicts
19+
from ..message import Message, MessageMetadata, msgs2dicts
2020
from ..telemetry import record_llm_request
2121
from ..tools.base import ToolSpec
2222
from .models import ModelMeta, get_model
@@ -95,8 +95,8 @@ def _extract_schema_result(content_blocks):
9595
def _record_usage(
9696
usage: Union["anthropic.types.Usage", "anthropic.types.MessageDeltaUsage"],
9797
model: str,
98-
) -> None:
99-
"""Record usage metrics as telemetry."""
98+
) -> MessageMetadata | None:
99+
"""Record usage metrics as telemetry and return MessageMetadata."""
100100
if not usage:
101101
return None
102102

@@ -125,6 +125,18 @@ def _record_usage(
125125
total_tokens=total_tokens if total_tokens > 0 else None,
126126
)
127127

128+
# Return MessageMetadata for attachment to Message
129+
metadata: MessageMetadata = {"model": model}
130+
if input_tokens is not None:
131+
metadata["input_tokens"] = input_tokens
132+
if output_tokens is not None:
133+
metadata["output_tokens"] = output_tokens
134+
if cache_read_tokens is not None:
135+
metadata["cache_read_tokens"] = cache_read_tokens
136+
if cache_creation_tokens is not None:
137+
metadata["cache_creation_tokens"] = cache_creation_tokens
138+
return metadata
139+
128140

129141
def _should_use_thinking(model_meta: ModelMeta, tools: list[ToolSpec] | None) -> bool:
130142
# Support environment variable to override reasoning behavior
@@ -305,7 +317,7 @@ def chat(
305317
model: str,
306318
tools: list[ToolSpec] | None,
307319
output_schema: type[BaseModel] | None = None,
308-
) -> str:
320+
) -> tuple[str, MessageMetadata | None]:
309321
from anthropic import NOT_GIVEN # fmt: skip
310322

311323
assert _anthropic, "LLM not initialized"
@@ -361,7 +373,7 @@ def chat(
361373
timeout=60,
362374
)
363375
content = response.content
364-
_record_usage(response.usage, model)
376+
metadata = _record_usage(response.usage, model)
365377

366378
parsed_block = []
367379
for block in content:
@@ -374,7 +386,7 @@ def chat(
374386
else:
375387
logger.warning("Unknown block: %s", str(block))
376388

377-
return "\n".join(parsed_block)
389+
return "\n".join(parsed_block), metadata
378390

379391

380392
@retry_generator_on_overloaded()
@@ -383,10 +395,13 @@ def stream(
383395
model: str,
384396
tools: list[ToolSpec] | None,
385397
output_schema: type[BaseModel] | None = None,
386-
) -> Generator[str, None, None]:
398+
) -> Generator[str, None, MessageMetadata | None]:
387399
import anthropic.types # fmt: skip
388400
from anthropic import NOT_GIVEN # fmt: skip
389401

402+
# Variable to capture metadata from usage recording
403+
captured_metadata: MessageMetadata | None = None
404+
390405
assert _anthropic, "LLM not initialized"
391406
messages_dicts, system_messages, tools_dict = _prepare_messages_for_api(
392407
messages, tools
@@ -493,13 +508,17 @@ def stream(
493508
case "message_delta":
494509
chunk = cast(anthropic.types.MessageDeltaEvent, chunk)
495510
# Record usage from message_delta which contains the final/cumulative usage
496-
_record_usage(chunk.usage, model)
511+
# and capture metadata for message attachment
512+
captured_metadata = _record_usage(chunk.usage, model)
497513
case "message_stop":
498514
pass
499515
case _:
500516
# print(f"Unknown chunk type: {chunk.type}")
501517
pass
502518

519+
# Return the captured metadata (accessible via StopIteration.value)
520+
return captured_metadata
521+
503522

504523
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
505524
for message in message_dicts:

gptme/llm/llm_openai.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..config import Config, get_config
1010
from ..constants import TEMPERATURE, TOP_P
11-
from ..message import Message, msgs2dicts
11+
from ..message import Message, MessageMetadata, msgs2dicts
1212
from ..telemetry import record_llm_request
1313
from ..tools import ToolSpec
1414
from .models import ModelMeta, Provider, is_custom_provider
@@ -36,10 +36,10 @@
3636
}
3737

3838

39-
def _record_usage(usage, model: str) -> None:
40-
"""Record usage metrics as telemetry."""
39+
def _record_usage(usage, model: str) -> MessageMetadata | None:
40+
"""Record usage metrics as telemetry and return MessageMetadata."""
4141
if not usage:
42-
return
42+
return None
4343

4444
# Extract token counts (OpenAI uses different field names than Anthropic)
4545
prompt_tokens = getattr(usage, "prompt_tokens", None)
@@ -67,6 +67,16 @@ def _record_usage(usage, model: str) -> None:
6767
total_tokens=total_tokens,
6868
)
6969

70+
# Return MessageMetadata for attachment to Message
71+
metadata: MessageMetadata = {"model": model}
72+
if input_tokens is not None:
73+
metadata["input_tokens"] = input_tokens
74+
if output_tokens is not None:
75+
metadata["output_tokens"] = output_tokens
76+
if cache_read_tokens is not None:
77+
metadata["cache_read_tokens"] = cache_read_tokens
78+
return metadata
79+
7080

7181
# TODO: improve provider routing for openrouter: https://openrouter.ai/docs/provider-routing
7282
# TODO: set required-parameters: https://openrouter.ai/docs/provider-routing#required-parameters-_beta_
@@ -260,7 +270,7 @@ def chat(
260270
model: str,
261271
tools: list[ToolSpec] | None,
262272
output_schema=None,
263-
) -> str:
273+
) -> tuple[str, MessageMetadata | None]:
264274
# This will generate code and such, so we need appropriate temperature and top_p params
265275
# top_p controls diversity, temperature controls randomness
266276

@@ -294,7 +304,7 @@ def chat(
294304
extra_headers=extra_headers(provider),
295305
extra_body=extra_body(provider, model_meta),
296306
)
297-
_record_usage(response.usage, model)
307+
metadata = _record_usage(response.usage, model)
298308
choice = response.choices[0]
299309
result = []
300310
if choice.finish_reason == "tool_calls":
@@ -313,7 +323,7 @@ def chat(
313323
result.append(choice.message.content)
314324

315325
assert result
316-
return "\n".join(result)
326+
return "\n".join(result), metadata
317327

318328

319329
def extra_headers(provider: Provider) -> dict[str, str]:
@@ -345,10 +355,13 @@ def stream(
345355
model: str,
346356
tools: list[ToolSpec] | None,
347357
output_schema=None,
348-
) -> Generator[str, None, None]:
358+
) -> Generator[str, None, MessageMetadata | None]:
349359
from . import _get_base_model, get_provider_from_model # fmt: skip
350360
from .models import get_model # fmt: skip
351361

362+
# Variable to capture metadata from usage recording
363+
captured_metadata: MessageMetadata | None = None
364+
352365
provider = get_provider_from_model(model)
353366
client = get_client(provider)
354367
is_proxy = _is_proxy(client)
@@ -389,8 +402,9 @@ def stream(
389402
chunk = cast(ChatCompletionChunk, chunk_raw)
390403

391404
# Record usage if available (typically in final chunk)
405+
# and capture metadata for message attachment
392406
if hasattr(chunk, "usage") and chunk.usage:
393-
_record_usage(chunk.usage, model)
407+
captured_metadata = _record_usage(chunk.usage, model)
394408

395409
if not chunk.choices:
396410
continue
@@ -441,6 +455,9 @@ def stream(
441455

442456
logger.debug(f"Stop reason: {stop_reason}")
443457

458+
# Return the captured metadata (accessible via StopIteration.value)
459+
return captured_metadata
460+
444461

445462
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
446463
for message in message_dicts:

0 commit comments

Comments
 (0)