Files
2026-06-10 16:53:28 +02:00

414 lines
10 KiB
Go

package main
import (
"crypto/rand"
"embed"
"encoding/json"
"errors"
"flag"
"fmt"
"io/fs"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
)
//go:embed static
var staticFS embed.FS
// ---------- Models ----------
type Option struct {
ID string `json:"id"`
Date string `json:"date"` // YYYY-MM-DD
Time string `json:"time,omitempty"` // HH:MM, optional
}
type Vote struct {
Name string `json:"name"`
OptionIDs []string `json:"optionIds"`
CreatedAt time.Time `json:"createdAt"`
}
type Poll struct {
ID string `json:"id"`
AdminToken string `json:"-"` // never serialized to clients
Title string `json:"title"`
Description string `json:"description,omitempty"`
Options []Option `json:"options"`
Votes []Vote `json:"votes"`
Closed bool `json:"closed"`
CreatedAt time.Time `json:"createdAt"`
}
// pollFile is the on-disk representation (includes the admin token).
type pollFile struct {
Poll
AdminToken string `json:"adminToken"`
}
// ---------- Store ----------
type Store struct {
mu sync.Mutex
path string
polls map[string]*Poll
}
func NewStore(path string) (*Store, error) {
s := &Store{path: path, polls: map[string]*Poll{}}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return s, nil
}
return nil, err
}
var list []pollFile
if err := json.Unmarshal(data, &list); err != nil {
return nil, fmt.Errorf("parsing %s: %w", path, err)
}
for i := range list {
p := list[i].Poll
p.AdminToken = list[i].AdminToken
s.polls[p.ID] = &p
}
return s, nil
}
// persist writes all polls to disk. Caller must hold s.mu.
func (s *Store) persist() error {
list := make([]pollFile, 0, len(s.polls))
for _, p := range s.polls {
list = append(list, pollFile{Poll: *p, AdminToken: p.AdminToken})
}
sort.Slice(list, func(i, j int) bool { return list[i].CreatedAt.Before(list[j].CreatedAt) })
data, err := json.MarshalIndent(list, "", " ")
if err != nil {
return err
}
tmp := s.path + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return err
}
return os.Rename(tmp, s.path)
}
func (s *Store) Create(p *Poll) error {
s.mu.Lock()
defer s.mu.Unlock()
s.polls[p.ID] = p
return s.persist()
}
func (s *Store) Get(id string) (*Poll, bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.polls[id]
if !ok {
return nil, false
}
cp := *p // shallow copy is fine for read-only use
return &cp, true
}
// Update runs fn on the poll under lock and persists the result.
func (s *Store) Update(id string, fn func(*Poll) error) error {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.polls[id]
if !ok {
return errNotFound
}
if err := fn(p); err != nil {
return err
}
return s.persist()
}
func (s *Store) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.polls[id]; !ok {
return errNotFound
}
delete(s.polls, id)
return s.persist()
}
var errNotFound = errors.New("not found")
// ---------- Helpers ----------
const idAlphabet = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789"
func randomID(n int) string {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
panic(err)
}
out := make([]byte, n)
for i, v := range b {
out[i] = idAlphabet[int(v)%len(idAlphabet)]
}
return string(out)
}
var (
dateRe = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}$`)
timeRe = regexp.MustCompile(`^([01]\d|2[0-3]):[0-5]\d$`)
)
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}
func adminToken(r *http.Request) string {
if t := r.Header.Get("X-Admin-Token"); t != "" {
return t
}
return r.URL.Query().Get("admin")
}
// ---------- Handlers ----------
type server struct {
store *Store
}
type createPollRequest struct {
Title string `json:"title"`
Description string `json:"description"`
Options []struct {
Date string `json:"date"`
Time string `json:"time"`
} `json:"options"`
}
func (sv *server) createPoll(w http.ResponseWriter, r *http.Request) {
var req createPollRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 64<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "Invalid request body.")
return
}
req.Title = strings.TrimSpace(req.Title)
req.Description = strings.TrimSpace(req.Description)
if req.Title == "" {
writeError(w, http.StatusBadRequest, "A title is required.")
return
}
if len(req.Title) > 120 || len(req.Description) > 600 {
writeError(w, http.StatusBadRequest, "Title or description is too long.")
return
}
if len(req.Options) == 0 {
writeError(w, http.StatusBadRequest, "Select at least one date.")
return
}
if len(req.Options) > 60 {
writeError(w, http.StatusBadRequest, "Too many dates (max 60).")
return
}
seen := map[string]bool{}
options := make([]Option, 0, len(req.Options))
for _, o := range req.Options {
if !dateRe.MatchString(o.Date) {
writeError(w, http.StatusBadRequest, "Invalid date: "+o.Date)
return
}
if _, err := time.Parse("2006-01-02", o.Date); err != nil {
writeError(w, http.StatusBadRequest, "Invalid date: "+o.Date)
return
}
if o.Time != "" && !timeRe.MatchString(o.Time) {
writeError(w, http.StatusBadRequest, "Invalid time for "+o.Date+" (use HH:MM).")
return
}
key := o.Date + "T" + o.Time
if seen[key] {
continue
}
seen[key] = true
options = append(options, Option{ID: randomID(8), Date: o.Date, Time: o.Time})
}
sort.Slice(options, func(i, j int) bool {
if options[i].Date != options[j].Date {
return options[i].Date < options[j].Date
}
return options[i].Time < options[j].Time
})
p := &Poll{
ID: randomID(10),
AdminToken: randomID(24),
Title: req.Title,
Description: req.Description,
Options: options,
Votes: []Vote{},
CreatedAt: time.Now().UTC(),
}
if err := sv.store.Create(p); err != nil {
log.Printf("create poll: %v", err)
writeError(w, http.StatusInternalServerError, "Could not save the poll.")
return
}
writeJSON(w, http.StatusCreated, map[string]string{"id": p.ID, "adminToken": p.AdminToken})
}
func (sv *server) getPoll(w http.ResponseWriter, r *http.Request) {
p, ok := sv.store.Get(r.PathValue("id"))
if !ok {
writeError(w, http.StatusNotFound, "Poll not found.")
return
}
resp := struct {
*Poll
IsAdmin bool `json:"isAdmin"`
}{p, adminToken(r) == p.AdminToken}
writeJSON(w, http.StatusOK, resp)
}
type voteRequest struct {
Name string `json:"name"`
OptionIDs []string `json:"optionIds"`
}
func (sv *server) vote(w http.ResponseWriter, r *http.Request) {
var req voteRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 16<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "Invalid request body.")
return
}
req.Name = strings.TrimSpace(req.Name)
if req.Name == "" {
writeError(w, http.StatusBadRequest, "Enter your name so friends know who answered.")
return
}
if len(req.Name) > 60 {
writeError(w, http.StatusBadRequest, "Name is too long.")
return
}
err := sv.store.Update(r.PathValue("id"), func(p *Poll) error {
if p.Closed {
return errors.New("This poll is closed.")
}
valid := map[string]bool{}
for _, o := range p.Options {
valid[o.ID] = true
}
ids := []string{}
seen := map[string]bool{}
for _, id := range req.OptionIDs {
if valid[id] && !seen[id] {
ids = append(ids, id)
seen[id] = true
}
}
v := Vote{Name: req.Name, OptionIDs: ids, CreatedAt: time.Now().UTC()}
// Same name (case-insensitive) replaces the earlier answer.
for i := range p.Votes {
if strings.EqualFold(p.Votes[i].Name, req.Name) {
p.Votes[i] = v
return nil
}
}
p.Votes = append(p.Votes, v)
return nil
})
if err != nil {
if errors.Is(err, errNotFound) {
writeError(w, http.StatusNotFound, "Poll not found.")
} else {
writeError(w, http.StatusConflict, err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
}
func (sv *server) withAdmin(fn func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
p, ok := sv.store.Get(r.PathValue("id"))
if !ok {
writeError(w, http.StatusNotFound, "Poll not found.")
return
}
if adminToken(r) != p.AdminToken {
writeError(w, http.StatusForbidden, "Only the poll creator can do this. Use your admin link.")
return
}
fn(w, r)
}
}
func (sv *server) closePoll(w http.ResponseWriter, r *http.Request) {
err := sv.store.Update(r.PathValue("id"), func(p *Poll) error {
p.Closed = true
return nil
})
if err != nil {
writeError(w, http.StatusNotFound, "Poll not found.")
return
}
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
}
func (sv *server) deletePoll(w http.ResponseWriter, r *http.Request) {
if err := sv.store.Delete(r.PathValue("id")); err != nil {
writeError(w, http.StatusNotFound, "Poll not found.")
return
}
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
}
// ---------- main ----------
func main() {
addr := flag.String("addr", ":8080", "listen address")
dataDir := flag.String("data", "data", "directory for poll storage")
flag.Parse()
if err := os.MkdirAll(*dataDir, 0o755); err != nil {
log.Fatal(err)
}
store, err := NewStore(filepath.Join(*dataDir, "polls.json"))
if err != nil {
log.Fatal(err)
}
sv := &server{store: store}
static, _ := fs.Sub(staticFS, "static")
mux := http.NewServeMux()
mux.HandleFunc("GET /{$}", func(w http.ResponseWriter, r *http.Request) {
http.ServeFileFS(w, r, static, "index.html")
})
mux.HandleFunc("GET /p/{id}", func(w http.ResponseWriter, r *http.Request) {
http.ServeFileFS(w, r, static, "poll.html")
})
mux.Handle("GET /static/", http.StripPrefix("/static/", http.FileServerFS(static)))
mux.HandleFunc("POST /api/polls", sv.createPoll)
mux.HandleFunc("GET /api/polls/{id}", sv.getPoll)
mux.HandleFunc("POST /api/polls/{id}/votes", sv.vote)
mux.HandleFunc("POST /api/polls/{id}/close", sv.withAdmin(sv.closePoll))
mux.HandleFunc("DELETE /api/polls/{id}", sv.withAdmin(sv.deletePoll))
log.Printf("mediator running on http://localhost%s (data in %s)", *addr, *dataDir)
log.Fatal(http.ListenAndServe(*addr, mux))
}