Make config.Config more unit-test friendly (#586)

This commit is contained in:
Duco van Amstel 2018-11-13 22:30:56 +00:00 committed by Wim
parent e19ba5a06a
commit 16d5aeac7c
7 changed files with 190 additions and 72 deletions

View File

@ -22,7 +22,7 @@ type Bridge struct {
Channels map[string]config.ChannelInfo Channels map[string]config.ChannelInfo
Joined map[string]bool Joined map[string]bool
Log *log.Entry Log *log.Entry
Config *config.Config Config config.Config
General *config.Protocol General *config.Protocol
} }
@ -69,36 +69,41 @@ func (b *Bridge) joinChannels(channels map[string]config.ChannelInfo, exists map
} }
func (b *Bridge) GetBool(key string) bool { func (b *Bridge) GetBool(key string) bool {
if b.Config.GetBool(b.Account + "." + key) { val, ok := b.Config.GetBool(b.Account + "." + key)
return b.Config.GetBool(b.Account + "." + key) if !ok {
val, _ = b.Config.GetBool("general." + key)
} }
return b.Config.GetBool("general." + key) return val
} }
func (b *Bridge) GetInt(key string) int { func (b *Bridge) GetInt(key string) int {
if b.Config.GetInt(b.Account+"."+key) != 0 { val, ok := b.Config.GetInt(b.Account + "." + key)
return b.Config.GetInt(b.Account + "." + key) if !ok {
val, _ = b.Config.GetInt("general." + key)
} }
return b.Config.GetInt("general." + key) return val
} }
func (b *Bridge) GetString(key string) string { func (b *Bridge) GetString(key string) string {
if b.Config.GetString(b.Account+"."+key) != "" { val, ok := b.Config.GetString(b.Account + "." + key)
return b.Config.GetString(b.Account + "." + key) if !ok {
val, _ = b.Config.GetString("general." + key)
} }
return b.Config.GetString("general." + key) return val
} }
func (b *Bridge) GetStringSlice(key string) []string { func (b *Bridge) GetStringSlice(key string) []string {
if len(b.Config.GetStringSlice(b.Account+"."+key)) != 0 { val, ok := b.Config.GetStringSlice(b.Account + "." + key)
return b.Config.GetStringSlice(b.Account + "." + key) if !ok {
val, _ = b.Config.GetStringSlice("general." + key)
} }
return b.Config.GetStringSlice("general." + key) return val
} }
func (b *Bridge) GetStringSlice2D(key string) [][]string { func (b *Bridge) GetStringSlice2D(key string) [][]string {
if len(b.Config.GetStringSlice2D(b.Account+"."+key)) != 0 { val, ok := b.Config.GetStringSlice2D(b.Account + "." + key)
return b.Config.GetStringSlice2D(b.Account + "." + key) if !ok {
val, _ = b.Config.GetStringSlice2D("general." + key)
} }
return b.Config.GetStringSlice2D("general." + key) return val
} }

View File

