From 2076619576975b8a3b90617acaf372691e13b33b Mon Sep 17 00:00:00 2001 From: Jordan Hotmann Date: Fri, 17 Nov 2023 11:24:12 -0700 Subject: [PATCH] API token auth, readme updates --- README.md | 91 ++++++++++++++++++++++++++++++++++++++- internal/api/api.go | 39 ++++++++++++----- pkg/client/client.go | 5 ++- pkg/config/config.go | 24 +++++++---- pkg/homeassistant/util.go | 4 +- 5 files changed, 140 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index dddeeaf..a889c8e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,92 @@ # HATS -Push Home Assistant websocket events to a NATS message queue. Additionally acts as a caching proxy for Home Assistant API requests. \ No newline at end of file +[Home Assistant](https://www.home-assistant.io/) + [NATS](https://nats.io/) = HATS + +## Features +- Push Home Assistant websocket events to a NATS message queue +- Caching proxy for Home Assistant API +- Clients for some application APIs (limited functionality) + - [Gokapi](https://github.com/Forceu/Gokapi) + - [ntfy](https://github.com/binwiederhier/ntfy) + - [qBittorrent](https://github.com/qbittorrent/qBittorrent) + - [Syncthing](https://github.com/syncthing/syncthing) + - [National Weather Service](https://www.weather.gov/) + +## Example Client + +```golang +package main + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "os/signal" + "syscall" + + "code.jhot.me/jhot/hats/pkg/client" + "code.jhot.me/jhot/hats/pkg/config" + ha "code.jhot.me/jhot/hats/pkg/homeassistant" + n "code.jhot.me/jhot/hats/pkg/nats" + "github.com/nats-io/nats.go" +) + +var ( + logger *slog.Logger + hatsClient *client.HatsClient + natsClient *n.NatsConnection +) + +func main() { + cfg := config.FromEnvironment() + logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: cfg.GetLogLevel(), + })) + + hatsClient = client.NewHatsClient(cfg.GetHatsBaseUrl(), cfg.HatsToken) + natsClient = n.DefaultNatsConnection().WithJetstream(false).WithHostName(cfg.NatsHost).WithPort(cfg.NatsPort).WithConnectionOption(nats.Name(cfg.NatsClientName)) + defer natsClient.Close() + + go GenericStateListener("sun.sun", SunHandler) + + sigch := make(chan os.Signal, 1) + signal.Notify(sigch, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) + <-sigch + logger.Info("SIGTERM received") +} + +func SunHandler(state ha.StateData) error { + return hatsClient.CallService("light.some_light", ha.Services.TurnOn) +} + +func GenericStateListener(entityId string, handler func(ha.StateData) error) { + topic := fmt.Sprintf("homeassistant.states.%s.*", entityId) + l := logger.With("topic", topic, "entity_id", entityId) + l.Debug("Subscribing to topic") + sub, ch, err := natsClient.Subscribe(topic) + if err != nil { + l.Error("Error subscribing to topic", "error", err) + return + } + + defer sub.Unsubscribe() + + for msg := range ch { + msg.Ack() + var data ha.EventData + err = json.Unmarshal(msg.Data, &data) + if err != nil { + l.Error("Error parsing message", "error", err) + continue + } + l.Debug("Event state " + data.NewState.State) + err = handler(data.NewState) + if err != nil { + l.Error("Error handling state event", "error", err) + continue + } + } +} + +``` \ No newline at end of file diff --git a/internal/api/api.go b/internal/api/api.go index f3623e1..d3b1211 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "log/slog" @@ -37,8 +38,10 @@ func Listen(parentLogger *slog.Logger) { router.Use(middleware.RequestID) router.Use(middleware.RealIP) + router.Use(loggerMiddleware) router.Use(middleware.Recoverer) router.Use(middleware.Timeout(60 * time.Second)) + router.Use(tokenAuthMiddleware) router.Get(`/api/state/{entityId}`, getEntityStateHandler) router.Post("/api/state/{entityId}/{service}", setEntityStateHandler) @@ -69,14 +72,36 @@ func Close() { } } -func logRequest(w http.ResponseWriter, r *http.Request) { - logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) +func tokenAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cfg.HatsToken == "" { // No token required + next.ServeHTTP(w, r) + return + } + + authHeaderParts := strings.Split(r.Header.Get("Authorization"), "") + switch { + case len(authHeaderParts) != 2: + case authHeaderParts[0] != "Bearer": + case authHeaderParts[1] != cfg.HatsToken: + http.Error(w, "Bearer authorization header doesn't match configured token", http.StatusUnauthorized) + return + default: + next.ServeHTTP(w, r) + } + }) +} + +func loggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) + next.ServeHTTP(w, r) + }) } // HOME ASSISTANT ENTITIES func getEntityStateHandler(w http.ResponseWriter, r *http.Request) { - logRequest(w, r) entityId := chi.URLParam(r, "entityId") full := r.URL.Query().Get("full") == "true" @@ -103,7 +128,6 @@ func getEntityStateHandler(w http.ResponseWriter, r *http.Request) { } func setEntityStateHandler(w http.ResponseWriter, r *http.Request) { - logRequest(w, r) entityId := chi.URLParam(r, "entityId") service := chi.URLParam(r, "service") domain := r.URL.Query().Get("domain") @@ -139,7 +163,6 @@ func setEntityStateHandler(w http.ResponseWriter, r *http.Request) { func getTimerHandler(w http.ResponseWriter, r *http.Request) { timerName := chi.URLParam(r, "timerName") - logRequest(w, r) timer, err := nats.GetTimer(timerName) if err != nil { @@ -157,7 +180,6 @@ type StartTimerData struct { func startTimerHandler(w http.ResponseWriter, r *http.Request) { timerName := chi.URLParam(r, "timerName") - logRequest(w, r) data := &StartTimerData{} if err := render.DecodeJSON(r.Body, data); err != nil { @@ -185,7 +207,6 @@ func startTimerHandler(w http.ResponseWriter, r *http.Request) { } func deleteTimerHandler(w http.ResponseWriter, r *http.Request) { - logRequest(w, r) timerName := chi.URLParam(r, "timerName") timer, err := nats.GetTimer(timerName) @@ -202,7 +223,6 @@ func deleteTimerHandler(w http.ResponseWriter, r *http.Request) { func getScheduleHandler(w http.ResponseWriter, r *http.Request) { scheduleName := chi.URLParam(r, "scheduleName") - logRequest(w, r) schedule, err := nats.GetSchedule(scheduleName) if err != nil { @@ -219,7 +239,6 @@ type CreateScheduleData struct { func createScheduleHandler(w http.ResponseWriter, r *http.Request) { scheduleName := chi.URLParam(r, "scheduleName") - logRequest(w, r) data := &CreateScheduleData{} if err := render.DecodeJSON(r.Body, data); err != nil { @@ -239,7 +258,6 @@ func createScheduleHandler(w http.ResponseWriter, r *http.Request) { } func deleteScheduleHandler(w http.ResponseWriter, r *http.Request) { - logRequest(w, r) scheduleName := chi.URLParam(r, "scheduleName") schedule, err := nats.GetSchedule(scheduleName) @@ -273,7 +291,6 @@ func postNtfyHandler(w http.ResponseWriter, r *http.Request) { func postCommandHandler(w http.ResponseWriter, r *http.Request) { commandName := chi.URLParam(r, "commandName") - logRequest(w, r) switch commandName { // Commands without payloads diff --git a/pkg/client/client.go b/pkg/client/client.go index c3be2af..13fb8c4 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -13,8 +13,11 @@ type HatsClient struct { client *resty.Client } -func NewHatsClient(baseUrl string) *HatsClient { +func NewHatsClient(baseUrl string, token string) *HatsClient { client := resty.New().SetBaseURL(baseUrl) + if token != "" { + client.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + } return &HatsClient{ client: client, } diff --git a/pkg/config/config.go b/pkg/config/config.go index fc2ab7b..84a6625 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -24,6 +24,7 @@ type HatsConfig struct { HatsHost string HatsPort string + HatsToken string HatsSecure bool NtfyHost string @@ -32,18 +33,23 @@ type HatsConfig struct { func FromEnvironment() *HatsConfig { config := &HatsConfig{ - LogLevl: util.GetEnvWithDefault("LOG_LEVEL", "INFO"), + LogLevl: util.GetEnvWithDefault("LOG_LEVEL", "INFO"), + HomeAssistantHost: util.GetEnvWithDefault("HASS_HOST", "127.0.0.1"), HomeAssistantPort: util.GetEnvWithDefault("HASS_PORT", "8123"), HomeAssistantToken: util.GetEnvWithDefault("HASS_TOKEN", ""), - NatsHost: util.GetEnvWithDefault("NATS_HOST", "127.0.0.1"), - NatsPort: util.GetEnvWithDefault("NATS_PORT", "4222"), - NatsToken: util.GetEnvWithDefault("NATS_TOKEN", ""), - NatsClientName: util.GetEnvWithDefault("NATS_CLIENT_NAME", "hats"), - HatsHost: util.GetEnvWithDefault("HATS_HOST", "hats"), - HatsPort: util.GetEnvWithDefault("HATS_PORT", "8888"), - NtfyHost: util.GetEnvWithDefault("NTFY_HOST", "https://ntfy.sh"), - NtfyToken: util.GetEnvWithDefault("NTFY_TOKEN", ""), + + NatsHost: util.GetEnvWithDefault("NATS_HOST", "127.0.0.1"), + NatsPort: util.GetEnvWithDefault("NATS_PORT", "4222"), + NatsToken: util.GetEnvWithDefault("NATS_TOKEN", ""), + NatsClientName: util.GetEnvWithDefault("NATS_CLIENT_NAME", "hats"), + + HatsHost: util.GetEnvWithDefault("HATS_HOST", "hats"), + HatsPort: util.GetEnvWithDefault("HATS_PORT", "8888"), + HatsToken: util.GetEnvWithDefault("HATS_TOKEN", ""), + + NtfyHost: util.GetEnvWithDefault("NTFY_HOST", "https://ntfy.sh"), + NtfyToken: util.GetEnvWithDefault("NTFY_TOKEN", ""), } config.HomeAssistantSecure, _ = strconv.ParseBool(util.GetEnvWithDefault("HASS_SECURE", "false")) diff --git a/pkg/homeassistant/util.go b/pkg/homeassistant/util.go index 285bc9a..c8f5cf5 100644 --- a/pkg/homeassistant/util.go +++ b/pkg/homeassistant/util.go @@ -11,8 +11,10 @@ import ( // // States that return true: "on", "home", "open", "playing", non-zero numbers, etc. // All others return false +// +// regex: ^(on|home|open(ing)?|unlocked|playing|active|good|walking|charging|alive|heat|cool|heat_cool|above_horizon|[1-9][\d\.]*|0\.0*[1-9]\d*)$ func StateToBool(state string) bool { - trueRegex := regexp.MustCompile(`^(on|home|open(ing)?|unlocked|playing|active|good|walking|charging|alive|heat|cool|heat_cool|[1-9][\d\.]*|0\.0*[1-9]\d*)$`) + trueRegex := regexp.MustCompile(`^(on|home|open(ing)?|unlocked|playing|active|good|walking|charging|alive|heat|cool|heat_cool|above_horizon|[1-9][\d\.]*|0\.0*[1-9]\d*)$`) return trueRegex.MatchString(state) }