From 51702974454edce8549376c4f04fde2a070f726b Mon Sep 17 00:00:00 2001 From: Alex Goodman Date: Mon, 30 Jun 2025 09:38:29 -0400 Subject: [PATCH] fix hydration trigger for new features Signed-off-by: Alex Goodman --- grype/db/v6/installation/curator.go | 2 +- grype/db/v6/installation/curator_test.go | 10 + internal/schemaver/schema_ver.go | 22 +- internal/schemaver/schema_ver_test.go | 459 ++++++++++++++++++++--- 4 files changed, 444 insertions(+), 49 deletions(-) diff --git a/grype/db/v6/installation/curator.go b/grype/db/v6/installation/curator.go index 1c8c853c2aa..d52c36cb7b7 100644 --- a/grype/db/v6/installation/curator.go +++ b/grype/db/v6/installation/curator.go @@ -343,7 +343,7 @@ func isRehydrationNeeded(fs afero.Fs, dirPath string, currentDBVersion *schemave return false, fmt.Errorf("unable to parse client version from import metadata: %w", err) } - hydratedWithOldClient := clientHydrationVersion.LessThan(*currentDBVersion) + hydratedWithOldClient := clientHydrationVersion.LessThanOrEqualTo(*currentDBVersion) haveNewerClient := clientHydrationVersion.LessThan(currentClientVersion) doRehydrate := hydratedWithOldClient && haveNewerClient diff --git a/grype/db/v6/installation/curator_test.go b/grype/db/v6/installation/curator_test.go index 9d60d34582f..e9d3096401e 100644 --- a/grype/db/v6/installation/curator_test.go +++ b/grype/db/v6/installation/curator_test.go @@ -729,6 +729,16 @@ func Test_isRehydrationNeeded(t *testing.T) { currentClientVer: schemaver.New(6, 2, 0), expectedResult: false, }, + { + // there are cases where new features will result in new columns, thus an old client downloading and hydrating + // a DB should function, however, when the new client is downloaded it should trigger at least a rehydration + // of the existing DB (in cases where the new DB is not availabl for download yet). + name: "rehydration needed - we have a new client version, with an old DB version", + currentDBVersion: schemaver.New(6, 0, 2), + hydrationClientVer: schemaver.New(6, 0, 2), + currentClientVer: schemaver.New(6, 0, 3), + expectedResult: true, + }, } for _, tt := range tests { diff --git a/internal/schemaver/schema_ver.go b/internal/schemaver/schema_ver.go index 83423f3ef3c..16957e0ef60 100644 --- a/internal/schemaver/schema_ver.go +++ b/internal/schemaver/schema_ver.go @@ -83,6 +83,26 @@ func (s SchemaVer) LessThan(other SchemaVer) bool { return s.Addition < other.Addition } +func (s SchemaVer) LessThanOrEqualTo(other SchemaVer) bool { + return s.LessThan(other) || s.Equal(other) +} + +func (s SchemaVer) Equal(other SchemaVer) bool { + return s.Model == other.Model && s.Revision == other.Revision && s.Addition == other.Addition +} + +func (s SchemaVer) GreaterThan(other SchemaVer) bool { + if s.Model != other.Model { + return s.Model > other.Model + } + + if s.Revision != other.Revision { + return s.Revision > other.Revision + } + + return s.Addition > other.Addition +} + func (s SchemaVer) GreaterOrEqualTo(other SchemaVer) bool { - return !s.LessThan(other) + return s.GreaterThan(other) || s.Equal(other) } diff --git a/internal/schemaver/schema_ver_test.go b/internal/schemaver/schema_ver_test.go index 04d5bf84d6d..37de3ea53cd 100644 --- a/internal/schemaver/schema_ver_test.go +++ b/internal/schemaver/schema_ver_test.go @@ -1,78 +1,263 @@ package schemaver import ( + "encoding/json" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestSchemaVerComparisons(t *testing.T) { +func TestSchemaVer_LessThan(t *testing.T) { tests := []struct { - name string - v1 SchemaVer - v2 SchemaVer - lessThan bool - greaterOrEqual bool + name string + v1 SchemaVer + v2 SchemaVer + want bool }{ { - name: "equal versions", - v1: New(1, 0, 0), - v2: New(1, 0, 0), - lessThan: false, - greaterOrEqual: true, + name: "equal versions", + v1: New(1, 0, 0), + v2: New(1, 0, 0), + want: false, }, { - name: "different model versions", - v1: New(1, 0, 0), - v2: New(2, 0, 0), - lessThan: true, - greaterOrEqual: false, + name: "different model versions", + v1: New(1, 0, 0), + v2: New(2, 0, 0), + want: true, }, { - name: "different revision versions", - v1: New(1, 1, 0), - v2: New(1, 2, 0), - lessThan: true, - greaterOrEqual: false, + name: "different revision versions", + v1: New(1, 1, 0), + v2: New(1, 2, 0), + want: true, }, { - name: "different addition versions", - v1: New(1, 0, 1), - v2: New(1, 0, 2), - lessThan: true, - greaterOrEqual: false, + name: "different addition versions", + v1: New(1, 0, 1), + v2: New(1, 0, 2), + want: true, }, { - name: "inverted addition versions", - v1: New(1, 0, 2), - v2: New(1, 0, 1), - lessThan: false, - greaterOrEqual: true, + name: "inverted addition versions", + v1: New(1, 0, 2), + v2: New(1, 0, 1), + want: false, }, { - name: "greater model overrides lower revision", - v1: New(2, 0, 0), - v2: New(1, 9, 9), - lessThan: false, - greaterOrEqual: true, + name: "greater model overrides lower revision", + v1: New(2, 0, 0), + v2: New(1, 9, 9), + want: false, }, { - name: "greater revision overrides lower addition", - v1: New(1, 2, 0), - v2: New(1, 1, 9), - lessThan: false, - greaterOrEqual: true, + name: "greater revision overrides lower addition", + v1: New(1, 2, 0), + v2: New(1, 1, 9), + want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.v1.LessThan(tt.v2); got != tt.lessThan { - t.Errorf("LessThan() = %v, want %v", got, tt.lessThan) - } - if got := tt.v1.GreaterOrEqualTo(tt.v2); got != tt.greaterOrEqual { - t.Errorf("GreaterOrEqualTo() = %v, want %v", got, tt.greaterOrEqual) - } + assert.Equal(t, tt.want, tt.v1.LessThan(tt.v2)) + }) + } +} + +func TestSchemaVer_GreaterOrEqualTo(t *testing.T) { + tests := []struct { + name string + v1 SchemaVer + v2 SchemaVer + want bool + }{ + { + name: "equal versions", + v1: New(1, 0, 0), + v2: New(1, 0, 0), + want: true, + }, + { + name: "different model versions", + v1: New(1, 0, 0), + v2: New(2, 0, 0), + want: false, + }, + { + name: "different revision versions", + v1: New(1, 1, 0), + v2: New(1, 2, 0), + want: false, + }, + { + name: "different addition versions", + v1: New(1, 0, 1), + v2: New(1, 0, 2), + want: false, + }, + { + name: "inverted addition versions", + v1: New(1, 0, 2), + v2: New(1, 0, 1), + want: true, + }, + { + name: "greater model overrides lower revision", + v1: New(2, 0, 0), + v2: New(1, 9, 9), + want: true, + }, + { + name: "greater revision overrides lower addition", + v1: New(1, 2, 0), + v2: New(1, 1, 9), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.v1.GreaterOrEqualTo(tt.v2)) + }) + } +} + +func TestSchemaVer_LessThanOrEqualTo(t *testing.T) { + tests := []struct { + name string + v1 SchemaVer + v2 SchemaVer + want bool + }{ + { + name: "equal versions", + v1: New(1, 2, 3), + v2: New(1, 2, 3), + want: true, + }, + { + name: "less than version", + v1: New(1, 2, 3), + v2: New(1, 2, 4), + want: true, + }, + { + name: "greater than version", + v1: New(1, 2, 4), + v2: New(1, 2, 3), + want: false, + }, + { + name: "different model - less", + v1: New(1, 9, 9), + v2: New(2, 0, 0), + want: true, + }, + { + name: "different model - greater", + v1: New(2, 0, 0), + v2: New(1, 9, 9), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.v1.LessThanOrEqualTo(tt.v2)) + }) + } +} + +func TestSchemaVer_Equal(t *testing.T) { + tests := []struct { + name string + v1 SchemaVer + v2 SchemaVer + want bool + }{ + { + name: "equal versions", + v1: New(1, 2, 3), + v2: New(1, 2, 3), + want: true, + }, + { + name: "different addition", + v1: New(1, 2, 3), + v2: New(1, 2, 4), + want: false, + }, + { + name: "different revision", + v1: New(1, 2, 3), + v2: New(1, 3, 3), + want: false, + }, + { + name: "different model", + v1: New(1, 2, 3), + v2: New(2, 2, 3), + want: false, + }, + { + name: "zero values equal", + v1: New(1, 0, 0), + v2: New(1, 0, 0), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.v1.Equal(tt.v2)) + }) + } +} + +func TestSchemaVer_GreaterThan(t *testing.T) { + tests := []struct { + name string + v1 SchemaVer + v2 SchemaVer + want bool + }{ + { + name: "equal versions", + v1: New(1, 2, 3), + v2: New(1, 2, 3), + want: false, + }, + { + name: "greater addition", + v1: New(1, 2, 4), + v2: New(1, 2, 3), + want: true, + }, + { + name: "greater revision", + v1: New(1, 3, 0), + v2: New(1, 2, 9), + want: true, + }, + { + name: "greater model", + v1: New(2, 0, 0), + v2: New(1, 9, 9), + want: true, + }, + { + name: "less than", + v1: New(1, 2, 3), + v2: New(1, 2, 4), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.v1.GreaterThan(tt.v2)) }) } } @@ -90,6 +275,18 @@ func TestParse(t *testing.T) { want: New(1, 2, 3), wantErr: false, }, + { + name: "valid version with v prefix", + input: "v1.2.3", + want: New(1, 2, 3), + wantErr: false, + }, + { + name: "valid version with v prefix and zeros", + input: "v1.0.0", + want: New(1, 0, 0), + wantErr: false, + }, { name: "valid large numbers", input: "999.888.777", @@ -108,6 +305,12 @@ func TestParse(t *testing.T) { want: New(0, 0, 0), wantErr: true, }, + { + name: "invalid version with v prefix and zero model", + input: "v0.0.0", + want: New(0, 0, 0), + wantErr: true, + }, { name: "invalid empty string", input: "", @@ -243,3 +446,165 @@ func TestSchemaVer_Valid(t *testing.T) { }) } } + +func TestSchemaVer_String(t *testing.T) { + tests := []struct { + name string + schema SchemaVer + want string + }{ + { + name: "basic version", + schema: New(1, 2, 3), + want: "v1.2.3", + }, + { + name: "version with zeros", + schema: New(1, 0, 0), + want: "v1.0.0", + }, + { + name: "large numbers", + schema: New(999, 888, 777), + want: "v999.888.777", + }, + { + name: "single digits", + schema: New(5, 4, 3), + want: "v5.4.3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.schema.String()) + }) + } +} + +func TestSchemaVer_MarshalJSON(t *testing.T) { + tests := []struct { + name string + schema SchemaVer + want string + }{ + { + name: "basic version", + schema: New(1, 2, 3), + want: `"v1.2.3"`, + }, + { + name: "version with zeros", + schema: New(1, 0, 0), + want: `"v1.0.0"`, + }, + { + name: "large numbers", + schema: New(999, 888, 777), + want: `"v999.888.777"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.schema.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, tt.want, string(got)) + }) + } +} + +func TestSchemaVer_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want SchemaVer + wantErr require.ErrorAssertionFunc + }{ + { + name: "valid version", + input: `"v1.2.3"`, + want: New(1, 2, 3), + wantErr: require.NoError, + }, + { + name: "valid version without v prefix", + input: `"1.2.3"`, + want: New(1, 2, 3), + wantErr: require.NoError, + }, + { + name: "valid version with zeros", + input: `"v1.0.0"`, + want: New(1, 0, 0), + wantErr: require.NoError, + }, + { + name: "invalid JSON format", + input: `{"version": "v1.2.3"}`, + wantErr: require.Error, + }, + { + name: "invalid version format", + input: `"invalid"`, + wantErr: require.Error, + }, + { + name: "invalid zero model", + input: `"v0.1.2"`, + wantErr: require.Error, + }, + { + name: "malformed JSON", + input: `"v1.2.3`, + wantErr: require.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got SchemaVer + err := json.Unmarshal([]byte(tt.input), &got) + tt.wantErr(t, err) + if err == nil { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestSchemaVer_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + schema SchemaVer + }{ + { + name: "basic version", + schema: New(1, 2, 3), + }, + { + name: "version with zeros", + schema: New(1, 0, 0), + }, + { + name: "large numbers", + schema: New(999, 888, 777), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // marshal + data, err := json.Marshal(tt.schema) + require.NoError(t, err) + + // unmarshal + var got SchemaVer + err = json.Unmarshal(data, &got) + require.NoError(t, err) + + // should be equal + assert.Equal(t, tt.schema, got) + }) + } +}