Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions react-sdk/src/__tests__/advance-stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { advanceStreamWithBuffer } from "../util/advance-stream";
import { ReadableStream as NodeReadableStream } from "stream/web";

const StreamCtor: typeof ReadableStream =
typeof globalThis.ReadableStream !== "undefined"
? globalThis.ReadableStream
: (NodeReadableStream as unknown as typeof ReadableStream);

const textEncoder = new TextEncoder();

function createStreamFromChunks(chunks: string[]): ReadableStream<Uint8Array> {
return new StreamCtor<Uint8Array>({
start(controller) {
for (const chunk of chunks) {
controller.enqueue(textEncoder.encode(chunk));
}
controller.close();
},
});
}

function createSSEResponse(chunks: string[]): Response {
const stream = createStreamFromChunks(chunks);
return {
body: stream,
} as unknown as Response;
}

describe("advanceStreamWithBuffer", () => {
it("parses JSON payload split across multiple SSE chunks", async () => {
const response = createSSEResponse([
'data: {"message":',
' "hello"}\n',
"\n",
]);

const client = {
post: jest.fn().mockReturnValue({
asResponse: jest.fn().mockResolvedValue(response),
}),
};

const iterable = await advanceStreamWithBuffer(client as any, {} as any);
const received: Record<string, unknown>[] = [];
for await (const item of iterable) {
received.push(item as unknown as Record<string, unknown>);
}

expect(received).toEqual([{ message: "hello" }]);
});

it("throws when JSON cannot be parsed after retries", async () => {
const response = createSSEResponse(["data: not-json\n", "\n"]);

const client = {
post: jest.fn().mockReturnValue({
asResponse: jest.fn().mockResolvedValue(response),
}),
};

const iterable = await advanceStreamWithBuffer(client as any, {} as any);

await expect(async () => {
for await (const _ of iterable) {
// Consume iterator
}
}).rejects.toThrow("Failed to parse JSON after multiple chunks.");
});
});
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import TamboAI, { advanceStream } from "@tambo-ai/typescript-sdk";
import TamboAI from "@tambo-ai/typescript-sdk";
import { act, renderHook } from "@testing-library/react";
import React from "react";
import { DeepPartial } from "ts-essentials";
Expand All @@ -10,6 +10,7 @@ import { useTamboClient, useTamboQueryClient } from "../tambo-client-provider";
import { TamboContextHelpersProvider } from "../tambo-context-helpers-provider";
import { TamboRegistryProvider } from "../tambo-registry-provider";
import { TamboThreadProvider, useTamboThread } from "../tambo-thread-provider";
import { advanceStreamWithBuffer } from "../../util/advance-stream";

type PartialTamboAI = DeepPartial<TamboAI>;

