diff --git a/banditcallback/banditcallback.go b/banditcallback/banditcallback.go index 9161f7e2..1b2d9bfd 100644 --- a/banditcallback/banditcallback.go +++ b/banditcallback/banditcallback.go @@ -15,6 +15,7 @@ package banditcallback import ( "context" + "net" "net/http" "net/url" "sync" @@ -76,10 +77,14 @@ func (e *Emitter) Enabled() bool { // EmitIfFirstSeen records the device-id and, if it hasn't been seen // within the TTL window, fires an async best-effort GET to the -// configured callback URL. Returns immediately — the HTTP request runs -// in its own goroutine and any failure is logged at debug only (the -// callback is a hint to the bandit, not a correctness signal). -func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID string) { +// configured callback URL. clientIP, when non-empty, is forwarded as a +// True-Client-IP header so the API can attribute the reward to the +// device's ASN rather than to the proxy VPS's egress IP — the source +// address of our outbound request. Returns immediately — the HTTP +// request runs in its own goroutine and any failure is logged at debug +// only (the callback is a hint to the bandit, not a correctness +// signal). +func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID, clientIP string) { if !e.Enabled() || deviceID == "" { return } @@ -90,7 +95,7 @@ func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID string) { return } - go e.fire(ctx, deviceID) + go e.fire(ctx, deviceID, clientIP) } // checkAndRecord returns true if the deviceID is first-seen within the @@ -120,7 +125,7 @@ func (e *Emitter) checkAndRecord(deviceID string, now time.Time) bool { return true } -func (e *Emitter) fire(ctx context.Context, deviceID string) { +func (e *Emitter) fire(ctx context.Context, deviceID, clientIP string) { // Detach from the request context so a closing client connection // doesn't cancel the outbound HTTP request. The callback is for // the bandit's benefit, not the client's; we want it to complete @@ -143,6 +148,9 @@ func (e *Emitter) fire(ctx context.Context, deviceID string) { log.Debugf("banditcallback: build request: %v", err) return } + if clientIP != "" { + req.Header.Set("True-Client-IP", clientIP) + } resp, err := e.client.Do(req) if err != nil { @@ -183,10 +191,17 @@ func NewFilter(headerName string, emitter *Emitter) *Filter { } // Apply implements filters.Filter. Forwards unconditionally (the -// emitter is a side-effect; failures are non-fatal). +// emitter is a side-effect; failures are non-fatal). The client IP is +// taken from req.RemoteAddr (same source opsfilter uses for measured +// reporting) so the API receives the device's real IP via the +// True-Client-IP header instead of our VPS egress. func (f *Filter) Apply(cs *filters.ConnectionState, req *http.Request, next filters.Next) (*http.Response, *filters.ConnectionState, error) { if f.emitter != nil && f.emitter.Enabled() { - f.emitter.EmitIfFirstSeen(req.Context(), req.Header.Get(f.deviceIDHeader)) + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + clientIP = req.RemoteAddr + } + f.emitter.EmitIfFirstSeen(req.Context(), req.Header.Get(f.deviceIDHeader), clientIP) } return next(cs, req) } diff --git a/banditcallback/banditcallback_test.go b/banditcallback/banditcallback_test.go index d9f386f4..1563c77e 100644 --- a/banditcallback/banditcallback_test.go +++ b/banditcallback/banditcallback_test.go @@ -29,7 +29,7 @@ func TestEmitter_DisabledWhenUnconfigured(t *testing.T) { if e.Enabled() { t.Fatal("expected disabled") } - e.EmitIfFirstSeen(context.Background(), "did-1") + e.EmitIfFirstSeen(context.Background(), "did-1", "") emitted, _ := e.Stats() if emitted != 0 { t.Fatalf("expected 0 emits, got %d", emitted) @@ -48,12 +48,15 @@ func TestEmitter_FirstSeenFires(t *testing.T) { if got := r.URL.Query().Get("did"); got == "" { t.Error("missing did") } + if got := r.Header.Get("True-Client-IP"); got != "203.0.113.7" { + t.Errorf("True-Client-IP mismatch: %q", got) + } w.WriteHeader(http.StatusNoContent) })) defer srv.Close() e := New("arm-test", srv.URL, time.Minute) - e.EmitIfFirstSeen(context.Background(), "device-a") + e.EmitIfFirstSeen(context.Background(), "device-a", "203.0.113.7") // Emission is async; wait for the goroutine. 1s ceiling is generous. deadline := time.Now().Add(time.Second) @@ -83,7 +86,7 @@ func TestEmitter_DedupSuppressesRepeat(t *testing.T) { e := New("arm-test", srv.URL, time.Minute) for i := 0; i < 10; i++ { - e.EmitIfFirstSeen(context.Background(), "device-a") + e.EmitIfFirstSeen(context.Background(), "device-a", "203.0.113.7") } // One emit, nine suppressed. Wait for async emit to complete. @@ -123,7 +126,7 @@ func TestEmitter_ConcurrentFirstSeenIsSingleFire(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - e.EmitIfFirstSeen(context.Background(), "race-device") + e.EmitIfFirstSeen(context.Background(), "race-device", "") }() } wg.Wait() @@ -146,14 +149,36 @@ func TestEmitter_ReEmitsAfterTTL(t *testing.T) { defer srv.Close() e := New("arm-test", srv.URL, 50*time.Millisecond) - e.EmitIfFirstSeen(context.Background(), "ttl-device") + e.EmitIfFirstSeen(context.Background(), "ttl-device", "") time.Sleep(20 * time.Millisecond) // within TTL — suppressed - e.EmitIfFirstSeen(context.Background(), "ttl-device") + e.EmitIfFirstSeen(context.Background(), "ttl-device", "") time.Sleep(80 * time.Millisecond) // past TTL — fires again - e.EmitIfFirstSeen(context.Background(), "ttl-device") + e.EmitIfFirstSeen(context.Background(), "ttl-device", "") time.Sleep(100 * time.Millisecond) if got := atomic.LoadInt32(&hits); got != 2 { t.Fatalf("expected 2 hits across TTL window, got %d", got) } } + +func TestEmitter_OmitsTrueClientIPWhenEmpty(t *testing.T) { + // Empty clientIP must NOT set the header — leaving it absent + // lets the API fall through to its existing RemoteAddr-based + // MaxMind lookup. Setting an empty header would shadow that + // fallback with a junk value. + var sawHeader int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := r.Header["True-Client-Ip"]; ok { + atomic.StoreInt32(&sawHeader, 1) + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + e := New("arm-test", srv.URL, time.Minute) + e.EmitIfFirstSeen(context.Background(), "device-x", "") + time.Sleep(100 * time.Millisecond) + if atomic.LoadInt32(&sawHeader) != 0 { + t.Fatal("True-Client-IP must be absent when clientIP is empty") + } +}