@ -2,7 +2,9 @@ package config
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil" "io/ioutil"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -177,13 +179,23 @@ type ConfigValues struct {
SameChannelGateway []SameChannelGateway SameChannelGateway []SameChannelGateway
} }
type Config struct { type Config interface {
v *viper.Viper ConfigValues() *ConfigValues
*ConfigValues GetBool(key string) (bool, bool)
sync.RWMutex GetInt(key string) (int, bool)
GetString(key string) (string, bool)
GetStringSlice(key string) ([]string, bool)
GetStringSlice2D(key string) ([][]string, bool)
} }
func NewConfig(cfgfile string) *Config { type config struct {
v *viper.Viper
sync.RWMutex
cv *ConfigValues
}
func NewConfig(cfgfile string) Config {
log.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false}) log.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false})
flog := log.WithFields(log.Fields{"prefix": "config"}) flog := log.WithFields(log.Fields{"prefix": "config"})
viper.SetConfigFile(cfgfile) viper.SetConfigFile(cfgfile)
@ -191,9 +203,9 @@ func NewConfig(cfgfile string) *Config {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
mycfg := NewConfigFromString(input) mycfg := newConfigFromString(input)
if mycfg.ConfigValues.General.MediaDownloadSize == 0 { if mycfg.cv.General.MediaDownloadSize == 0 {
mycfg.ConfigValues.General.MediaDownloadSize = 1000000 mycfg.cv.General.MediaDownloadSize = 1000000
} }
viper.WatchConfig() viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) { viper.OnConfigChange(func(e fsnotify.Event) {
@ -211,8 +223,11 @@ func getFileContents(filename string) ([]byte, error) {
return input, nil return input, nil
} }
func NewConfigFromString(input []byte) *Config { func NewConfigFromString(input []byte) Config {
var cfg ConfigValues return newConfigFromString(input)
}
func newConfigFromString(input []byte) *config {
viper.SetConfigType("toml") viper.SetConfigType("toml")
viper.SetEnvPrefix("matterbridge") viper.SetEnvPrefix("matterbridge")
viper.AddConfigPath(".") viper.AddConfigPath(".")
@ -222,45 +237,51 @@ func NewConfigFromString(input []byte) *Config {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = viper.Unmarshal(&cfg)
cfg := &ConfigValues{}
err = viper.Unmarshal(cfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
mycfg := new(Config) return &config{
mycfg.v = viper.GetViper() v: viper.GetViper(),
mycfg.ConfigValues = &cfg cv: cfg,
return mycfg }
} }
func (c *Config) GetBool(key string) bool { func (c *config) ConfigValues() *ConfigValues {
return c.cv
}
func (c *config) GetBool(key string) (bool, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting bool %s = %#v", key, c.v.GetBool(key)) // log.Debugf("getting bool %s = %#v", key, c.v.GetBool(key))
return c.v.GetBool(key) return c.v.GetBool(key), c.v.IsSet(key)
} }
func (c *Config) GetInt(key string) int { func (c *config) GetInt(key string) (int, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting int %s = %d", key, c.v.GetInt(key)) // log.Debugf("getting int %s = %d", key, c.v.GetInt(key))
return c.v.GetInt(key) return c.v.GetInt(key), c.v.IsSet(key)
} }
func (c *Config) GetString(key string) string { func (c *config) GetString(key string) (string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting String %s = %s", key, c.v.GetString(key)) // log.Debugf("getting String %s = %s", key, c.v.GetString(key))
return c.v.GetString(key) return c.v.GetString(key), c.v.IsSet(key)
} }
func (c *Config) GetStringSlice(key string) []string { func (c *config) GetStringSlice(key string) ([]string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
// log.Debugf("getting StringSlice %s = %#v", key, c.v.GetStringSlice(key)) // log.Debugf("getting StringSlice %s = %#v", key, c.v.GetStringSlice(key))
return c.v.GetStringSlice(key) return c.v.GetStringSlice(key), c.v.IsSet(key)
} }
func (c *Config) GetStringSlice2D(key string) [][]string { func (c *config) GetStringSlice2D(key string) ([][]string, bool) {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
result := [][]string{} result := [][]string{}
@ -272,9 +293,9 @@ func (c *Config) GetStringSlice2D(key string) [][]string {
} }
result = append(result, result2) result = append(result, result2)
} }
return result return result, true
} }
return result return result, false
} }
func GetIconURL(msg *Message, iconURL string) string { func GetIconURL(msg *Message, iconURL string) string {
@ -286,3 +307,46 @@ func GetIconURL(msg *Message, iconURL string) string {
iconURL = strings.Replace(iconURL, "{PROTOCOL}", protocol, -1) iconURL = strings.Replace(iconURL, "{PROTOCOL}", protocol, -1)
return iconURL return iconURL
} }
type TestConfig struct {
Config
Overrides map[string]interface{}
}
func (c *TestConfig) GetBool(key string) (bool, bool) {
val, ok := c.Overrides[key]
fmt.Fprintln(os.Stderr, "DEBUG:", c.Overrides, key, ok, val)
if ok {
return val.(bool), true
}
return c.Config.GetBool(key)
}
func (c *TestConfig) GetInt(key string) (int, bool) {
if val, ok := c.Overrides[key]; ok {
return val.(int), true
}
return c.Config.GetInt(key)
}
func (c *TestConfig) GetString(key string) (string, bool) {
if val, ok := c.Overrides[key]; ok {
return val.(string), true
}
return c.Config.GetString(key)
}
func (c *TestConfig) GetStringSlice(key string) ([]string, bool) {
if val, ok := c.Overrides[key]; ok {
return val.([]string), true
}
return c.Config.GetStringSlice(key)
}
func (c *TestConfig) GetStringSlice2D(key string) ([][]string, bool) {
if val, ok := c.Overrides[key]; ok {
return val.([][]string), true
}
return c.Config.GetStringSlice2D(key)
}

View File

@ -33,7 +33,8 @@ import (
) )
type Gateway struct { type Gateway struct {
*config.Config config.Config
Router *Router Router *Router
MyConfig *config.Gateway MyConfig *config.Gateway
Bridges map[string]*bridge.Bridge Bridges map[string]*bridge.Bridge
@ -107,7 +108,7 @@ func (gw *Gateway) AddBridge(cfg *config.Bridge) error {
if br == nil { if br == nil {
br = bridge.New(cfg) br = bridge.New(cfg)
br.Config = gw.Router.Config br.Config = gw.Router.Config
br.General = &gw.General br.General = &gw.ConfigValues().General
// set logging // set logging
br.Log = log.WithFields(log.Fields{"prefix": "bridge"}) br.Log = log.WithFields(log.Fields{"prefix": "bridge"})
brconfig := &bridge.Config{Remote: gw.Message, Log: log.WithFields(log.Fields{"prefix": br.Protocol}), Bridge: br} brconfig := &bridge.Config{Remote: gw.Message, Log: log.WithFields(log.Fields{"prefix": br.Protocol}), Bridge: br}
@ -278,7 +279,7 @@ func (gw *Gateway) handleMessage(msg config.Message, dest *bridge.Bridge) []*BrM
// Get the ID of the parent message in thread // Get the ID of the parent message in thread
var canonicalParentMsgID string var canonicalParentMsgID string
if msg.ParentID != "" && (gw.Config.General.PreserveThreading || dest.GetBool("PreserveThreading")) { if msg.ParentID != "" && (gw.ConfigValues().General.PreserveThreading || dest.GetBool("PreserveThreading")) {
thisParentMsgID := dest.Protocol + " " + msg.ParentID thisParentMsgID := dest.Protocol + " " + msg.ParentID
canonicalParentMsgID = gw.FindCanonicalMsgID(thisParentMsgID) canonicalParentMsgID = gw.FindCanonicalMsgID(thisParentMsgID)
} }
@ -391,13 +392,13 @@ func (gw *Gateway) ignoreMessage(msg *config.Message) bool {
func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) string { func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) string {
br := gw.Bridges[msg.Account] br := gw.Bridges[msg.Account]
msg.Protocol = br.Protocol msg.Protocol = br.Protocol
if gw.Config.General.StripNick || dest.GetBool("StripNick") { if gw.ConfigValues().General.StripNick || dest.GetBool("StripNick") {
re := regexp.MustCompile("[^a-zA-Z0-9]+") re := regexp.MustCompile("[^a-zA-Z0-9]+")
msg.Username = re.ReplaceAllString(msg.Username, "") msg.Username = re.ReplaceAllString(msg.Username, "")
} }
nick := dest.GetString("RemoteNickFormat") nick := dest.GetString("RemoteNickFormat")
if nick == "" { if nick == "" {
nick = gw.Config.General.RemoteNickFormat nick = gw.ConfigValues().General.RemoteNickFormat
} }
// loop to replace nicks // loop to replace nicks
@ -436,7 +437,7 @@ func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) strin
} }
func (gw *Gateway) modifyAvatar(msg config.Message, dest *bridge.Bridge) string { func (gw *Gateway) modifyAvatar(msg config.Message, dest *bridge.Bridge) string {
iconurl := gw.Config.General.IconURL iconurl := gw.ConfigValues().General.IconURL
if iconurl == "" { if iconurl == "" {
iconurl = dest.GetString("IconURL") iconurl = dest.GetString("IconURL")
} }
@ -477,7 +478,9 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
reg := regexp.MustCompile("[^a-zA-Z0-9]+") reg := regexp.MustCompile("[^a-zA-Z0-9]+")
// If we don't have a attachfield or we don't have a mediaserver configured return // If we don't have a attachfield or we don't have a mediaserver configured return
if msg.Extra == nil || (gw.Config.General.MediaServerUpload == "" && gw.Config.General.MediaDownloadPath == "") { if msg.Extra == nil ||
(gw.ConfigValues().General.MediaServerUpload == "" &&
gw.ConfigValues().General.MediaDownloadPath == "") {
return return
} }
@ -499,10 +502,10 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
sha1sum := fmt.Sprintf("%x", sha1.Sum(*fi.Data))[:8] sha1sum := fmt.Sprintf("%x", sha1.Sum(*fi.Data))[:8]
if gw.Config.General.MediaServerUpload != "" { if gw.ConfigValues().General.MediaServerUpload != "" {
// Use MediaServerUpload. Upload using a PUT HTTP request and basicauth. // Use MediaServerUpload. Upload using a PUT HTTP request and basicauth.
url := gw.Config.General.MediaServerUpload + "/" + sha1sum + "/" + fi.Name url := gw.ConfigValues().General.MediaServerUpload + "/" + sha1sum + "/" + fi.Name
req, err := http.NewRequest("PUT", url, bytes.NewReader(*fi.Data)) req, err := http.NewRequest("PUT", url, bytes.NewReader(*fi.Data))
if err != nil { if err != nil {
@ -521,7 +524,7 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
} else { } else {
// Use MediaServerPath. Place the file on the current filesystem. // Use MediaServerPath. Place the file on the current filesystem.
dir := gw.Config.General.MediaDownloadPath + "/" + sha1sum dir := gw.ConfigValues().General.MediaDownloadPath + "/" + sha1sum
err := os.Mkdir(dir, os.ModePerm) err := os.Mkdir(dir, os.ModePerm)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
flog.Errorf("mediaserver path failed, could not mkdir: %s %#v", err, err) flog.Errorf("mediaserver path failed, could not mkdir: %s %#v", err, err)
@ -539,7 +542,7 @@ func (gw *Gateway) handleFiles(msg *config.Message) {
} }
// Download URL. // Download URL.
durl := gw.Config.General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name durl := gw.ConfigValues().General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name
flog.Debugf("mediaserver download URL = %s", durl) flog.Debugf("mediaserver download URL = %s", durl)

View File

@ -2,27 +2,32 @@ package gateway
import ( import (
"fmt" "fmt"
"time"
"github.com/42wim/matterbridge/bridge" "github.com/42wim/matterbridge/bridge"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
samechannelgateway "github.com/42wim/matterbridge/gateway/samechannel" samechannelgateway "github.com/42wim/matterbridge/gateway/samechannel"
// "github.com/davecgh/go-spew/spew"
"time"
) )
type Router struct { type Router struct {
config.Config
Gateways map[string]*Gateway Gateways map[string]*Gateway
Message chan config.Message Message chan config.Message
MattermostPlugin chan config.Message MattermostPlugin chan config.Message
*config.Config
} }
func NewRouter(cfg *config.Config) (*Router, error) { func NewRouter(cfg config.Config) (*Router, error) {
r := &Router{Message: make(chan config.Message), MattermostPlugin: make(chan config.Message), Gateways: make(map[string]*Gateway), Config: cfg} r := &Router{
Config: cfg,
Message: make(chan config.Message),
MattermostPlugin: make(chan config.Message),
Gateways: make(map[string]*Gateway),
}
sgw := samechannelgateway.New(cfg) sgw := samechannelgateway.New(cfg)
gwconfigs := sgw.GetConfig() gwconfigs := sgw.GetConfig()
for _, entry := range append(gwconfigs, cfg.Gateway...) { for _, entry := range append(gwconfigs, cfg.ConfigValues().Gateway...) {
if !entry.Enable { if !entry.Enable {
continue continue
} }

View File

@ -5,17 +5,17 @@ import (
) )
type SameChannelGateway struct { type SameChannelGateway struct {
*config.Config config.Config
} }
func New(cfg *config.Config) *SameChannelGateway { func New(cfg config.Config) *SameChannelGateway {
return &SameChannelGateway{Config: cfg} return &SameChannelGateway{Config: cfg}
} }
func (sgw *SameChannelGateway) GetConfig() []config.Gateway { func (sgw *SameChannelGateway) GetConfig() []config.Gateway {
var gwconfigs []config.Gateway var gwconfigs []config.Gateway
cfg := sgw.Config cfg := sgw.Config
for _, gw := range cfg.SameChannelGateway { for _, gw := range cfg.ConfigValues().SameChannelGateway {
gwconfig := config.Gateway{Name: gw.Name, Enable: gw.Enable} gwconfig := config.Gateway{Name: gw.Name, Enable: gw.Enable}
for _, account := range gw.Accounts { for _, account := range gw.Accounts {
for _, channel := range gw.Channels { for _, channel := range gw.Channels {

View File

@ -1,16 +1,13 @@
package samechannelgateway package samechannelgateway
import ( import (
"fmt"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/BurntSushi/toml"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
var testconfig = ` const testConfig = `
[mattermost.test] [mattermost.test]
[slack.test] [slack.test]
@ -21,12 +18,56 @@ var testconfig = `
channels = [ "testing","testing2","testing10"] channels = [ "testing","testing2","testing10"]
` `
func TestGetConfig(t *testing.T) { var (
var cfg *config.Config expectedConfig = config.Gateway{
if _, err := toml.Decode(testconfig, &cfg); err != nil { Name: "blah",
fmt.Println(err) Enable: true,
In: []config.Bridge(nil),
Out: []config.Bridge(nil),
InOut: []config.Bridge{
{
Account: "mattermost.test",
Channel: "testing",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
{
Account: "mattermost.test",
Channel: "testing2",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
{
Account: "mattermost.test",
Channel: "testing10",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
{
Account: "slack.test",
Channel: "testing",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
{
Account: "slack.test",
Channel: "testing2",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
{
Account: "slack.test",
Channel: "testing10",
Options: config.ChannelOptions{Key: ""},
SameChannel: true,
},
},
} }
)
func TestGetConfig(t *testing.T) {
cfg := config.NewConfigFromString([]byte(testConfig))
sgw := New(cfg) sgw := New(cfg)
configs := sgw.GetConfig() configs := sgw.GetConfig()
assert.Equal(t, []config.Gateway{{Name: "blah", Enable: true, In: []config.Bridge(nil), Out: []config.Bridge(nil), InOut: []config.Bridge{{Account: "mattermost.test", Channel: "testing", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "mattermost.test", Channel: "testing2", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "mattermost.test", Channel: "testing10", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing2", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing10", Options: config.ChannelOptions{Key: ""}, SameChannel: true}}}}, configs) assert.Equal(t, []config.Gateway{expectedConfig}, configs)
} }

View File

@ -44,7 +44,7 @@ func main() {
flog.Println("WARNING: THIS IS A DEVELOPMENT VERSION. Things may break.") flog.Println("WARNING: THIS IS A DEVELOPMENT VERSION. Things may break.")
} }
cfg := config.NewConfig(*flagConfig) cfg := config.NewConfig(*flagConfig)
cfg.General.Debug = *flagDebug cfg.ConfigValues().General.Debug = *flagDebug
r, err := gateway.NewRouter(cfg) r, err := gateway.NewRouter(cfg)
if err != nil { if err != nil {
flog.Fatalf("Starting gateway failed: %s", err) flog.Fatalf("Starting gateway failed: %s", err)