Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ini.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ type LoadOptions struct {
AllowNonUniqueSections bool
// AllowDuplicateShadowValues indicates whether values for shadowed keys should be deduplicated.
AllowDuplicateShadowValues bool
// ParseBool is a function that parses boolean strings. When nil, the built-in boolean parser is used.
ParseBool func(string) (bool, error)
// FormatBool is a function that formats boolean values as strings. When nil, strconv.FormatBool is used.
FormatBool func(bool) string
}

// DebugFunc is the type of function called to log parse events.
Expand Down
72 changes: 72 additions & 0 deletions ini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,78 @@ key3`))
})
})

t.Run("custom bool handling", func(t *testing.T) {
parseBool := func(value string) (bool, error) {
switch value {
case "enabled":
return true, nil
case "disabled":
return false, nil
default:
return false, assert.AnError
}
}
formatBool := func(value bool) string {
if value {
return "enabled"
}
return "disabled"
}

f, err := LoadSources(LoadOptions{
ParseBool: parseBool,
FormatBool: formatBool,
}, []byte(`
feature = enabled
disabled_feature = disabled
feature_list = enabled,disabled,enabled
invalid_feature = unknown
invalid_feature_false = unknown
invalid_feature_list = enabled,unknown`))
require.NoError(t, err)
require.NotNil(t, f)

t.Run("parse single true bool", func(t *testing.T) {
value, err := f.Section("").Key("feature").Bool()
require.NoError(t, err)
assert.True(t, value)
})

t.Run("parse single false bool", func(t *testing.T) {
value, err := f.Section("").Key("disabled_feature").Bool()
require.NoError(t, err)
assert.False(t, value)
})

t.Run("parse bool slice", func(t *testing.T) {
assert.Equal(t, []bool{true, false, true}, f.Section("").Key("feature_list").Bools(","))
})

t.Run("fail to parse invalid bool", func(t *testing.T) {
value, err := f.Section("").Key("invalid_feature").Bool()
assert.Error(t, err)
assert.False(t, value)
})

t.Run("fail to parse invalid bool slice", func(t *testing.T) {
values, err := f.Section("").Key("invalid_feature_list").StrictBools(",")
assert.Empty(t, values)
assert.Error(t, err)
})

t.Run("format fallback true bool default", func(t *testing.T) {
key := f.Section("").Key("invalid_feature")
assert.True(t, key.MustBool(true))
assert.Equal(t, "enabled", key.String())
})

t.Run("format fallback false bool default", func(t *testing.T) {
key := f.Section("").Key("invalid_feature_false")
assert.False(t, key.MustBool(false))
assert.Equal(t, "disabled", key.String())
})
})

t.Run("allow shadow keys", func(t *testing.T) {
f, err := LoadSources(LoadOptions{AllowShadows: true, AllowPythonMultilineValues: true}, []byte(`
[remote "origin"]
Expand Down
25 changes: 22 additions & 3 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,26 @@ func (k *Key) Validate(fn func(string) string) string {
// It accepts 1, t, T, TRUE, true, True, YES, yes, Yes, y, ON, on, On,
// 0, f, F, FALSE, false, False, NO, no, No, n, OFF, off, Off.
// Any other value returns an error.
func parseBool(str string) (value bool, err error) {
func parseBool(str string, customParseBool ...func(string) (bool, error)) (value bool, err error) {
if len(customParseBool) > 0 && customParseBool[0] != nil {
return customParseBool[0](str)
}

switch str {
case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "y", "ON", "on", "On":
return true, nil
case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "n", "OFF", "off", "Off":
return false, nil
}

return false, fmt.Errorf("parsing \"%s\": invalid syntax", str)
}

// Bool returns bool type value.
func (k *Key) Bool() (bool, error) {
if k != nil && k.s != nil && k.s.f != nil {
return parseBool(k.String(), k.s.f.options.ParseBool)
}
return parseBool(k.String())
}

Expand Down Expand Up @@ -248,6 +256,13 @@ func (k *Key) Time() (time.Time, error) {
return k.TimeFormat(time.RFC3339)
}

func (k *Key) formatBool(value bool) string {
if k != nil && k.s != nil && k.s.f != nil && k.s.f.options.FormatBool != nil {
return k.s.f.options.FormatBool(value)
}
return strconv.FormatBool(value)
}

// MustString returns default value if key value is empty.
func (k *Key) MustString(defaultVal string) string {
val := k.String()
Expand All @@ -263,7 +278,7 @@ func (k *Key) MustString(defaultVal string) string {
func (k *Key) MustBool(defaultVal ...bool) bool {
val, err := k.Bool()
if len(defaultVal) > 0 && err != nil {
k.value = strconv.FormatBool(defaultVal[0])
k.value = k.formatBool(defaultVal[0])
return defaultVal[0]
}
return val
Expand Down Expand Up @@ -697,8 +712,12 @@ func (k *Key) StrictTimes(delim string) ([]time.Time, error) {
// parseBools transforms strings to bools.
func (k *Key) parseBools(strs []string, addInvalid, returnOnInvalid bool) ([]bool, error) {
vals := make([]bool, 0, len(strs))
var customParseBool func(string) (bool, error)
if k != nil && k.s != nil && k.s.f != nil {
customParseBool = k.s.f.options.ParseBool
}
parser := func(str string) (interface{}, error) {
val, err := parseBool(str)
val, err := parseBool(str, customParseBool)
return val, err
}
rawVals, err := k.doParse(strs, addInvalid, returnOnInvalid, parser)
Expand Down
94 changes: 94 additions & 0 deletions key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,100 @@ func TestKey_Helpers(t *testing.T) {
})
}

func TestKey_FormatBool(t *testing.T) {
formatBool := func(value bool) string {
if value {
return "enabled"
}
return "disabled"
}

t.Run("must bool falls back to strconv format when formatter is nil", func(t *testing.T) {
f := Empty()
require.NotNil(t, f)

key, err := f.Section("").NewKey("BOOL", "not-a-bool")
require.NoError(t, err)
require.NotNil(t, key)

assert.True(t, key.MustBool(true))
assert.Equal(t, "true", key.String())
})

t.Run("must bool uses custom formatter when configured", func(t *testing.T) {
f := Empty(LoadOptions{FormatBool: formatBool})
require.NotNil(t, f)

key, err := f.Section("").NewKey("BOOL", "not-a-bool")
require.NoError(t, err)
require.NotNil(t, key)

assert.True(t, key.MustBool(true))
assert.Equal(t, "enabled", key.String())
})
}

func TestKey_ParseBool(t *testing.T) {
parseBool := func(value string) (bool, error) {
switch value {
case "enabled":
return true, nil
case "disabled":
return false, nil
default:
return false, fmt.Errorf("parsing %q: invalid syntax", value)
}
}

t.Run("bool parses custom text", func(t *testing.T) {
f := Empty(LoadOptions{ParseBool: parseBool})
require.NotNil(t, f)

t.Run("success", func(t *testing.T) {
key, err := f.Section("").NewKey("BOOL", "enabled")
require.NoError(t, err)
require.NotNil(t, key)

value, err := key.Bool()
require.NoError(t, err)
assert.True(t, value)
})

t.Run("failure", func(t *testing.T) {
key, err := f.Section("").NewKey("BOOL_INVALID", "unknown")
require.NoError(t, err)
require.NotNil(t, key)

value, err := key.Bool()
assert.Error(t, err)
assert.False(t, value)
})
})

t.Run("bool slices parse custom text", func(t *testing.T) {
f := Empty(LoadOptions{ParseBool: parseBool})
require.NotNil(t, f)

t.Run("success", func(t *testing.T) {
key, err := f.Section("").NewKey("BOOLS", "enabled,disabled,enabled")
require.NoError(t, err)
require.NotNil(t, key)

boolsEqual(t, key.Bools(","), true, false, true)
})

t.Run("failure", func(t *testing.T) {
invalidKey, err := f.Section("").NewKey("BOOLS_INVALID", "enabled,unknown")
require.NoError(t, err)
require.NotNil(t, invalidKey)

vals, err := invalidKey.StrictBools(",")
assert.Empty(t, vals)
assert.Error(t, err)
})
})
}

func TestKey_ValueWithShadows(t *testing.T) {
t.Run("", func(t *testing.T) {
f, err := ShadowLoad([]byte(`
Expand Down
6 changes: 3 additions & 3 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func reflectSliceWithProperType(key *Key, field reflect.Value, delim string, all
case reflect.Float64:
val = fmt.Sprint(slice.Index(i).Float())
case reflect.Bool:
val = fmt.Sprint(slice.Index(i).Bool())
val = key.formatBool(slice.Index(i).Bool())
case reflectTime:
val = slice.Index(i).Interface().(time.Time).Format(time.RFC3339)
default:
Expand Down Expand Up @@ -506,7 +506,7 @@ func reflectSliceWithProperType(key *Key, field reflect.Value, delim string, all
case reflect.Float64:
fmt.Fprint(&buf, slice.Index(i).Float())
case reflect.Bool:
fmt.Fprint(&buf, slice.Index(i).Bool())
buf.WriteString(key.formatBool(slice.Index(i).Bool()))
case reflectTime:
buf.WriteString(slice.Index(i).Interface().(time.Time).Format(time.RFC3339))
default:
Expand All @@ -524,7 +524,7 @@ func reflectWithProperType(t reflect.Type, key *Key, field reflect.Value, delim
case reflect.String:
key.SetValue(field.String())
case reflect.Bool:
key.SetValue(fmt.Sprint(field.Bool()))
key.SetValue(key.formatBool(field.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
key.SetValue(fmt.Sprint(field.Int()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Expand Down