diff --git a/go.mod b/go.mod index f1c1694..e56a94b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,10 @@ require ( require ( github.com/ajg/form v1.5.1 // indirect + github.com/google/uuid v1.3.1 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect + go.uber.org/atomic v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/time v0.3.0 // indirect ) @@ -17,12 +21,14 @@ require ( github.com/go-chi/chi v1.5.5 github.com/go-chi/chi/v5 v5.0.10 github.com/go-chi/render v1.0.3 + github.com/go-co-op/gocron v1.35.2 github.com/go-resty/resty/v2 v2.9.1 github.com/golang/protobuf v1.5.3 // indirect github.com/klauspost/compress v1.17.0 // indirect github.com/nats-io/nats-server/v2 v2.10.2 // indirect github.com/nats-io/nkeys v0.4.5 // indirect github.com/nats-io/nuid v1.0.1 // indirect + github.com/samber/lo v1.38.1 golang.org/x/crypto v0.13.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect diff --git a/go.sum b/go.sum index b61817f..127a0ab 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,34 @@ github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-co-op/gocron v1.35.2 h1:lG3rdA9TqBBC/PtT2ukQqgLm6jEepnAzz3+OQetvPTE= +github.com/go-co-op/gocron v1.35.2/go.mod h1:NLi+bkm4rRSy1F8U7iacZOz0xPseMoIOnvabGoSe/no= github.com/go-resty/resty/v2 v2.9.1 h1:PIgGx4VrHvag0juCJ4dDv3MiFRlDmP0vicBucwf+gLM= github.com/go-resty/resty/v2 v2.9.1/go.mod h1:4/GYJVjh9nhkhGR6AUNW3XhpDYNUr+Uvy9gV/VGZIy4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= @@ -28,11 +41,30 @@ github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= github.com/nats-io/nkeys v0.4.5/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= +github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -79,3 +111,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/api.go b/internal/api/api.go index 8b01b9a..ac7831a 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -37,49 +37,16 @@ func Listen(parentLogger *slog.Logger) { router.Use(middleware.Recoverer) router.Use(middleware.Timeout(60 * time.Second)) - router.Get(`/api/state/{entityId}`, func(w http.ResponseWriter, r *http.Request) { - logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) - entityId := chi.URLParam(r, "entityId") + router.Get(`/api/state/{entityId}`, getEntityStateHandler) + router.Post("/api/state/{entityId}/{service}", setEntityStateHandler) - kvVal, err := nats.GetKeyValue(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId)) - if err == nil && len(kvVal) > 0 { - w.Write(kvVal) - return - } + router.Get("/api/timer/{timerName}", getTimerHandler) + router.Post("/api/timer/{timerName}", createTimerHandler) + router.Delete("/api/timer/{timerName}", deleteTimerHandler) - data, err := haClient.GetState(entityId) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - nats.SetKeyValueString(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId), data.State) - render.PlainText(w, r, data.State) - }) - - router.Post("/api/state/{entityId}/{service}", func(w http.ResponseWriter, r *http.Request) { - logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) - entityId := chi.URLParam(r, "entityId") - service := chi.URLParam(r, "service") - - var extras map[string]string - err := render.DecodeJSON(r.Body, &extras) - var haErr error - if err == nil && len(extras) > 0 { - haErr = haClient.CallService(entityId, service, extras) - } else { - haErr = haClient.CallService(entityId, service) - } - - if haErr != nil { - logger.Error("Error setting state", "error", haErr) - http.Error(w, fmt.Sprintf("error proxying request: %s", haErr.Error()), http.StatusInternalServerError) - return - } - - render.Status(r, http.StatusOK) - render.PlainText(w, r, "OK") - }) + router.Get("/api/schedule/{scheduleName}", getScheduleHandler) + router.Post("/api/schedule/{scheduleName}", createScheduleHandler) + router.Delete("/api/schedule/{scheduleName}", deleteScheduleHandler) server = http.Server{ Addr: ":8888", @@ -94,3 +61,166 @@ func Close() { server.Close() } } + +func logRequest(w http.ResponseWriter, r *http.Request) { + logger.Debug(fmt.Sprintf("%s %s", r.Method, r.URL.Path), "method", r.Method, "path", r.URL.Path, "address", r.RemoteAddr) +} + +// HOME ASSISTANT ENTITIES + +func getEntityStateHandler(w http.ResponseWriter, r *http.Request) { + logRequest(w, r) + entityId := chi.URLParam(r, "entityId") + + kvVal, err := nats.GetKeyValue(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId)) + if err == nil && len(kvVal) > 0 { + w.Write(kvVal) + return + } + + data, err := haClient.GetState(entityId) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + nats.SetKeyValueString(fmt.Sprintf("%s.%s", HA_STATE_PREFIX, entityId), data.State) + render.PlainText(w, r, data.State) +} + +func setEntityStateHandler(w http.ResponseWriter, r *http.Request) { + logRequest(w, r) + entityId := chi.URLParam(r, "entityId") + service := chi.URLParam(r, "service") + + var extras map[string]string + err := render.DecodeJSON(r.Body, &extras) + var haErr error + if err == nil && len(extras) > 0 { + haErr = haClient.CallService(entityId, service, extras) + } else { + haErr = haClient.CallService(entityId, service) + } + + if haErr != nil { + logger.Error("Error setting state", "error", haErr) + http.Error(w, fmt.Sprintf("error proxying request: %s", haErr.Error()), http.StatusInternalServerError) + return + } + + render.Status(r, http.StatusOK) + render.PlainText(w, r, "OK") +} + +// TIMERS + +func getTimerHandler(w http.ResponseWriter, r *http.Request) { + timerName := chi.URLParam(r, "timerName") + logRequest(w, r) + + timer, err := nats.GetTimer(timerName) + if err != nil { + http.Error(w, "Unable to get timer: "+err.Error(), http.StatusInternalServerError) + return + } + + render.PlainText(w, r, string(timer.Marshall())) +} + +type CreateTimerData struct { + Duration string `json:"duration"` + Force bool `json:"force"` +} + +func createTimerHandler(w http.ResponseWriter, r *http.Request) { + timerName := chi.URLParam(r, "timerName") + logRequest(w, r) + + data := &CreateTimerData{} + if err := render.DecodeJSON(r.Body, data); err != nil { + http.Error(w, "Unable to parse timer data", http.StatusNotAcceptable) + return + } + + if data.Duration == "" { + http.Error(w, "duration required", http.StatusNotAcceptable) + return + } + + timer := nats.NewTimerWithDuration(timerName, data.Duration).CalculateNext() + if data.Force { + timer.Activate() + } else { + timer.ActivateIfNotAlready() + } + + getTimerHandler(w, r) +} + +func deleteTimerHandler(w http.ResponseWriter, r *http.Request) { + logRequest(w, r) + timerName := chi.URLParam(r, "timerName") + + timer, err := nats.GetTimer(timerName) + if err != nil { + http.Error(w, "Unable to get timer: "+err.Error(), http.StatusInternalServerError) + return + } + + timer.Cancel() + render.PlainText(w, r, "OK") +} + +// SCHEDULES + +func getScheduleHandler(w http.ResponseWriter, r *http.Request) { + scheduleName := chi.URLParam(r, "scheduleName") + logRequest(w, r) + + schedule, err := nats.GetSchedule(scheduleName) + if err != nil { + http.Error(w, "Unable to get schedule: "+err.Error(), http.StatusInternalServerError) + return + } + + render.PlainText(w, r, string(schedule.GetNext())) +} + +type CreateScheduleData struct { + Cron string `json:"cron"` +} + +func createScheduleHandler(w http.ResponseWriter, r *http.Request) { + scheduleName := chi.URLParam(r, "scheduleName") + logRequest(w, r) + + data := &CreateScheduleData{} + if err := render.DecodeJSON(r.Body, data); err != nil { + http.Error(w, "Unable to parse schedule data", http.StatusNotAcceptable) + return + } + + if data.Cron == "" { + http.Error(w, "cron required", http.StatusNotAcceptable) + return + } + + schedule := nats.NewSchedule(scheduleName, data.Cron) + schedule.Activate() + + getScheduleHandler(w, r) +} + +func deleteScheduleHandler(w http.ResponseWriter, r *http.Request) { + logRequest(w, r) + scheduleName := chi.URLParam(r, "scheduleName") + + schedule, err := nats.GetSchedule(scheduleName) + if err != nil { + http.Error(w, "Unable to get schedule: "+err.Error(), http.StatusInternalServerError) + return + } + + schedule.Cancel() + render.PlainText(w, r, "OK") +} diff --git a/internal/nats/client.go b/internal/nats/client.go index 8244b5b..6fc0fee 100644 --- a/internal/nats/client.go +++ b/internal/nats/client.go @@ -62,10 +62,10 @@ func KvConnect() error { var err error if found { - logger.Debug("Connecting for KV store") + logger.Debug("Connecting to KV store") kv, err = client.JS.KeyValue(ctx, "hats") } else { - logger.Debug("Creating for KV store") + logger.Debug("Creating KV store") kv, err = client.JS.CreateKeyValue(ctx, jetstream.KeyValueConfig{ Bucket: "hats", TTL: 2 * time.Hour, diff --git a/internal/nats/schedule.go b/internal/nats/schedule.go new file mode 100644 index 0000000..d54786f --- /dev/null +++ b/internal/nats/schedule.go @@ -0,0 +1,113 @@ +package nats + +import ( + "errors" + "time" + + "github.com/go-co-op/gocron" + "github.com/nats-io/nats.go/jetstream" +) + +var ( + scheduleStore jetstream.KeyValue + scheduler = gocron.NewScheduler(time.Local) + schedules = map[string]*gocron.Job{} + + fireSchedule = func(name string) { + PublishString("schedules."+name, "fired") + } +) + +func ScheduleStoreConnect() error { + if client.JS == nil { + return errors.New("jetstream must be connected first") + } + + logger.Debug("Looking for schedule KV store") + listener := client.JS.KeyValueStoreNames(ctx) + found := false + for name := range listener.Name() { + if name == "KV_hats_schedules" { + found = true + } + } + + var err error + if found { + logger.Debug("Connecting to Schedules KV store") + scheduleStore, err = client.JS.KeyValue(ctx, "hats_schedules") + } else { + logger.Debug("Creating Schedules KV store") + scheduleStore, err = client.JS.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "hats_schedules", + }) + } + + return err +} + +func GetExistingSchedules() { + scheduler.StartAsync() + existing, _ := scheduleStore.Keys(ctx, jetstream.IgnoreDeletes()) + for _, name := range existing { + sched, err := GetSchedule(name) + if err != nil { + continue + } + sched.Activate() + } +} + +type HatsSchedule struct { + Name string + Cron string +} + +func NewSchedule(name, cron string) *HatsSchedule { + return &HatsSchedule{ + Name: name, + Cron: cron, + } +} + +func GetSchedule(name string) (*HatsSchedule, error) { + value, err := scheduleStore.Get(ctx, name) + if err != nil { + return nil, err + } + + return NewSchedule(name, string(value.Value())), nil +} + +func (t *HatsSchedule) GetNext() string { + if job, exists := schedules[t.Name]; exists { + return job.NextRun().String() + } + return "" +} + +func (t *HatsSchedule) Activate() error { + job, err := scheduler.CronWithSeconds(t.Cron).Do(fireSchedule, t.Name) + if err != nil { + return err + } + if existing, found := schedules[t.Name]; found { + scheduler.RemoveByID(existing) + } + schedules[t.Name] = job + scheduleStore.PutString(ctx, t.Name, t.Cron) + return nil +} + +func (t *HatsSchedule) Cancel() { + if job, exists := schedules[t.Name]; exists { + scheduler.RemoveByID(job) + } + scheduleStore.Purge(ctx, t.Name) +} + +func StopSchedules() { + if scheduler != nil { + scheduler.Stop() + } +} diff --git a/internal/nats/timers.go b/internal/nats/timers.go new file mode 100644 index 0000000..7981124 --- /dev/null +++ b/internal/nats/timers.go @@ -0,0 +1,137 @@ +package nats + +import ( + "errors" + "time" + + "github.com/nats-io/nats.go/jetstream" +) + +var ( + timerStore jetstream.KeyValue + ticker *time.Ticker +) + +func TimerStoreConnect() error { + if client.JS == nil { + return errors.New("jetstream must be connected first") + } + + logger.Debug("Looking for KV store") + listener := client.JS.KeyValueStoreNames(ctx) + found := false + for name := range listener.Name() { + if name == "KV_hats_timers" { + found = true + } + } + + var err error + if found { + logger.Debug("Connecting to Timers KV store") + timerStore, err = client.JS.KeyValue(ctx, "hats_timers") + } else { + logger.Debug("Creating Timers KV store") + timerStore, err = client.JS.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "hats_timers", + }) + } + + return err +} + +type HatsTimer struct { + Name string + Duration time.Duration + NextActivation time.Time +} + +func NewTimerWithDuration(name, duration string) *HatsTimer { + t := &HatsTimer{ + Name: name, + } + + d, err := time.ParseDuration(duration) + if err != nil { + d = 5 * time.Minute + } + t.Duration = d + + return t.CalculateNext() +} + +func NewTimerWithActivation(name string, activation []byte) (*HatsTimer, error) { + t := &HatsTimer{ + Name: name, + Duration: 5 * time.Minute, + } + + a, err := time.Parse(time.RFC3339, string(activation)) + if err != nil { + return t.CalculateNext(), err + } + t.NextActivation = a + + return t, nil +} + +func GetTimer(name string) (*HatsTimer, error) { + value, err := timerStore.Get(ctx, name) + if err != nil { + return nil, err + } + + return NewTimerWithActivation(name, value.Value()) +} + +func (t *HatsTimer) CalculateNext() *HatsTimer { + t.NextActivation = time.Now().Add(t.Duration) + return t +} + +func (t *HatsTimer) Marshall() []byte { + timestamp, _ := t.NextActivation.MarshalText() + return timestamp +} + +func (t *HatsTimer) Activate() { + timerStore.Put(ctx, t.Name, t.Marshall()) +} + +func (t *HatsTimer) ActivateIfNotAlready() { + timerStore.Create(ctx, t.Name, t.Marshall()) +} + +func (t *HatsTimer) Cancel() { + timerStore.Purge(ctx, t.Name) +} + +func (t *HatsTimer) End() { + t.Cancel() + PublishString("timers."+t.Name, "done") +} + +func WatchTimers() { + ticker = time.NewTicker(time.Second) + for { + t := <-ticker.C + timers, _ := timerStore.Keys(ctx, jetstream.IgnoreDeletes()) + for _, timerName := range timers { + timer, err := GetTimer(timerName) + if err != nil { + logger.Error("Error retrieving timer", "timer", timerName, "error", err) + continue + } + + if t.After(timer.NextActivation) { + timer.End() + } + } + } +} + +func StopTimers() { + if ticker != nil { + ticker.Stop() + } +} diff --git a/main.go b/main.go index 42820ef..6755877 100644 --- a/main.go +++ b/main.go @@ -30,11 +30,26 @@ func main() { if err != nil { panic(err) } + defer nats.Close() + err = nats.KvConnect() if err != nil { panic(err) } - defer nats.Close() + + err = nats.TimerStoreConnect() + if err != nil { + panic(err) + } + go nats.WatchTimers() + defer nats.StopTimers() + + err = nats.ScheduleStoreConnect() + if err != nil { + panic(err) + } + nats.GetExistingSchedules() + defer nats.StopSchedules() err = homeassistant.Subscribe(logger) if err != nil { diff --git a/pkg/client/client.go b/pkg/client/client.go index 5672f3b..edd4228 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -3,6 +3,7 @@ package client import ( "fmt" + "code.jhot.me/jhot/hats/internal/api" "code.jhot.me/jhot/hats/pkg/homeassistant" "github.com/go-resty/resty/v2" ) @@ -64,3 +65,84 @@ func (c *HatsClient) CallService(entityId string, service string, extras ...map[ return nil } + +func (c *HatsClient) GetTimer(name string) (string, error) { + resp, err := c.client.R().Get(fmt.Sprintf("api/timer/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return "", err + } + + return resp.String(), nil +} + +func (c *HatsClient) SetTimer(name string, duration string, force bool) (string, error) { + data := api.CreateTimerData{ + Duration: duration, + Force: force, + } + + resp, err := c.client.R().SetBody(data).Post(fmt.Sprintf("api/timer/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return "", err + } + + return resp.String(), nil +} + +func (c *HatsClient) DeleteTimer(name string) error { + resp, err := c.client.R().Delete(fmt.Sprintf("api/timer/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + return err +} + +func (c *HatsClient) GetSchedule(name string) (string, error) { + resp, err := c.client.R().Get(fmt.Sprintf("api/schedule/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return "", err + } + + return resp.String(), nil +} + +// SetSchedule: set a cron schedule +// +// name: a unique identifying string +// cron: a cron expression with seconds, like "0 */5 * * * *" (every 5 minutes) +func (c *HatsClient) SetSchedule(name string, cron string) (string, error) { + data := api.CreateScheduleData{ + Cron: cron, + } + + resp, err := c.client.R().SetBody(data).Post(fmt.Sprintf("api/schedule/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + + if err != nil { + return "", err + } + + return resp.String(), nil +} + +func (c *HatsClient) DeleteSchedule(name string) error { + resp, err := c.client.R().Delete(fmt.Sprintf("api/schedule/%s", name)) + if err == nil && !resp.IsSuccess() { + err = fmt.Errorf("%d status code received: %s", resp.StatusCode(), resp.String()) + } + return err +} diff --git a/pkg/config/config.go b/pkg/config/config.go index fd7fd54..7297dc7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,11 @@ type HatsConfig struct { NatsBaseUrl string NatsToken string NatsClientName string + + HatsHost string + HatsPort string + HatsSecure bool + HatsBaseUrl string } func FromEnvironment() *HatsConfig { @@ -31,6 +36,8 @@ func FromEnvironment() *HatsConfig { NatsPort: util.GetEnvWithDefault("NATS_PORT", "4222"), NatsToken: util.GetEnvWithDefault("NATS_TOKEN", ""), NatsClientName: util.GetEnvWithDefault("NATS_CLIENT_NAME", "hats"), + HatsHost: util.GetEnvWithDefault("HATS_HOST", "hats"), + HatsPort: util.GetEnvWithDefault("HATS_PORT", "8888"), } config.HomeAssistantSecure, _ = strconv.ParseBool(util.GetEnvWithDefault("HASS_SECURE", "false")) @@ -42,7 +49,15 @@ func FromEnvironment() *HatsConfig { } config.HomeAssistantBaseUrl = fmt.Sprintf("%s://%s:%s", hassProtocol, config.HomeAssistantHost, config.HomeAssistantPort) config.HomeAssistantWebsocketUrl = fmt.Sprintf("%s://%s:%s/api/websocket", HassWsProtocol, config.HomeAssistantHost, config.HomeAssistantPort) + config.NatsBaseUrl = fmt.Sprintf("nats://%s:%s", config.NatsHost, config.NatsPort) + config.HatsSecure, _ = strconv.ParseBool(util.GetEnvWithDefault("HATS_SECURE", "false")) + hatsProtocol := "http" + if config.HatsSecure { + hatsProtocol += "s" + } + config.HatsBaseUrl = fmt.Sprintf("%s://%s:%s", hatsProtocol, config.HatsHost, config.HatsPort) + return config }