From d5ed9de839dd35e4fc56efb0b1bcaedb7f555793 Mon Sep 17 00:00:00 2001 From: Javier Peletier Date: Thu, 17 Dec 2020 14:21:08 +0100 Subject: [PATCH] refactor and cleanup --- config.go | 114 +++++++++++++++++++++++++++++ kn/bridge.go | 50 ++++++------- main.go | 201 ++++++++------------------------------------------- mqtt/mqtt.go | 103 ++++++++++++++++++++++++++ 4 files changed, 271 insertions(+), 197 deletions(-) create mode 100644 config.go create mode 100644 mqtt/mqtt.go diff --git a/config.go b/config.go new file mode 100644 index 0000000..fd9b052 --- /dev/null +++ b/config.go @@ -0,0 +1,114 @@ +package main + +import ( + "flag" + "fmt" + "koolnova2mqtt/kn" + "koolnova2mqtt/modbus" + "koolnova2mqtt/mqtt" + "log" + "os" + "regexp" + "strconv" + "strings" + "time" +) + +type Config struct { + MqttClient *mqtt.Client + slaves map[byte]string + BridgeTemplateConfig *kn.Config +} + +func generateNodeName(slaveID string, port string) string { + reg, err := regexp.Compile("[^a-zA-Z0-9]+") + if err != nil { + log.Fatal(err) + } + hostname, _ := os.Hostname() + + port = strings.Replace(port, "/dev/", "", -1) + port = reg.ReplaceAllString(port, "") + return strings.ToLower(fmt.Sprintf("%s_%s_%s", hostname, port, slaveID)) + +} + +func parseModbusSlaveInfo(slaveIDs, slaveNames string, modbusPort string) map[byte]string { + slaveIDStrList := strings.Split(slaveIDs, ",") + var slaveNameList []string + + if slaveNames == "" { + for _, slaveIDStr := range slaveIDStrList { + slaveNameList = append(slaveNameList, generateNodeName(slaveIDStr, modbusPort)) + } + } else { + slaveNameList = strings.Split(slaveNames, ",") + if len(slaveIDStrList) != len(slaveNameList) { + log.Fatalf("modbusSlaveIDs and modbusSlaveNames lists must have the same length") + } + } + + slaves := make(map[byte]string) + for i, slaveIDStr := range slaveIDStrList { + slaveID, err := strconv.Atoi(slaveIDStr) + if err != nil { + log.Fatalf("Error parsing slaveID list") + } + slaves[byte(slaveID)] = slaveNameList[i] + } + return slaves +} + +func ParseCommandLine() *Config { + hostname, _ := os.Hostname() + + server := flag.String("server", "tcp://127.0.0.1:1883", "The full url of the MQTT server to connect to ex: tcp://127.0.0.1:1883") + clientid := flag.String("clientid", hostname+strconv.Itoa(time.Now().Second()), "A clientid for the connection") + username := flag.String("username", "", "A username to authenticate to the MQTT server") + password := flag.String("password", "", "Password to match username") + prefix := flag.String("prefix", "koolnova2mqtt", "MQTT topic root where to publish/read topics") + hassPrefix := flag.String("hassPrefix", "homeassistant", "Home assistant discovery prefix") + modbusPort := flag.String("modbusPort", "/dev/ttyUSB0", "Serial port where modbus hardware is connected") + modbusPortBaudRate := flag.Int("modbusRate", 9600, "Modbus port data rate") + modbusDataBits := flag.Int("modbusDataBits", 8, "Modbus port data bits") + modbusPortParity := flag.String("modbusParity", "E", "N - None, E - Even, O - Odd (default E) (The use of no parity requires 2 stop bits.)") + modbusStopBits := flag.Int("modbusStopBits", 1, "Modbus port stop bits") + modbusSlaveList := flag.String("modbusSlaveIDs", "49", "Comma-separated list of modbus slave IDs to manage") + modbusSlaveNames := flag.String("modbusSlaveNames", "", "Comma-separated list of modbus slave names. Defaults to 'slave#'") + + flag.Parse() + + slaves := parseModbusSlaveInfo(*modbusSlaveList, *modbusSlaveNames, *modbusPort) + + mb, err := modbus.New(&modbus.Config{ + Port: *modbusPort, + BaudRate: *modbusPortBaudRate, + DataBits: *modbusDataBits, + Parity: *modbusPortParity, + StopBits: *modbusStopBits, + Timeout: 200 * time.Millisecond, + }) + if err != nil { + log.Fatalf("Error initializing modbus: %s", err) + } + defer mb.Close() + + mqttClient := mqtt.New(&mqtt.Config{ + Server: *server, + ClientID: *clientid, + Username: *username, + Password: *password, + }) + + return &Config{ + slaves: slaves, + MqttClient: mqttClient, + BridgeTemplateConfig: &kn.Config{ + Mqtt: mqttClient, + Modbus: mb, + TopicPrefix: *prefix, + HassPrefix: *hassPrefix, + }, + } + +} diff --git a/kn/bridge.go b/kn/bridge.go index 6dd6b84..9b1a23e 100644 --- a/kn/bridge.go +++ b/kn/bridge.go @@ -9,14 +9,15 @@ import ( "strconv" ) -type Publish func(topic string, qos byte, retained bool, payload string) -type Subscribe func(topic string, callback func(message string)) error +type MqttClient interface { + Publish(topic string, qos byte, retained bool, payload string) error + Subscribe(topic string, callback func(message string)) error +} type Config struct { ModuleName string SlaveID byte - Publish Publish - Subscribe Subscribe + Mqtt MqttClient TopicPrefix string HassPrefix string Modbus modbus.Modbus @@ -117,7 +118,7 @@ func (b *Bridge) Start() error { if zone.IsOn() { hvacModeTopic := b.getZoneTopic(zone.ZoneNumber, "hvacMode") mode := getHVACMode() - b.Publish(hvacModeTopic, 0, true, mode) + b.Mqtt.Publish(hvacModeTopic, 0, true, mode) } } } @@ -143,22 +144,22 @@ func (b *Bridge) Start() error { } else { mode = HVAC_MODE_OFF } - b.Publish(hvacModeTopic, 0, true, mode) + b.Mqtt.Publish(hvacModeTopic, 0, true, mode) } zone.OnCurrentTempChange = func(currentTemp float32) { - b.Publish(currentTempTopic, 0, true, fmt.Sprintf("%g", currentTemp)) + b.Mqtt.Publish(currentTempTopic, 0, true, fmt.Sprintf("%g", currentTemp)) } zone.OnTargetTempChange = func(targetTemp float32) { - b.Publish(targetTempTopic, 0, true, fmt.Sprintf("%g", targetTemp)) + b.Mqtt.Publish(targetTempTopic, 0, true, fmt.Sprintf("%g", targetTemp)) } zone.OnFanModeChange = func(fanMode FanMode) { - b.Publish(fanModeTopic, 0, true, FanMode2Str(fanMode)) + b.Mqtt.Publish(fanModeTopic, 0, true, FanMode2Str(fanMode)) } zone.OnKnModeChange = func(knMode KnMode) { } - err = b.Subscribe(targetTempSetTopic, func(message string) { + err = b.Mqtt.Subscribe(targetTempSetTopic, func(message string) { targetTemp, err := strconv.ParseFloat(message, 32) if err != nil { log.Printf("Error parsing targetTemperature in topic %s: %s", targetTempSetTopic, err) @@ -173,7 +174,7 @@ func (b *Bridge) Start() error { return err } - err = b.Subscribe(fanModeSetTopic, func(message string) { + err = b.Mqtt.Subscribe(fanModeSetTopic, func(message string) { fm, err := Str2FanMode(message) if err != nil { log.Printf("Unknown fan mode %q in message to zone %d", message, zone.ZoneNumber) @@ -187,7 +188,7 @@ func (b *Bridge) Start() error { return err } - err = b.Subscribe(hvacModeSetTopic, func(message string) { + err = b.Mqtt.Subscribe(hvacModeSetTopic, func(message string) { if message == HVAC_MODE_OFF { err := zone.SetOn(false) if err != nil { @@ -211,7 +212,7 @@ func (b *Bridge) Start() error { return err } - err = b.Subscribe(holdModeSetTopic, func(message string) { + err = b.Mqtt.Subscribe(holdModeSetTopic, func(message string) { knMode := sys.GetSystemKNMode() knMode = ApplyHoldMode(knMode, message) err := sys.SetSystemKNMode(knMode) @@ -248,56 +249,55 @@ func (b *Bridge) Start() error { configJSON, _ := json.Marshal(config) // //[/]/config - b.Publish(fmt.Sprintf("%s/climate/%s/zone%d/config", b.HassPrefix, b.ModuleName, zone.ZoneNumber), 0, true, string(configJSON)) + b.Mqtt.Publish(fmt.Sprintf("%s/climate/%s/zone%d/config", b.HassPrefix, b.ModuleName, zone.ZoneNumber), 0, true, string(configJSON)) // temperature sensor configuration: name = fmt.Sprintf("%s_zone%d_temp", b.ModuleName, zone.ZoneNumber) config = map[string]interface{}{ "name": name, "device_class": "temperature", - "expire_after": 60, "state_topic": currentTempTopic, "unit_of_measurement": "ÂșC", "unique_id": name, } configJSON, _ = json.Marshal(config) - b.Publish(fmt.Sprintf("%s/sensor/%s/zone%d_temp/config", b.HassPrefix, b.ModuleName, zone.ZoneNumber), 0, true, string(configJSON)) + b.Mqtt.Publish(fmt.Sprintf("%s/sensor/%s/zone%d_temp/config", b.HassPrefix, b.ModuleName, zone.ZoneNumber), 0, true, string(configJSON)) } sys.OnACAirflowChange = func(ac ACMachine) { airflow := sys.GetAirflow(ac) - b.Publish(b.getACTopic(ac, "airflow"), 0, true, strconv.Itoa(airflow)) + b.Mqtt.Publish(b.getACTopic(ac, "airflow"), 0, true, strconv.Itoa(airflow)) } sys.OnACTargetTempChange = func(ac ACMachine) { targetTemp := sys.GetMachineTargetTemp(ac) - b.Publish(b.getACTopic(ac, "targetTemp"), 0, true, fmt.Sprintf("%g", targetTemp)) + b.Mqtt.Publish(b.getACTopic(ac, "targetTemp"), 0, true, fmt.Sprintf("%g", targetTemp)) } sys.OnACTargetFanModeChange = func(ac ACMachine) { targetAirflow := sys.GetTargetFanMode(ac) - b.Publish(b.getACTopic(ac, "fanMode"), 0, true, FanMode2Str(targetAirflow)) + b.Mqtt.Publish(b.getACTopic(ac, "fanMode"), 0, true, FanMode2Str(targetAirflow)) } sys.OnEfficiencyChange = func() { efficiency := sys.GetEfficiency() - b.Publish(b.getSysTopic("efficiency"), 0, true, strconv.Itoa(efficiency)) + b.Mqtt.Publish(b.getSysTopic("efficiency"), 0, true, strconv.Itoa(efficiency)) } sys.OnSystemEnabledChange = func() { enabled := sys.GetSystemEnabled() - b.Publish(b.getSysTopic("enabled"), 0, true, fmt.Sprintf("%t", enabled)) + b.Mqtt.Publish(b.getSysTopic("enabled"), 0, true, fmt.Sprintf("%t", enabled)) publishHvacMode() } sys.OnKnModeChange = func() { publishHvacMode() - b.Publish(holdModeTopic, 0, true, getHoldMode()) + b.Mqtt.Publish(holdModeTopic, 0, true, getHoldMode()) } b.zw.TriggerCallbacks() b.sysw.TriggerCallbacks() - b.Publish(b.getSysTopic("serialBaud"), 0, true, strconv.Itoa(sys.GetBaudRate())) - b.Publish(b.getSysTopic("serialParity"), 0, true, sys.GetParity()) - b.Publish(b.getSysTopic("slaveId"), 0, true, strconv.Itoa(sys.GetSlaveID())) + b.Mqtt.Publish(b.getSysTopic("serialBaud"), 0, true, strconv.Itoa(sys.GetBaudRate())) + b.Mqtt.Publish(b.getSysTopic("serialParity"), 0, true, sys.GetParity()) + b.Mqtt.Publish(b.getSysTopic("slaveId"), 0, true, strconv.Itoa(sys.GetSlaveID())) return nil } diff --git a/main.go b/main.go index 1eae09f..ff83326 100644 --- a/main.go +++ b/main.go @@ -1,35 +1,24 @@ package main import ( - "crypto/tls" - "errors" - "flag" - "fmt" "koolnova2mqtt/kn" - "koolnova2mqtt/modbus" "log" "os" "os/signal" - "regexp" - "strconv" - "strings" "syscall" "time" - - MQTT "github.com/eclipse/paho.mqtt.golang" ) -func generateNodeName(slaveID string, port string) string { - reg, err := regexp.Compile("[^a-zA-Z0-9]+") - if err != nil { - log.Fatal(err) +func NewBridges(slaves map[byte]string, templateConfig *kn.Config) []*kn.Bridge { + var bridges []*kn.Bridge + for id, name := range slaves { + config := *templateConfig + config.ModuleName = name + config.SlaveID = id + bridge := kn.NewBridge(&config) + bridges = append(bridges, bridge) } - hostname, _ := os.Hostname() - - port = strings.Replace(port, "/dev/", "", -1) - port = reg.ReplaceAllString(port, "") - return strings.ToLower(fmt.Sprintf("%s_%s_%s", hostname, port, slaveID)) - + return bridges } func main() { @@ -37,162 +26,28 @@ func main() { ctrlC := make(chan os.Signal, 1) signal.Notify(ctrlC, os.Interrupt, syscall.SIGTERM) - hostname, _ := os.Hostname() + config := ParseCommandLine() - server := flag.String("server", "tcp://127.0.0.1:1883", "The full url of the MQTT server to connect to ex: tcp://127.0.0.1:1883") - //topic := flag.String("topic", "#", "Topic to subscribe to") - //qos := flag.Int("qos", 0, "The QoS to subscribe to messages at") - clientid := flag.String("clientid", hostname+strconv.Itoa(time.Now().Second()), "A clientid for the connection") - username := flag.String("username", "", "A username to authenticate to the MQTT server") - password := flag.String("password", "", "Password to match username") - prefix := flag.String("prefix", "koolnova2mqtt", "MQTT topic root where to publish/read topics") - hassPrefix := flag.String("hassPrefix", "homeassistant", "Home assistant discovery prefix") - modbusPort := flag.String("modbusPort", "/dev/ttyUSB0", "Serial port where modbus hardware is connected") - modbusPortBaudRate := flag.Int("modbusRate", 9600, "Modbus port data rate") - modbusDataBits := flag.Int("modbusDataBits", 8, "Modbus port data bits") - modbusPortParity := flag.String("modbusParity", "E", "N - None, E - Even, O - Odd (default E) (The use of no parity requires 2 stop bits.)") - modbusStopBits := flag.Int("modbusStopBits", 1, "Modbus port stop bits") - modbusSlaveList := flag.String("modbusSlaveIDs", "49", "Comma-separated list of modbus slave IDs to manage") - modbusSlaveNames := flag.String("modbusSlaveNames", "", "Comma-separated list of modbus slave names. Defaults to 'slave#'") - - flag.Parse() - - mb, err := modbus.New(&modbus.Config{ - Port: *modbusPort, - BaudRate: *modbusPortBaudRate, - DataBits: *modbusDataBits, - Parity: *modbusPortParity, - StopBits: *modbusStopBits, - Timeout: 200 * time.Millisecond, - }) - if err != nil { - log.Fatalf("Error initializing modbus: %s", err) - } - defer mb.Close() - - var mqttClient MQTT.Client - publish := func(topic string, qos byte, retained bool, payload string) { - client := mqttClient - if client == nil { - log.Printf("Cannot publish message %q to topic %s. MQTT client is disconnected", payload, topic) - return - } - client.Publish(topic, qos, retained, payload) - } - - subscribe := func(topic string, callback func(message string)) error { - client := mqttClient - if client == nil { - log.Printf("Cannot subscribe to topic %s. MQTT client is disconnected", topic) - return errors.New("Client is disconnected") - } - token := client.Subscribe(topic, 0, func(c MQTT.Client, m MQTT.Message) { - cbclient := mqttClient - if cbclient != client { - log.Printf("Cannot invoke callback to topic %s. MQTT client is disconnected", topic) - } - callback(string(m.Payload())) - }) - token.Wait() - return token.Error() - } - - var snameList []string - slist := strings.Split(*modbusSlaveList, ",") - - if *modbusSlaveNames == "" { - for _, slaveIDStr := range slist { - snameList = append(snameList, generateNodeName(slaveIDStr, *modbusPort)) - } - } else { - snameList = strings.Split(*modbusSlaveNames, ",") - if len(slist) != len(snameList) { - log.Fatalf("modbusSlaveIDs and modbusSlaveNames lists must have the same length") - } - } - - var bridges []*kn.Bridge - for i, slaveIDStr := range slist { - slaveID, err := strconv.Atoi(slaveIDStr) - slaveName := snameList[i] - if err != nil { - log.Fatalf("Error parsing slaveID list") - } - bridge := kn.NewBridge(&kn.Config{ - ModuleName: slaveName, - SlaveID: byte(slaveID), - Publish: publish, - Subscribe: subscribe, - TopicPrefix: *prefix, - HassPrefix: *hassPrefix, - Modbus: mb, - }) - bridges = append(bridges, bridge) - } - - connOpts := MQTT.NewClientOptions().AddBroker(*server).SetClientID(*clientid).SetCleanSession(true) - if *username != "" { - connOpts.SetUsername(*username) - if *password != "" { - connOpts.SetPassword(*password) - } - } - tlsConfig := &tls.Config{InsecureSkipVerify: true, ClientAuth: tls.NoClientCert} - connOpts.SetTLSConfig(tlsConfig) - onConnect := false - connOpts.OnConnect = func(c MQTT.Client) { - onConnect = true - } - var started bool - connOpts.OnConnectionLost = func(c MQTT.Client, err error) { - log.Printf("Connection to MQTT server lost: %s\n", err) - mqttClient = nil - started = false - } - - connectMQTT := func() error { - mqttClient = MQTT.NewClient(connOpts) - - if token := mqttClient.Connect(); token.Wait() && token.Error() != nil { - mqttClient = nil - return token.Error() - } else { - log.Printf("Connected to %s\n", *server) - } - return nil - } - - ticker := time.NewTicker(2 * time.Second) go func() { + ticker := time.NewTicker(2 * time.Second) + var sessionID int + var bridges []*kn.Bridge for range ticker.C { - if mqttClient == nil { - err := connectMQTT() - if err != nil { - log.Printf("Error connecting to MQTT server: %s\n", err) - continue + newSessionID := config.MqttClient.ID + if sessionID != newSessionID { + bridges = NewBridges(config.slaves, config.BridgeTemplateConfig) + for _, b := range bridges { + err := b.Start() + if err != nil { + log.Printf("Error starting bridge: %s\n", err) + break + } else { + sessionID = newSessionID + } } - } - client := mqttClient - if client != nil && client.IsConnected() { - if onConnect { - onConnect = false - for _, b := range bridges { - err := b.Start() - if err != nil { - log.Printf("Error starting bridge: %s\n", err) - client.Disconnect(100) - mqttClient = nil - break - } else { - started = true - } - } - } else { - if started { - for _, b := range bridges { - b.Tick() - } - } + } else { + for _, b := range bridges { + b.Tick() } } } @@ -200,4 +55,6 @@ func main() { <-ctrlC + config.MqttClient.Close() + } diff --git a/mqtt/mqtt.go b/mqtt/mqtt.go new file mode 100644 index 0000000..90767de --- /dev/null +++ b/mqtt/mqtt.go @@ -0,0 +1,103 @@ +package mqtt + +import ( + "crypto/tls" + "errors" + "log" + "time" + + MQTT "github.com/eclipse/paho.mqtt.golang" +) + +type Config struct { + Server string + ClientID string + Username string + Password string +} + +type Client struct { + client MQTT.Client + ID int + closed bool +} + +var ErrNotConnected = errors.New("MQTT client not connected") + +func New(config *Config) *Client { + m := &Client{} + + connOpts := MQTT.NewClientOptions(). + AddBroker(config.Server). + SetClientID(config.ClientID). + SetCleanSession(true). + SetAutoReconnect(false) + + if config.Username != "" { + connOpts.SetUsername(config.Username) + if config.Password != "" { + connOpts.SetPassword(config.Password) + } + } + + tlsConfig := &tls.Config{InsecureSkipVerify: true, ClientAuth: tls.NoClientCert} + connOpts.SetTLSConfig(tlsConfig) + + connOpts.OnConnectionLost = func(c MQTT.Client, err error) { + log.Printf("MQTT disconnected: %s\n", err) + } + + connect := func() { + log.Printf("Trying to connect to MQTT %s ...\n", config.Server) + newClient := MQTT.NewClient(connOpts) + token := newClient.Connect() + token.Wait() + if token.Error() == nil { + m.client = newClient + m.ID++ + log.Printf("Connected to MQTT. Session ID %d\n", m.ID) + } + } + + connect() + go func() { + ticker := time.NewTicker(5 * time.Second) + for range ticker.C { + if m.closed { + return + } + if m.client == nil || !m.client.IsConnectionOpen() { + connect() + } + } + if m.client != nil { + m.client.Disconnect(100) + } + }() + return m +} + +func (m *Client) Publish(topic string, qos byte, retained bool, payload string) error { + if m.client == nil { + return ErrNotConnected + } + token := m.client.Publish(topic, qos, retained, payload) + token.Wait() + return token.Error() +} + +func (m *Client) Subscribe(topic string, callback func(message string)) error { + if m.client == nil { + return ErrNotConnected + } + token := m.client.Subscribe(topic, 0, func(c MQTT.Client, m MQTT.Message) { + callback(string(m.Payload())) + }) + token.Wait() + return token.Error() +} + +func (m *Client) Close() error { + m.closed = true + return nil +}