forked from ollama/ollama
-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_test.go
123 lines (109 loc) · 3.6 KB
/
llm_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//go:build integration
package server
import (
"context"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
//
// TODO - Fix this ^^
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "tell me a short story about agi?",
Options: map[string]interface{}{},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Options: map[string]interface{}{},
},
}
resp = [2]string{
"once upon a time",
"united states thanksgiving",
}
)
func TestIntegrationSimpleOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
workDir, err := os.MkdirTemp("", "ollama")
require.NoError(t, err)
defer os.RemoveAll(workDir)
require.NoError(t, llm.Init(workDir))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
assert.Contains(t, strings.ToLower(response), resp[0])
}
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
t.Skip("concurrent prediction on single runner not currently supported")
workDir, err := os.MkdirTemp("", "ollama")
require.NoError(t, err)
defer os.RemoveAll(workDir)
require.NoError(t, llm.Init(workDir))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
var wg sync.WaitGroup
wg.Add(len(req))
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
workDir, err := os.MkdirTemp("", "ollama")
require.NoError(t, err)
defer os.RemoveAll(workDir)
require.NoError(t, llm.Init(workDir))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
var wg sync.WaitGroup
wg.Add(len(req))
t.Logf("Running %d concurrently", len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency