commit 1d1c0a9e8200061469911f020f36e5d1193bffe3 Author: Jordan Hotmann Date: Thu Oct 12 11:23:35 2023 -0600 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c9ef98a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.envrc +temp/ +.vscode/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..dd3cc7b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM golang:1 as builder + +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download +COPY . ./ +RUN CGO_ENABLED=0 GOOS=linux go build -o /hats + +FROM builder as tester + +RUN go test -v + +FROM scratch + +COPY --from=builder --chmod=755 /hats /hats +ENTRYPOINT [ "/hats" ] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..dddeeaf --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# HATS + +Push Home Assistant websocket events to a NATS message queue. Additionally acts as a caching proxy for Home Assistant API requests. \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f1c1694 --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module code.jhot.me/jhot/hats + +go 1.21.1 + +require ( + github.com/gorilla/websocket v1.5.0 + github.com/nats-io/nats.go v1.30.2 +) + +require ( + github.com/ajg/form v1.5.1 // indirect + golang.org/x/net v0.15.0 // indirect + golang.org/x/time v0.3.0 // indirect +) + +require ( + github.com/go-chi/chi v1.5.5 + github.com/go-chi/chi/v5 v5.0.10 + github.com/go-chi/render v1.0.3 + github.com/go-resty/resty/v2 v2.9.1 + github.com/golang/protobuf v1.5.3 // indirect + github.com/klauspost/compress v1.17.0 // indirect + github.com/nats-io/nats-server/v2 v2.10.2 // indirect + github.com/nats-io/nkeys v0.4.5 // indirect + github.com/nats-io/nuid v1.0.1 // indirect + golang.org/x/crypto v0.13.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b61817f --- /dev/null +++ b/go.sum @@ -0,0 +1,81 @@ +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= +github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= +github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= +github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= +github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-resty/resty/v2 v2.9.1 h1:PIgGx4VrHvag0juCJ4dDv3MiFRlDmP0vicBucwf+gLM= +github.com/go-resty/resty/v2 v2.9.1/go.mod h1:4/GYJVjh9nhkhGR6AUNW3XhpDYNUr+Uvy9gV/VGZIy4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= +github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= +github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= +github.com/nats-io/nats-server/v2 v2.10.2 h1:2o/OOyc/dxeMCQtrF1V/9er0SU0A3LKhDlv/+rqreBM= +github.com/nats-io/nats-server/v2 v2.10.2/go.mod h1:lzrskZ/4gyMAh+/66cCd+q74c6v7muBypzfWhP/MAaM= +github.com/nats-io/nats.go v1.30.2 h1:aloM0TGpPorZKQhbAkdCzYDj+ZmsJDyeo3Gkbr72NuY= +github.com/nats-io/nats.go v1.30.2/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= +github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= +github.com/nats-io/nkeys v0.4.5/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000..8b01b9a --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,96 @@ +package api + +import ( + "fmt" + "net/http" + "time" + + "log/slog" + + "code.jhot.me/jhot/hats/internal/nats" + "code.jhot.me/jhot/hats/pkg/config" + "code.jhot.me/jhot/hats/pkg/homeassistant" + "github.com/go-chi/chi/middleware" + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" +) + +var ( + cfg *config.HatsConfig + logger *slog.Logger + server http.Server + haClient *homeassistant.RestClient +) + +const ( + HA_STATE_PREFIX = "homeassistant.states" +) + +func Listen(parentLogger *slog.Logger) { + logger = parentLogger + cfg = config.FromEnvironment() + haClient = homeassistant.NewRestClient(cfg.HomeAssistantBaseUrl, cfg.HomeAssistantToken) + router := chi.NewRouter() + + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(middleware.Recoverer) + router.Use(middleware.Timeout(60 * time.Second)) + + router.Get(`/api/state/{entityId}`, func(w http.ResponseWriter, r *http.Request) { + logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) + entityId := chi.URLParam(r, "entityId") + + kvVal, err := nats.GetKeyValue(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId)) + if err == nil && len(kvVal) > 0 { + w.Write(kvVal) + return + } + + data, err := haClient.GetState(entityId) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + nats.SetKeyValueString(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId), data.State) + render.PlainText(w, r, data.State) + }) + + router.Post("/api/state/{entityId}/{service}", func(w http.ResponseWriter, r *http.Request) { + logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) + entityId := chi.URLParam(r, "entityId") + service := chi.URLParam(r, "service") + + var extras map[string]string + err := render.DecodeJSON(r.Body, &extras) + var haErr error + if err == nil && len(extras) > 0 { + haErr = haClient.CallService(entityId, service, extras) + } else { + haErr = haClient.CallService(entityId, service) + } + + if haErr != nil { + logger.Error("Error setting state", "error", haErr) + http.Error(w, fmt.Sprintf("error proxying request: %s", haErr.Error()), http.StatusInternalServerError) + return + } + + render.Status(r, http.StatusOK) + render.PlainText(w, r, "OK") + }) + + server = http.Server{ + Addr: ":8888", + Handler: router, + } + + go server.ListenAndServe() +} + +func Close() { + if server.Addr != "" { + server.Close() + } +} diff --git a/internal/homeassistant/subscriber.go b/internal/homeassistant/subscriber.go new file mode 100644 index 0000000..a9d3af9 --- /dev/null +++ b/internal/homeassistant/subscriber.go @@ -0,0 +1,144 @@ +package homeassistant + +import ( + "encoding/json" + "errors" + "fmt" + "time" + + "log/slog" + + "code.jhot.me/jhot/hats/internal/nats" + "code.jhot.me/jhot/hats/pkg/config" + ha "code.jhot.me/jhot/hats/pkg/homeassistant" + "github.com/gorilla/websocket" +) + +var ( + cfg *config.HatsConfig + logger *slog.Logger + haWebsocketConn *websocket.Conn + done chan struct{} +) + +const ( + stateChangeEventId = 1001 + zhaEventId = 1002 + qrEventId = 1003 +) + +func CloseSubscription() error { + if haWebsocketConn != nil { + logger.Debug("Closing Home Assistant subscription") + haWebsocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + close(done) + return haWebsocketConn.Close() + } + return nil +} + +func Subscribe(parentLogger *slog.Logger) error { + logger = parentLogger + cfg = config.FromEnvironment() + var err error + + url := fmt.Sprintf(cfg.HomeAssistantWebsocketUrl) + + logger.Debug("Dialing Home Assistant websocket API", "url", url) + + haWebsocketConn, _, err = websocket.DefaultDialer.Dial(url, nil) + if err != nil { + return fmt.Errorf("%w: error dialing Home Assistant websocket", err) + } + + done = make(chan struct{}) + + handleMessages() + return nil +} + +func reconnect() { + haWebsocketConn.Close() + attempts := 1 + for { + if attempts > 10 { + panic(errors.New("unable to reconnect to Home Assistant")) + } + + time.Sleep(time.Duration(attempts) * 5 * time.Second) + logger.Info("Trying to reconnect to Home Assistant", "attempt", attempts) + + err := Subscribe(logger) + if err == nil { + break + } + + attempts += 1 + } +} + +func handleMessages() { + go func() { + defer close(done) + for { + _, rawMessage, err := haWebsocketConn.ReadMessage() + if err != nil { + logger.Error("Error reading Home Assistant websocket message", "error", err) + reconnect() + break + } + if len(rawMessage) == 0 { + continue + } + var message ha.HassMessage + err = json.Unmarshal(rawMessage, &message) + if err != nil { + logger.Error("Error parsing HASS message", "message", string(rawMessage), "error", err) + continue + } + + switch message.Type { + case ha.MessageType.AuthRequired: + logger.Debug("Logging in to HomeAssistant websocket API") + haWebsocketConn.WriteJSON(ha.AuthMessage{Type: "auth", AccessToken: cfg.HomeAssistantToken}) + case ha.MessageType.AuthOk: + logger.Debug("Subscribing to Events") + haWebsocketConn.WriteJSON(ha.SubscribeEventsMessage{ + Type: ha.MessageType.SubscribeEvents, + EventType: ha.MessageType.StateChanged, + Id: stateChangeEventId}) + haWebsocketConn.WriteJSON(ha.SubscribeEventsMessage{ + Type: ha.MessageType.SubscribeEvents, + EventType: ha.MessageType.ZhaEvent, + Id: zhaEventId}) + haWebsocketConn.WriteJSON(ha.SubscribeEventsMessage{ + Type: ha.MessageType.SubscribeEvents, + EventType: ha.MessageType.TagScanned, + Id: qrEventId}) + case ha.MessageType.Result: + if !message.Success { + logger.Error("Non-Success Result:", "message", message) + reconnect() + return + } + case ha.MessageType.Event: + logger.Debug("Event received", "event", message.Event) + switch message.Id { + case stateChangeEventId: + data, marshallErr := json.Marshal(message.Event.Data) + if marshallErr != nil { + logger.Error("Error marshalling event data", "error", marshallErr) + } + nats.Publish(fmt.Sprintf("homeassistant.states.%s.%s", message.Event.Data.EntityId, message.Event.Data.NewState.State), data) + nats.SetKeyValueString(fmt.Sprintf("homeassistant.states.%s", message.Event.Data.EntityId), message.Event.Data.NewState.State) + case zhaEventId: + data, _ := json.Marshal(message.Event.Data) + nats.Publish(fmt.Sprintf("homeassistant.zha.%s", message.Event.Data.DeviceIeee), data) + case qrEventId: + data, _ := json.Marshal(message.Event.Data) + nats.Publish(fmt.Sprintf("homeassistant.qr.%s", message.Event.Data.TagId), data) + } + } + } + }() +} diff --git a/internal/nats/client.go b/internal/nats/client.go new file mode 100644 index 0000000..8244b5b --- /dev/null +++ b/internal/nats/client.go @@ -0,0 +1,102 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "code.jhot.me/jhot/hats/pkg/config" + n "code.jhot.me/jhot/hats/pkg/nats" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +var ( + cfg *config.HatsConfig + client *n.NatsConnection + kv jetstream.KeyValue + ctx context.Context + logger *slog.Logger +) + +func Close() { + client.Close() +} + +func JetstreamConnect(parentContext context.Context, parentLogger *slog.Logger) error { + ctx = parentContext + logger = parentLogger + cfg = config.FromEnvironment() + var err error + + client = n.DefaultNatsConnection().WithHostName(cfg.NatsHost).WithPort(cfg.NatsPort).WithConnectionOption(nats.Name(cfg.NatsClientName)) + + if cfg.NatsToken != "" { + client.WithConnectionOption(nats.Token(cfg.NatsToken)) + } + + logger.Debug("Connecting to nats") + err = client.Connect() + if err != nil { + return fmt.Errorf("%w: error connecting to nats server", err) + } + + return nil +} + +func KvConnect() error { + if client.JS == nil { + return errors.New("jetstream must be connected first") + } + + logger.Debug("Looking for KV store") + listener := client.JS.KeyValueStoreNames(ctx) + found := false + for name := range listener.Name() { + if name == "KV_hats" { + found = true + } + } + + var err error + if found { + logger.Debug("Connecting for KV store") + kv, err = client.JS.KeyValue(ctx, "hats") + } else { + logger.Debug("Creating for KV store") + kv, err = client.JS.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "hats", + TTL: 2 * time.Hour, + }) + } + + return err +} + +func Publish(subject string, message []byte) { + client.Publish(ctx, subject, message, jetstream.WithRetryAttempts(2), jetstream.WithRetryWait(500*time.Millisecond)) +} + +func PublishString(subject, message string) { + Publish(subject, []byte(message)) +} + +func GetKeyValue(key string) ([]byte, error) { + value, err := kv.Get(ctx, key) + if err != nil { + return []byte{}, err + } + return value.Value(), nil +} + +func SetKeyValue(key string, value []byte) error { + _, err := kv.Put(ctx, key, value) + return err +} + +func SetKeyValueString(key, value string) error { + _, err := kv.PutString(ctx, key, value) + return err +} diff --git a/internal/util/os.go b/internal/util/os.go new file mode 100644 index 0000000..b2b7f06 --- /dev/null +++ b/internal/util/os.go @@ -0,0 +1,11 @@ +package util + +import "os" + +func GetEnvWithDefault(key, defaultValue string) string { + val := os.Getenv(key) + if val != "" { + return val + } + return defaultValue +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..42820ef --- /dev/null +++ b/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "context" + "log/slog" + "os" + "os/signal" + + "code.jhot.me/jhot/hats/internal/api" + "code.jhot.me/jhot/hats/internal/homeassistant" + "code.jhot.me/jhot/hats/internal/nats" +) + +var ( + interrupt chan os.Signal + logger *slog.Logger + ctx context.Context + cancel context.CancelFunc +) + +func main() { + ctx, cancel = context.WithCancel(context.Background()) + logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + interrupt = make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + err := nats.JetstreamConnect(ctx, logger) + if err != nil { + panic(err) + } + err = nats.KvConnect() + if err != nil { + panic(err) + } + defer nats.Close() + + err = homeassistant.Subscribe(logger) + if err != nil { + panic(err) + } + defer homeassistant.CloseSubscription() + + api.Listen(logger) + defer api.Close() + + for sig := range interrupt { + logger.Debug("Interrupt:", "signal", sig.String()) + cancel() + break + } +} diff --git a/pkg/client/client.go b/pkg/client/client.go new file mode 100644 index 0000000..5672f3b --- /dev/null +++ b/pkg/client/client.go @@ -0,0 +1,66 @@ +package client + +import ( + "fmt" + + "code.jhot.me/jhot/hats/pkg/homeassistant" + "github.com/go-resty/resty/v2" +) + +type HatsClient struct { + client *resty.Client +} + +func NewHatsClient(baseUrl string) *HatsClient { + client := resty.New().SetBaseURL(baseUrl) + return &HatsClient{ + client: client, + } +} + +func (c *HatsClient) GetState(entityId string) (string, error) { + resp, err := c.client.R().Get(fmt.Sprintf("api/state/%s", entityId)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return "", err + } + + return resp.String(), nil +} + +func (c *HatsClient) GetStateBool(entityId string) (bool, error) { + stateString, err := c.GetState(entityId) + + if err != nil { + return false, err + } + + return homeassistant.StateToBool(stateString), nil +} + +func (c *HatsClient) CallService(entityId string, service string, extras ...map[string]string) error { + req := c.client.R() + if len(extras) > 0 { + data := map[string]interface{}{} + for _, extra := range extras { + for k, v := range extra { + data[k] = v + } + } + req.SetBody(data) + } + + resp, err := req.Post(fmt.Sprintf("api/state/%s/%s", entityId, service)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return err + } + + return nil +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..fd7fd54 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,48 @@ +package config + +import ( + "fmt" + "strconv" + + "code.jhot.me/jhot/hats/internal/util" +) + +type HatsConfig struct { + HomeAssistantHost string + HomeAssistantPort string + HomeAssistantSecure bool + HomeAssistantBaseUrl string + HomeAssistantWebsocketUrl string + HomeAssistantToken string + + NatsHost string + NatsPort string + NatsBaseUrl string + NatsToken string + NatsClientName string +} + +func FromEnvironment() *HatsConfig { + config := &HatsConfig{ + HomeAssistantHost: util.GetEnvWithDefault("HASS_HOST", "127.0.0.1"), + HomeAssistantPort: util.GetEnvWithDefault("HASS_PORT", "8123"), + HomeAssistantToken: util.GetEnvWithDefault("HASS_TOKEN", ""), + NatsHost: util.GetEnvWithDefault("NATS_HOST", "127.0.0.1"), + NatsPort: util.GetEnvWithDefault("NATS_PORT", "4222"), + NatsToken: util.GetEnvWithDefault("NATS_TOKEN", ""), + NatsClientName: util.GetEnvWithDefault("NATS_CLIENT_NAME", "hats"), + } + + config.HomeAssistantSecure, _ = strconv.ParseBool(util.GetEnvWithDefault("HASS_SECURE", "false")) + hassProtocol := "http" + HassWsProtocol := "ws" + if config.HomeAssistantSecure { + hassProtocol += "s" + HassWsProtocol += "s" + } + config.HomeAssistantBaseUrl = fmt.Sprintf("%s://%s:%s", hassProtocol, config.HomeAssistantHost, config.HomeAssistantPort) + config.HomeAssistantWebsocketUrl = fmt.Sprintf("%s://%s:%s/api/websocket", HassWsProtocol, config.HomeAssistantHost, config.HomeAssistantPort) + config.NatsBaseUrl = fmt.Sprintf("nats://%s:%s", config.NatsHost, config.NatsPort) + + return config +} diff --git a/pkg/homeassistant/rest.go b/pkg/homeassistant/rest.go new file mode 100644 index 0000000..4a8604d --- /dev/null +++ b/pkg/homeassistant/rest.go @@ -0,0 +1,49 @@ +package homeassistant + +import ( + "fmt" + "strings" + + "github.com/go-resty/resty/v2" +) + +type RestClient struct { + client *resty.Client +} + +func NewRestClient(baseUrl, token string) *RestClient { + client := resty.New().SetBaseURL(baseUrl) + client.SetHeaders(map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", token), + "Accept": "application/json", + }) + return &RestClient{ + client: client, + } +} + +func (c *RestClient) GetState(entityId string) (StateData, error) { + var data StateData + resp, err := c.client.R().SetResult(&data).Get(fmt.Sprintf("api/states/%s", entityId)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + return data, err +} + +func (c *RestClient) CallService(entityId string, service string, extras ...map[string]string) error { + domain := strings.Split(entityId, ".")[0] + data := map[string]interface{}{ + "entity_id": entityId, + } + for _, extra := range extras { + for k, v := range extra { + data[k] = v + } + } + resp, err := c.client.R().SetBody(data).Post(fmt.Sprintf("api/services/%s/%s", domain, service)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + return err +} diff --git a/pkg/homeassistant/structs.go b/pkg/homeassistant/structs.go new file mode 100644 index 0000000..0a3c817 --- /dev/null +++ b/pkg/homeassistant/structs.go @@ -0,0 +1,127 @@ +package homeassistant + +// Message types that can be returned from Home Assistants websocket API +var MessageType = struct { + AuthRequired string + AuthOk string + AuthInvalid string + Result string + Event string + ZhaEvent string + SubscribeEvents string + StateChanged string + TagScanned string +}{ + AuthRequired: "auth_required", + AuthOk: "auth_ok", + AuthInvalid: "auth_invalid", + Result: "result", + Event: "event", + ZhaEvent: "zha_event", + SubscribeEvents: "subscribe_events", + StateChanged: "state_changed", + TagScanned: "tag_scanned", +} + +// Home Assistant device domains +var Domains = struct { + Light string + Switch string + Lock string + Cover string +}{ + Light: "light", + Switch: "switch", + Lock: "lock", + Cover: "cover", +} + +// Home Assistant services +var Services = struct { + TurnOn string + TurnOff string + Toggle string + Reload string + Lock string + Unlock string + OpenCover string + CloseCover string +}{ + TurnOn: "turn_on", + TurnOff: "turn_off", + Toggle: "toggle", + Reload: "reload", + Lock: "lock", + Unlock: "unlock", + OpenCover: "open_cover", + CloseCover: "close_cover", +} + +// Extra props that can be sent when calling a Home Assistant service +var ExtraProps = struct { + Transition string + Brightness string + BrightnessPercent string +}{ + Transition: "transition", + Brightness: "brightness", + BrightnessPercent: "brightness_pct", +} + +type ResultContext struct { + Id string `json:"id,omitempty"` +} + +type Result struct { + Context ResultContext `json:"context,omitempty"` +} + +type StateData struct { + LastChanged string `json:"last_changed,omitempty"` + LastUpdated string `json:"last_updated,omitempty"` + State string `json:"state,omitempty"` + Attributes map[string]interface{} `json:"attributes,omitempty"` + Context interface{} `json:"context,omitempty"` +} + +type EventData struct { + EntityId string `json:"entity_id,omitempty"` + NewState StateData `json:"new_state,omitempty"` + OldState StateData `json:"old_state,omitempty"` + DeviceIeee string `json:"device_ieee,omitempty"` + DeviceId string `json:"device_id,omitempty"` + Command string `json:"command,omitempty"` + Args interface{} `json:"args,omitempty"` + Params interface{} `json:"params,omitempty"` + TagId string `json:"tag_id,omitempty"` +} + +type Event struct { + Data EventData `json:"data,omitempty"` + EventType string `json:"event_type,omitempty"` + TimeFired string `json:"time_fired,omitempty"` + Origin string `json:"origin,omitempty"` +} + +type HassMessage struct { + Type string `json:"type"` + Version string `json:"ha_version,omitempty"` + AccessToken string `json:"access_token,omitempty"` + Message string `json:"message,omitempty"` + Success bool `json:"success,omitempty"` + Result Result `json:"result,omitempty"` + EventType string `json:"event_type,omitempty"` + Event Event `json:"event,omitempty"` + Id int `json:"id,omitempty"` +} + +type AuthMessage struct { + Type string `json:"type"` + AccessToken string `json:"access_token,omitempty"` +} + +type SubscribeEventsMessage struct { + Type string `json:"type"` + EventType string `json:"event_type"` + Id int `json:"id"` +} diff --git a/pkg/homeassistant/util.go b/pkg/homeassistant/util.go new file mode 100644 index 0000000..b793b6b --- /dev/null +++ b/pkg/homeassistant/util.go @@ -0,0 +1,45 @@ +package homeassistant + +import ( + "regexp" + "strings" +) + +// StateToBool converts a state string into a boolean +// +// States that return true: "on", "home", "open", "playing", non-zero numbers, etc. +// All others return false +func StateToBool(state string) bool { + trueRegex := regexp.MustCompile(`^(on|home|open(ing)?|unlocked|playing|good|walking|charging|alive|heat|cool|heat_cool|[1-9][\d\.]*|0\.0*[1-9]\d*)$`) + return trueRegex.MatchString(state) +} + +// BoolToService converts a boolean into the appropriate service string +// +// For locks: true becomes "unlock" and false becomes "lock" +// For covers: true becomes "open_cover" and false becomes "close_cover" +// For all others: true becomes "turn_on" and false becomes "turn_off" +func BoolToService(entityId string, desiredState bool) string { + domain := strings.Split(entityId, ".")[0] + + switch domain { + case Domains.Lock: + if desiredState { + return Services.Unlock + } else { + return Services.Lock + } + case Domains.Cover: + if desiredState { + return Services.OpenCover + } else { + return Services.CloseCover + } + default: + if desiredState { + return Services.TurnOn + } else { + return Services.TurnOff + } + } +} diff --git a/pkg/nats/client.go b/pkg/nats/client.go new file mode 100644 index 0000000..264dd8f --- /dev/null +++ b/pkg/nats/client.go @@ -0,0 +1,118 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +type NatsConnection struct { + HostName string + Port string + UseJetstream bool + connOpts []nats.Option + jetOpts []jetstream.JetStreamOpt + Conn *nats.Conn + JS jetstream.JetStream +} + +func DefaultNatsConnection() *NatsConnection { + return &NatsConnection{ + HostName: "127.0.0.1", + Port: "4222", + UseJetstream: true, + connOpts: []nats.Option{}, + jetOpts: []jetstream.JetStreamOpt{}, + } +} + +func (n *NatsConnection) WithHostName(hostname string) *NatsConnection { + n.HostName = hostname + return n +} + +func (n *NatsConnection) WithPort(port string) *NatsConnection { + n.Port = port + return n +} + +func (n *NatsConnection) WithJetstream(jetstream bool) *NatsConnection { + n.UseJetstream = jetstream + return n +} + +func (n *NatsConnection) WithConnectionOption(opt nats.Option) *NatsConnection { + n.connOpts = append(n.connOpts, opt) + return n +} + +func (n *NatsConnection) WithJetstreamOption(opt jetstream.JetStreamOpt) *NatsConnection { + n.jetOpts = append(n.jetOpts, opt) + return n +} + +func (n *NatsConnection) Connect() error { + var err error + + n.Conn, err = nats.Connect(fmt.Sprintf("nats://%s:%s", n.HostName, n.Port), n.connOpts...) + if err != nil { + return err + } + + if n.UseJetstream { + n.JS, err = jetstream.New(n.Conn, n.jetOpts...) + if err != nil { + return err + } + } + + return nil +} + +func (n *NatsConnection) Close() { + if n.Conn != nil { + n.Conn.Close() + } +} + +func (n *NatsConnection) Publish(ctx context.Context, subject string, payload []byte, opts ...jetstream.PublishOpt) { + if n.UseJetstream { + n.JS.PublishAsync(subject, payload, opts...) + // n.JS.Publish(ctx, subject, payload, opts...) + } else { + n.Conn.Publish(subject, payload) + } +} + +func (n *NatsConnection) Subscribe(subject string) (sub *nats.Subscription, ch chan *nats.Msg, err error) { + if !n.UseJetstream { + ch = make(chan *nats.Msg, 64) + sub, err = n.Conn.ChanSubscribe(subject, ch) + return sub, ch, err + } + return nil, nil, errors.New("jetstream in use, you should use Stream instead") +} + +func (n *NatsConnection) Stream(ctx context.Context, subject string) (stream jetstream.Stream, consumer jetstream.Consumer, err error) { + if n.UseJetstream { + stream, err = n.JS.CreateStream(ctx, jetstream.StreamConfig{ + Name: strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(subject, ".", "_"), "*", "any"), ">", "arrow"), + Subjects: []string{subject}, + }) + if err != nil { + return nil, nil, err + } + + consumer, err = stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{}) + if err != nil { + return nil, nil, err + } + + return stream, consumer, nil + } + return nil, nil, errors.New("jetstream not in use, you should use Subscribe instead") +}