Expand All @@ -25,8 +26,8 @@ jest.mock("../tambo-client-provider", () => ({
useTamboClient: jest.fn(),
useTamboQueryClient: jest.fn(),
}));
jest.mock("@tambo-ai/typescript-sdk", () => ({
advanceStream: jest.fn(),
jest.mock("../../util/advance-stream", () => ({
advanceStreamWithBuffer: jest.fn(),
}));

// Test utilities
Expand Down Expand Up @@ -80,19 +81,21 @@ describe("TamboThreadProvider with initial messages", () => {
invalidateQueries: jest.fn(),
};
(useTamboQueryClient as jest.Mock).mockReturnValue(mockQueryClient);
(advanceStream as jest.Mock).mockImplementation(async function* () {
yield {
responseMessageDto: {
id: "response-1",
role: "assistant",
content: [{ type: "text", text: "Hello back!" }],
threadId: "new-thread-id",
componentState: {},
createdAt: new Date().toISOString(),
},
generationStage: GenerationStage.COMPLETE,
};
});
(advanceStreamWithBuffer as jest.Mock).mockImplementation(
async function* () {
yield {
responseMessageDto: {
id: "response-1",
role: "assistant",
content: [{ type: "text", text: "Hello back!" }],
threadId: "new-thread-id",
componentState: {},
createdAt: new Date().toISOString(),
},
generationStage: GenerationStage.COMPLETE,
};
},
);
});

it("should initialize with empty messages when no initial messages provided", () => {
Expand Down Expand Up @@ -145,8 +148,8 @@ describe("TamboThreadProvider with initial messages", () => {
await result.current.sendThreadMessage("Test message");
});

// Check that advanceStream was called with initial messages
expect(advanceStream).toHaveBeenCalledWith(
// Check that advanceStreamWithBuffer was called with initial messages
expect(advanceStreamWithBuffer).toHaveBeenCalledWith(
mockClient,
expect.objectContaining({
initialMessages: [
Expand Down Expand Up @@ -183,8 +186,8 @@ describe("TamboThreadProvider with initial messages", () => {
await result.current.sendThreadMessage("Test message");
});

// Check that advanceStream was called without initial messages
expect(advanceStream).toHaveBeenCalledWith(
// Check that advanceStreamWithBuffer was called without initial messages
expect(advanceStreamWithBuffer).toHaveBeenCalledWith(
mockClient,
expect.not.objectContaining({
initialMessages: expect.anything(),
Expand Down
45 changes: 23 additions & 22 deletions react-sdk/src/providers/__tests__/tambo-thread-provider.test.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import TamboAI, { advanceStream } from "@tambo-ai/typescript-sdk";
import TamboAI from "@tambo-ai/typescript-sdk";
import { QueryClient } from "@tanstack/react-query";
import { act, renderHook } from "@testing-library/react";
import React from "react";
Expand All @@ -14,6 +14,7 @@ import { useTamboClient, useTamboQueryClient } from "../tambo-client-provider";
import { TamboContextHelpersProvider } from "../tambo-context-helpers-provider";
import { TamboRegistryProvider } from "../tambo-registry-provider";
import { TamboThreadProvider, useTamboThread } from "../tambo-thread-provider";
import { advanceStreamWithBuffer } from "../../util/advance-stream";

type PartialTamboAI = DeepPartial<TamboAI>;

Expand All @@ -29,8 +30,8 @@ jest.mock("../tambo-client-provider", () => ({
useTamboClient: jest.fn(),
useTamboQueryClient: jest.fn(),
}));
jest.mock("@tambo-ai/typescript-sdk", () => ({
advanceStream: jest.fn(),
jest.mock("../../util/advance-stream", () => ({
advanceStreamWithBuffer: jest.fn(),
}));

// Mock the getCustomContext
Expand Down Expand Up @@ -341,8 +342,8 @@ describe("TamboThreadProvider", () => {
},
};

// Mock advanceStream to return our async iterator
jest.mocked(advanceStream).mockResolvedValue(mockAsyncIterator);
// Mock advanceStreamWithBuffer to return our async iterator
jest.mocked(advanceStreamWithBuffer).mockResolvedValue(mockAsyncIterator);

const { result } = renderHook(() => useTamboThread(), { wrapper });

Expand Down Expand Up @@ -536,7 +537,7 @@ describe("TamboThreadProvider", () => {
});

describe("streaming behavior", () => {
it("should call advanceStream when streamResponse=true", async () => {
it("should call advanceStreamWithBuffer when streamResponse=true", async () => {
// Use wrapper with streaming=true to show that explicit streamResponse=true works
const wrapperWithStreaming = ({
children,
Expand Down Expand Up @@ -577,7 +578,7 @@ describe("TamboThreadProvider", () => {
},
};

jest.mocked(advanceStream).mockResolvedValue(mockAsyncIterator);
jest.mocked(advanceStreamWithBuffer).mockResolvedValue(mockAsyncIterator);

const { result } = renderHook(() => useTamboThread(), {
wrapper: wrapperWithStreaming,
Expand All @@ -595,7 +596,7 @@ describe("TamboThreadProvider", () => {
});
});

expect(advanceStream).toHaveBeenCalledWith(
expect(advanceStreamWithBuffer).toHaveBeenCalledWith(
mockTamboAI,
{
messageToAppend: {
Expand Down Expand Up @@ -675,9 +676,9 @@ describe("TamboThreadProvider", () => {
toolCallCounts: {},
});

// Should not call advance or advanceStream
// Should not call advance or advanceStreamWithBuffer
expect(mockThreadsApi.advance).not.toHaveBeenCalled();
expect(advanceStream).not.toHaveBeenCalled();
expect(advanceStreamWithBuffer).not.toHaveBeenCalled();
});

it("should call advanceById when streamResponse is undefined and provider streaming=false", async () => {
Expand Down Expand Up @@ -734,12 +735,12 @@ describe("TamboThreadProvider", () => {
toolCallCounts: {},
});

// Should not call advance or advanceStream
// Should not call advance or advanceStreamWithBuffer
expect(mockThreadsApi.advance).not.toHaveBeenCalled();
expect(advanceStream).not.toHaveBeenCalled();
expect(advanceStreamWithBuffer).not.toHaveBeenCalled();
});

it("should call advanceStream when streamResponse is undefined and provider streaming=true (default)", async () => {
it("should call advanceStreamWithBuffer when streamResponse is undefined and provider streaming=true (default)", async () => {
// Use wrapper with streaming=true (default) to test that undefined streamResponse respects provider setting
const wrapperWithDefaultStreaming = ({
children,
Expand Down Expand Up @@ -778,7 +779,7 @@ describe("TamboThreadProvider", () => {
},
};

jest.mocked(advanceStream).mockResolvedValue(mockAsyncIterator);
jest.mocked(advanceStreamWithBuffer).mockResolvedValue(mockAsyncIterator);

const { result } = renderHook(() => useTamboThread(), {
wrapper: wrapperWithDefaultStreaming,
Expand All @@ -796,7 +797,7 @@ describe("TamboThreadProvider", () => {
});
});

expect(advanceStream).toHaveBeenCalledWith(
expect(advanceStreamWithBuffer).toHaveBeenCalledWith(
mockTamboAI,
{
messageToAppend: {
Expand Down Expand Up @@ -879,12 +880,12 @@ describe("TamboThreadProvider", () => {
toolCallCounts: {},
});

// Should not call advanceById or advanceStream
// Should not call advanceById or advanceStreamWithBuffer
expect(mockThreadsApi.advanceByID).not.toHaveBeenCalled();
expect(advanceStream).not.toHaveBeenCalled();
expect(advanceStreamWithBuffer).not.toHaveBeenCalled();
});

it("should call advanceStream when streamResponse=true for placeholder thread", async () => {
it("should call advanceStreamWithBuffer when streamResponse=true for placeholder thread", async () => {
// Use wrapper with streaming=false to show that explicit streamResponse=true overrides provider setting
const wrapperWithoutStreaming = ({
children,
Expand Down Expand Up @@ -925,7 +926,7 @@ describe("TamboThreadProvider", () => {
},
};

jest.mocked(advanceStream).mockResolvedValue(mockAsyncIterator);
jest.mocked(advanceStreamWithBuffer).mockResolvedValue(mockAsyncIterator);

const { result } = renderHook(() => useTamboThread(), {
wrapper: wrapperWithoutStreaming,
Expand All @@ -946,7 +947,7 @@ describe("TamboThreadProvider", () => {
});
});

expect(advanceStream).toHaveBeenCalledWith(
expect(advanceStreamWithBuffer).toHaveBeenCalledWith(
mockTamboAI,
{
messageToAppend: {
Expand Down Expand Up @@ -1000,8 +1001,8 @@ describe("TamboThreadProvider", () => {
it("should set generation stage to ERROR when streaming sendThreadMessage fails", async () => {
const testError = new Error("Streaming API call failed");

// Mock advanceStream to throw an error
jest.mocked(advanceStream).mockRejectedValue(testError);
// Mock advanceStreamWithBuffer to throw an error
jest.mocked(advanceStreamWithBuffer).mockRejectedValue(testError);

const { result } = renderHook(() => useTamboThread(), { wrapper });

Expand Down
7 changes: 4 additions & 3 deletions react-sdk/src/providers/tambo-thread-provider.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"use client";
import TamboAI, { advanceStream } from "@tambo-ai/typescript-sdk";
import TamboAI from "@tambo-ai/typescript-sdk";
import { Thread } from "@tambo-ai/typescript-sdk/resources/beta/threads/threads";
import React, {
createContext,
Expand All @@ -25,6 +25,7 @@ import {
getUnassociatedTools,
mapTamboToolToContextTool,
} from "../util/registry";
import { advanceStreamWithBuffer } from "../util/advance-stream";
import { handleToolCall } from "../util/tool-caller";
import { useTamboClient, useTamboQueryClient } from "./tambo-client-provider";
import { useTamboContextHelpers } from "./tambo-context-helpers-provider";
Expand Down Expand Up @@ -721,7 +722,7 @@ export const TamboThreadProvider: React.FC<
chunk.responseMessageDto.threadId,
GenerationStage.STREAMING_RESPONSE,
);
const toolCallResponseStream = await advanceStream(
const toolCallResponseStream = await advanceStreamWithBuffer(
client,
toolCallResponseParams,
chunk.responseMessageDto.threadId,
Expand Down Expand Up @@ -920,7 +921,7 @@ export const TamboThreadProvider: React.FC<
if (streamResponse) {
let advanceStreamResponse: AsyncIterable<TamboAI.Beta.Threads.ThreadAdvanceResponse>;
try {
advanceStreamResponse = await advanceStream(
advanceStreamResponse = await advanceStreamWithBuffer(
client,
params,
threadId === placeholderThread.id ? undefined : threadId,
Expand Down
Loading