-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathcache-redis.go
More file actions
294 lines (251 loc) · 7.63 KB
/
cache-redis.go
File metadata and controls
294 lines (251 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
package rdns
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"expvar"
"fmt"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/redis/go-redis/v9"
)
const (
// asyncWriteSemCapacity limits concurrent background Redis writes.
redisAsyncWriteSemCapacity = 256
)
type redisBackend struct {
client *redis.Client
opt RedisBackendOptions
asyncWriteSem chan struct{}
asyncSkipped *expvar.Int
}
type RedisBackendOptions struct {
RedisOptions redis.Options
KeyPrefix string
SyncSet bool // When true, perform Redis SET synchronously. Default is false (async writes).
}
var _ CacheBackend = (*redisBackend)(nil)
// Buffer pool for dns.Msg.PackBuffer to minimize allocations.
var packBufPool = sync.Pool{
New: func() any {
b := make([]byte, 0, 2048)
return &b
},
}
const (
binaryFormatVersion = 1
headerSize = 10
flagPrefetchBit = 1 << 0
)
// encodeCacheAnswer encodes a cacheAnswer into a compact binary format:
// - byte 0: version (1)
// - byte 1: flags (bit0: prefetchEligible)
// - bytes 2..9: timestamp (uint64 seconds from Unix epoch, big endian)
// - bytes 10..N: dns.Msg wire bytes
func encodeCacheAnswer(item *cacheAnswer) ([]byte, error) {
bufPtr := packBufPool.Get().(*[]byte)
buf := *bufPtr
defer func() {
*bufPtr = buf[:0]
packBufPool.Put(bufPtr)
}()
if cap(buf) == 0 {
buf = make([]byte, 0, 2048)
}
// Pack DNS message first into the scratch buffer
buf = buf[:cap(buf)]
dnsWire, err := item.Msg.PackBuffer(buf)
if err != nil {
return nil, fmt.Errorf("failed to pack DNS message: %w", err)
}
// Keep the (potentially grown) buffer for cleanup
buf = dnsWire
// Allocate result with header + DNS wire bytes
result := make([]byte, headerSize+len(dnsWire))
// Write header
result[0] = binaryFormatVersion
var flags byte
if item.PrefetchEligible {
flags |= flagPrefetchBit
}
result[1] = flags
timestamp := uint64(item.Timestamp.Unix())
binary.BigEndian.PutUint64(result[2:10], timestamp)
// Copy DNS wire bytes after header
copy(result[headerSize:], dnsWire)
return result, nil
}
// decodeCacheAnswer decodes a binary-encoded cacheAnswer.
// Returns an error if the format is invalid or unsupported.
func decodeCacheAnswer(b []byte) (*cacheAnswer, error) {
if len(b) < headerSize {
return nil, fmt.Errorf("binary data too short: %d bytes", len(b))
}
// Check version
version := b[0]
if version != binaryFormatVersion {
return nil, fmt.Errorf("unsupported binary format version: %d", version)
}
// Parse flags
flags := b[1]
prefetchEligible := (flags & flagPrefetchBit) != 0
// Parse timestamp
timestamp := int64(binary.BigEndian.Uint64(b[2:10]))
// Unpack DNS message
msg := new(dns.Msg)
if err := msg.Unpack(b[headerSize:]); err != nil {
return nil, fmt.Errorf("failed to unpack DNS message: %w", err)
}
return &cacheAnswer{
Timestamp: time.Unix(timestamp, 0),
PrefetchEligible: prefetchEligible,
Msg: msg,
}, nil
}
func NewRedisBackend(opt RedisBackendOptions) *redisBackend {
b := &redisBackend{
client: redis.NewClient(&opt.RedisOptions),
opt: opt,
asyncWriteSem: make(chan struct{}, redisAsyncWriteSemCapacity),
asyncSkipped: getVarInt("cache", "redis", "async-skipped"),
}
return b
}
func (b *redisBackend) Store(query *dns.Msg, item *cacheAnswer) {
// TTL guard: skip storing if already expired
ttl := time.Until(item.Expiry)
if ttl <= 0 {
return
}
if b.opt.SyncSet {
b.storeSync(query, item, ttl)
} else {
b.storeAsync(query, item, ttl)
}
}
func (b *redisBackend) storeSync(query *dns.Msg, item *cacheAnswer, ttl time.Duration) {
key := b.keyFromQuery(query)
value, err := encodeCacheAnswer(item)
if err != nil {
Log.Error("failed to encode cache record", "error", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := b.client.Set(ctx, key, value, ttl).Err(); err != nil {
Log.Error("failed to write to redis", "error", err)
}
}
func (b *redisBackend) storeAsync(query *dns.Msg, item *cacheAnswer, ttl time.Duration) {
// Non-blocking semaphore acquire
select {
case b.asyncWriteSem <- struct{}{}:
go func() {
defer func() { <-b.asyncWriteSem }()
b.storeSync(query, item, ttl)
}()
default:
// Semaphore full, skip async store (best-effort caching)
b.asyncSkipped.Add(1)
}
}
func (b *redisBackend) Lookup(q *dns.Msg) (*dns.Msg, bool, bool) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := b.keyFromQuery(q)
// Fetch raw bytes to avoid string conversion overhead
valueBytes, err := b.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) { // Return a cache-miss if there's no such key
return nil, false, false
}
Log.Error("failed to read from redis", "error", err)
return nil, false, false
}
// Try binary decode first, with JSON fallback for backward compatibility
var a *cacheAnswer
a, err = decodeCacheAnswer(valueBytes)
if err != nil {
// Fallback to JSON for backward compatibility with existing cached entries
if jsonErr := json.Unmarshal(valueBytes, &a); jsonErr != nil {
Log.Error("failed to decode cache record from redis", "binary_error", err, "json_error", jsonErr)
return nil, false, false
}
}
answer := a.Msg
prefetchEligible := a.PrefetchEligible
answer.Id = q.Id
answer.Question = q.Question // restore the case used in the question
// Calculate the time the record spent in the cache. We need to
// subtract that from the TTL of each answer record.
age := uint32(time.Since(a.Timestamp).Seconds())
// Go through all the answers, NS, and Extra and adjust the TTL (subtract the time
// it's spent in the cache). If the record is too old, evict it from the cache
// and return a cache-miss. OPT records have a TTL of 0 and are ignored.
for _, rr := range [][]dns.RR{answer.Answer, answer.Ns, answer.Extra} {
for _, a := range rr {
if _, ok := a.(*dns.OPT); ok {
continue
}
h := a.Header()
if age >= h.Ttl {
return nil, false, false
}
h.Ttl -= age
}
}
return answer, prefetchEligible, true
}
func (b *redisBackend) Flush() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if _, err := b.client.Del(ctx, b.opt.KeyPrefix+"*").Result(); err != nil {
Log.Error("failed to delete keys in redis", "error", err)
}
}
func (b *redisBackend) Size() int {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
size, err := b.client.DBSize(ctx).Result()
if err != nil {
Log.Error("failed to run dbsize command on redis", "error", err)
}
return int(size)
}
func (b *redisBackend) Close() error {
return b.client.Close()
}
// Build a key string to be used in redis.
func (b *redisBackend) keyFromQuery(q *dns.Msg) string {
var key strings.Builder
key.WriteString(b.opt.KeyPrefix)
key.WriteString(strings.ToLower(q.Question[0].Name))
key.WriteByte(':')
key.WriteString(dns.Class(q.Question[0].Qclass).String())
key.WriteByte(':')
key.WriteString(dns.Type(q.Question[0].Qtype).String())
key.WriteByte(':')
// CD=1 responses are unvalidated (RFC 4035 §4.7 / RFC 6840 §5.9) and
// must be keyed separately from CD=0 ones.
if q.CheckingDisabled {
key.WriteString("cd")
}
key.WriteByte(':')
edns0 := q.IsEdns0()
if edns0 != nil {
key.WriteString(fmt.Sprintf("%t", edns0.Do()))
key.WriteByte(':')
// See if we have a subnet option
for _, opt := range edns0.Option {
if subnet, ok := opt.(*dns.EDNS0_SUBNET); ok {
key.WriteString(subnet.Address.String())
key.WriteByte('/')
key.WriteString(fmt.Sprintf("%d", subnet.SourceNetmask))
}
}
}
return key.String()
}