Skip to content

Commit

Permalink
parse and validate query parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
gernest committed Feb 8, 2024
1 parent 52e5212 commit f4da594
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 44 deletions.
26 changes: 13 additions & 13 deletions gen/go/staples/v1/stats.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions proto/staples/v1/stats.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ message Timeseries {
}

enum Interval {
minute = 0;
hour = 1;
date = 2;
date = 0;
minute = 1;
hour = 2;
week = 3;
month = 4;
}
Expand Down
9 changes: 9 additions & 0 deletions request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ func Read(w http.ResponseWriter, r *http.Request, o proto.Message) bool {
return true
}

func Validate(ctx context.Context, w http.ResponseWriter, o proto.Message) bool {
if err := Get(ctx).Validate(o); err != nil {
logger.Get(ctx).Error("Failed validating request body", "err", err)
Error(ctx, w, http.StatusBadRequest, err.Error())
return false
}
return true
}

func Write(ctx context.Context, w http.ResponseWriter, o proto.Message) {
data, err := m.Marshal(o)
if err != nil {
Expand Down
14 changes: 9 additions & 5 deletions stats/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ import (
)

func Aggregate(w http.ResponseWriter, r *http.Request) {
var req v1.Aggregate_Request
req.SiteId = r.URL.Query().Get("site_id")
if !request.Read(w, r, &req) {
ctx := r.Context()
query := r.URL.Query()
req := v1.Aggregate_Request{
SiteId: query.Get("site_id"),
Period: ParsePeriod(ctx, query),
Metrics: ParseMetrics(ctx, query),
}
if !request.Validate(ctx, w, &req) {
return
}
ctx := r.Context()
filters := &v1.Filters{
List: append(req.Filters, &v1.Filter{
Property: v1.Property_domain,
Expand All @@ -31,7 +35,7 @@ func Aggregate(w http.ResponseWriter, r *http.Request) {
metrics := slices.Clone(req.Metrics)
slices.Sort(metrics)
metricsToProjection(filters, metrics)
from, to := PeriodToRange(time.Now, req.Period)
from, to := PeriodToRange(ctx, time.Now, req.Period, r.URL.Query())
resultRecord, err := session.Get(ctx).Scan(ctx, from.UnixMilli(), to.UnixMilli(), filters)
if err != nil {
logger.Get(ctx).Error("Failed scanning", "err", err)
Expand Down
15 changes: 10 additions & 5 deletions stats/breakdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ import (
)

func BreakDown(w http.ResponseWriter, r *http.Request) {

var req v1.BreakDown_Request
if !request.Read(w, r, &req) {
ctx := r.Context()
query := r.URL.Query()
req := v1.BreakDown_Request{
SiteId: query.Get("site_id"),
Period: ParsePeriod(ctx, query),
Metrics: ParseMetrics(ctx, query),
}
if !request.Validate(ctx, w, &req) {
return
}
ctx := r.Context()
period := req.Period
if period == nil {
period = &v1.TimePeriod{
Expand All @@ -43,7 +47,7 @@ func BreakDown(w http.ResponseWriter, r *http.Request) {
slices.Sort(req.Metrics)
slices.Sort(req.Property)
metricsToProjection(filter, req.Metrics, req.Property...)
from, to := PeriodToRange(time.Now, period)
from, to := PeriodToRange(ctx, time.Now, period, r.URL.Query())
scannedRecord, err := session.Get(ctx).Scan(ctx, from.UnixMilli(), to.UnixMilli(), filter)
if err != nil {
logger.Get(ctx).Error("Failed scanning", "err", err)
Expand Down Expand Up @@ -205,6 +209,7 @@ func take(ctx context.Context, metric v1.Metric, f v1.Filters_Projection, mappin
}
return a, true
}

func hashProp(a arrow.Array) map[string]*roaring.Bitmap {
o := make(map[string]*roaring.Bitmap)
d := a.(*array.Dictionary)
Expand Down
107 changes: 99 additions & 8 deletions stats/common.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
package stats

import (
"context"
"net/url"
"strings"
"time"

v1 "github.com/vinceanalytics/vince/gen/go/staples/v1"
"github.com/vinceanalytics/vince/logger"
"github.com/vinceanalytics/vince/timeutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)

// Avoid leaking internal errors to client. The actual error is logged and this
// is returned back to the client.
var InternalError = status.Error(codes.Internal, "Something went wrong")

func PeriodToRange(now func() time.Time, period *v1.TimePeriod) (start, end time.Time) {
func PeriodToRange(ctx context.Context, now func() time.Time, period *v1.TimePeriod, query url.Values) (start, end time.Time) {
date := parseDate(ctx, query, now)
switch e := period.Value.(type) {
case *v1.TimePeriod_Base_:
switch e.Base {
case v1.TimePeriod_day:
end = timeutil.Today()
end = date
start = end
case v1.TimePeriod__7d:
end = timeutil.Today()
end = date
start = end.AddDate(0, 0, -6)
case v1.TimePeriod__30d:
end = timeutil.Today()
end = date
start = end.AddDate(0, 0, -30)
case v1.TimePeriod_mo:
end = timeutil.Today()
end = date
start = timeutil.BeginMonth(end)
end = timeutil.EndMonth(end)
case v1.TimePeriod__6mo:
end = timeutil.EndMonth(timeutil.Today())
end = timeutil.EndMonth(date)
start = timeutil.BeginMonth(end.AddDate(0, -5, 0))
case v1.TimePeriod__12mo:
end = timeutil.EndMonth(timeutil.Today())
end = timeutil.EndMonth(date)
start = timeutil.BeginMonth(end.AddDate(0, -11, 0))
case v1.TimePeriod_year:
end = timeutil.EndYear(timeutil.Today())
end = timeutil.EndYear(date)
start = timeutil.BeginYear(end)
}

Expand Down Expand Up @@ -69,3 +75,88 @@ func ValidByPeriod(period *v1.TimePeriod, i v1.Interval) bool {
return false
}
}

func ParsePeriod(ctx context.Context, query url.Values) *v1.TimePeriod {
value := query.Get("period")
switch value {
case "12mo":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod__12mo}}
case "6mo":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod__6mo}}
case "month":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod_mo}}
case "30day":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod__30d}}
case "7day":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod__7d}}
case "day":
return &v1.TimePeriod{Value: &v1.TimePeriod_Base_{Base: v1.TimePeriod_day}}
case "custom":
date := query.Get("date")
if date == "" {
logger.Get(ctx).Error("custom period specified with missing date")
return nil
}
from, to, _ := strings.Cut(date, ",")

start, err := time.Parse(time.DateOnly, from)
if err != nil {
logger.Get(ctx).Error("Invalid date for custom period", "date", date, "err", err)
return nil
}
end, err := time.Parse(time.DateOnly, to)
if err != nil {
logger.Get(ctx).Error("Invalid date for custom period", "date", date, "err", err)
return nil
}
return &v1.TimePeriod{
Value: &v1.TimePeriod_Custom_{
Custom: &v1.TimePeriod_Custom{
Start: timestamppb.New(start),
End: timestamppb.New(end),
},
},
}
default:
return nil
}
}

func parseDate(ctx context.Context, query url.Values, now func() time.Time) time.Time {
date := query.Get("date")
if date == "" {
return timeutil.BeginDay(now())
}
v, err := time.Parse(time.DateOnly, date)
if err != nil {
fall := timeutil.BeginDay(now())
logger.Get(ctx).Error("failed parsing date falling back to now",
"date", date, "now", fall.Format(time.DateOnly), "err", err)
return fall
}
return v
}

func ParseMetrics(ctx context.Context, query url.Values) (o []v1.Metric) {
metrics := query.Get("metrics")
for _, m := range strings.Split(metrics, ",") {
v, ok := v1.Metric_value[m]
if !ok {
logger.Get(ctx).Error("Skipping unexpected metric name", "metric", m)
continue
}
o = append(o, v1.Metric(v))
}
return
}

func ParseInterval(ctx context.Context, query url.Values) v1.Interval {
i := query.Get("interval")
v, ok := v1.Interval_value[i]
if !ok {
if i != "" {
logger.Get(ctx).Error("Skipping unexpected interval value", "interval", i)
}
}
return v1.Interval(v)
}
10 changes: 5 additions & 5 deletions stats/current_visitors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ import (
)

func Realtime(w http.ResponseWriter, r *http.Request) {
var req v1.Realtime_Request
req.SiteId = r.URL.Query().Get("site_id")
ctx := r.Context()
if err := request.Get(r.Context()).Validate(&req); err != nil {
logger.Get(ctx).Error("Failed validating request body", "err", err)
request.Error(ctx, w, http.StatusBadRequest, err.Error())
query := r.URL.Query()
req := v1.Realtime_Request{
SiteId: query.Get("site_id"),
}
if !request.Validate(ctx, w, &req) {
return
}
now := time.Now().UTC()
Expand Down
15 changes: 10 additions & 5 deletions stats/timeseries.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ import (
)

func TimeSeries(w http.ResponseWriter, r *http.Request) {
var req v1.Timeseries_Request
req.SiteId = r.URL.Query().Get("site_id")
if !request.Read(w, r, &req) {
ctx := r.Context()
query := r.URL.Query()
req := v1.Timeseries_Request{
SiteId: query.Get("site_id"),
Period: ParsePeriod(ctx, query),
Metrics: ParseMetrics(ctx, query),
Interval: ParseInterval(ctx, query),
}
if !request.Validate(ctx, w, &req) {
return
}
ctx := r.Context()
// make sure we have valid interval
if !ValidByPeriod(req.Period, req.Interval) {
request.Error(ctx, w, http.StatusBadRequest, "Interval out of range")
Expand All @@ -40,7 +45,7 @@ func TimeSeries(w http.ResponseWriter, r *http.Request) {
metrics := slices.Clone(req.Metrics)
slices.Sort(metrics)
metricsToProjection(filters, metrics)
from, to := PeriodToRange(time.Now, req.Period)
from, to := PeriodToRange(ctx, time.Now, req.Period, r.URL.Query())
scanRecord, err := session.Get(ctx).Scan(ctx, from.UnixMilli(), to.UnixMilli(), filters)
if err != nil {
logger.Get(ctx).Error("Failed scanning", "err", err)
Expand Down

0 comments on commit f4da594

Please sign in to comment.