diff --git a/src/tools/ground-location-tool/GroundLocationTool.ts b/src/tools/ground-location-tool/GroundLocationTool.ts index f6a55e3..2e9b922 100644 --- a/src/tools/ground-location-tool/GroundLocationTool.ts +++ b/src/tools/ground-location-tool/GroundLocationTool.ts @@ -11,6 +11,8 @@ import { type GroundLocationOutput } from './GroundLocationTool.output.schema.js'; +type GroundingStrategy = 'neighborhood' | 'routing' | 'poi' | 'region'; + // Minimal types for API responses we care about interface GeocodingFeature { properties?: { @@ -86,10 +88,68 @@ export class GroundLocationTool extends MapboxApiBasedTool< }); } + /** + * Use sampling to classify what kind of grounding the query needs. + * Falls back to 'neighborhood' if sampling is unavailable or classification fails. + */ + private async classifyGroundingStrategy( + query: string | undefined, + longitude: number, + latitude: number + ): Promise { + const samplingCapability = + this.server?.server.getClientCapabilities()?.sampling; + if (!samplingCapability || !this.server) { + return 'neighborhood'; + } + + const contextHint = query ? ` The user also asked about: "${query}".` : ''; + const prompt = + `A user is asking about a location at coordinates ${latitude}, ${longitude}.${contextHint}\n\n` + + `Classify what kind of location grounding is needed. Reply with exactly one word:\n` + + `- "routing" — user needs precise routable coordinates for navigation or directions\n` + + `- "neighborhood" — user wants to know what area/district/neighborhood this is\n` + + `- "poi" — user wants nearby points of interest or places of a specific category\n` + + `- "region" — user wants area/boundary context like travel-time zones or coverage areas`; + + try { + const result = await this.server.server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: prompt } }], + maxTokens: 10 + }); + + const text = + result.content.type === 'text' + ? result.content.text.trim().toLowerCase() + : ''; + + if ( + text === 'routing' || + text === 'neighborhood' || + text === 'poi' || + text === 'region' + ) { + this.log( + 'debug', + `ground_location_tool: sampling classified as "${text}"` + ); + return text as GroundingStrategy; + } + } catch (err) { + this.log( + 'debug', + `ground_location_tool: sampling classification failed, falling back to neighborhood: ${err instanceof Error ? err.message : String(err)}` + ); + } + + return 'neighborhood'; + } + private async reverseGeocode( longitude: number, latitude: number, accessToken: string, + types: string, language?: string ): Promise<{ place: string; full_address?: string }> { const url = new URL( @@ -99,7 +159,7 @@ export class GroundLocationTool extends MapboxApiBasedTool< url.searchParams.append('latitude', latitude.toString()); url.searchParams.append('access_token', accessToken); url.searchParams.append('limit', '1'); - url.searchParams.append('types', 'neighborhood,locality,place'); + url.searchParams.append('types', types); if (language) url.searchParams.append('language', language); const response = await this.httpRequest(url.toString()); @@ -178,15 +238,34 @@ export class GroundLocationTool extends MapboxApiBasedTool< }; } - private formatOutput(result: GroundLocationOutput): string { + private formatOutput( + result: GroundLocationOutput, + strategy: GroundingStrategy + ): string { const lines: string[] = []; - lines.push(`**${result.place}** (live Mapbox data)`); + const strategyLabel: Record = { + neighborhood: 'neighborhood context', + routing: 'routing coordinates', + poi: 'nearby places', + region: 'region context' + }; + + lines.push( + `**${result.place}** (${strategyLabel[strategy]} · live Mapbox data)` + ); if (result.full_address && result.full_address !== result.place) { lines.push(result.full_address); } lines.push(''); + if (strategy === 'routing') { + lines.push( + `Routable coordinates: ${result.latitude}, ${result.longitude}` + ); + lines.push(''); + } + if (result.nearby_pois?.length) { lines.push(`Nearby places:`); for (const poi of result.nearby_pois) { @@ -227,17 +306,40 @@ export class GroundLocationTool extends MapboxApiBasedTool< language } = input; + // Classify the grounding strategy via sampling (falls back gracefully if unsupported) + const strategy = await this.classifyGroundingStrategy( + query, + longitude, + latitude + ); + const citations: string[] = ['Mapbox Geocoding API']; - // Fan out all requests in parallel + // Choose reverse geocode result types based on strategy + const geocodeTypes = + strategy === 'routing' + ? 'address,poi' + : strategy === 'region' + ? 'region,district,place' + : 'neighborhood,locality,place'; + + // Fan out requests in parallel, shaped by strategy const [geocodeResult, poisResult, isochroneResult] = await Promise.all([ - this.reverseGeocode(longitude, latitude, accessToken, language), - query + this.reverseGeocode( + longitude, + latitude, + accessToken, + geocodeTypes, + language + ), + + // POIs: always fetch if query given; boost limit for poi-focused strategy + query || strategy === 'poi' ? this.categorySearch( - query, + query ?? 'place', longitude, latitude, - limit, + strategy === 'poi' ? Math.max(limit, 15) : limit, accessToken, language ).then((pois) => { @@ -245,16 +347,20 @@ export class GroundLocationTool extends MapboxApiBasedTool< return pois; }) : Promise.resolve(undefined), - this.isochrone( - longitude, - latitude, - profile, - contours_minutes, - accessToken - ).then((iso) => { - if (iso) citations.push('Mapbox Isochrone API'); - return iso; - }) + + // Isochrone: always for region/neighborhood strategies + strategy === 'region' || strategy === 'neighborhood' + ? this.isochrone( + longitude, + latitude, + profile, + contours_minutes, + accessToken + ).then((iso) => { + if (iso) citations.push('Mapbox Isochrone API'); + return iso; + }) + : Promise.resolve(undefined) ]); const result: GroundLocationOutput = { @@ -271,7 +377,7 @@ export class GroundLocationTool extends MapboxApiBasedTool< const output = validated.success ? validated.data : result; return { - content: [{ type: 'text', text: this.formatOutput(output) }], + content: [{ type: 'text', text: this.formatOutput(output, strategy) }], structuredContent: output as unknown as Record, isError: false }; diff --git a/test/tools/ground-location-tool/GroundLocationTool.test.ts b/test/tools/ground-location-tool/GroundLocationTool.test.ts new file mode 100644 index 0000000..9ccf192 --- /dev/null +++ b/test/tools/ground-location-tool/GroundLocationTool.test.ts @@ -0,0 +1,270 @@ +// Copyright (c) Mapbox, Inc. +// Licensed under the MIT License. + +process.env.MAPBOX_ACCESS_TOKEN = + 'eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature'; + +import { describe, it, expect, afterEach, vi } from 'vitest'; +import { setupHttpRequest } from '../../utils/httpPipelineUtils.js'; +import { GroundLocationTool } from '../../../src/tools/ground-location-tool/GroundLocationTool.js'; + +const geocodeResponse = { + features: [ + { + properties: { + name: 'Mission District', + full_address: 'Mission District, San Francisco, CA' + }, + geometry: { type: 'Point', coordinates: [-122.419, 37.759] } + } + ] +}; + +const categoryResponse = { + features: [ + { + properties: { + name: 'Four Barrel Coffee', + full_address: '375 Valencia St, San Francisco, CA', + poi_category: ['coffee'], + distance: 120 + }, + geometry: { type: 'Point', coordinates: [-122.421, 37.762] } + } + ] +}; + +const isochroneResponse = { + features: [ + { + properties: { contour: 5 }, + geometry: { type: 'Polygon', coordinates: [] } + }, + { + properties: { contour: 10 }, + geometry: { type: 'Polygon', coordinates: [] } + }, + { + properties: { contour: 15 }, + geometry: { type: 'Polygon', coordinates: [] } + } + ] +}; + +function setupMockHttp(responses: Record) { + const mockFetch = vi.fn().mockImplementation((url: string) => { + for (const [key, body] of Object.entries(responses)) { + if (url.includes(key)) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => body + }); + } + } + return Promise.resolve({ ok: false, status: 404, json: async () => ({}) }); + }); + + const { httpRequest } = setupHttpRequest(); + // Override with our multi-response mock + const tool = new GroundLocationTool({ + httpRequest: mockFetch as unknown as typeof httpRequest + }); + + return { tool, mockFetch }; +} + +describe('GroundLocationTool', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('returns neighborhood context without sampling', async () => { + const { tool, mockFetch } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse, + 'isochrone/v1': isochroneResponse + }); + + const result = await tool.run({ longitude: -122.419, latitude: 37.759 }); + + expect(result.isError).toBe(false); + const text = (result.content[0] as { type: string; text: string }).text; + expect(text).toContain('Mission District'); + expect(text).toContain('neighborhood context'); + // geocode + isochrone called, no category search (no query) + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + + it('performs category search when query is provided', async () => { + const { tool, mockFetch } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse, + 'category/coffee': categoryResponse, + 'isochrone/v1': isochroneResponse + }); + + const result = await tool.run({ + longitude: -122.419, + latitude: 37.759, + query: 'coffee' + }); + + expect(result.isError).toBe(false); + const text = (result.content[0] as { type: string; text: string }).text; + expect(text).toContain('Four Barrel Coffee'); + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining('category/coffee') + ); + }); + + it('skips isochrone for routing strategy via sampling', async () => { + const { tool, mockFetch } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse + }); + + // Simulate a server with sampling support that returns 'routing' + const mockServer = { + server: { + getClientCapabilities: () => ({ sampling: {} }), + createMessage: vi.fn().mockResolvedValue({ + role: 'assistant', + content: { type: 'text', text: 'routing' }, + model: 'test', + stopReason: 'endTurn' + }), + sendLoggingMessage: vi.fn() + } + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (tool as any).server = mockServer; + + const result = await tool.run({ + longitude: -122.419, + latitude: 37.759, + query: 'coffee shop' + }); + + expect(result.isError).toBe(false); + const text = (result.content[0] as { type: string; text: string }).text; + expect(text).toContain('routing coordinates'); + // geocode + category search; no isochrone for routing + expect(mockFetch).not.toHaveBeenCalledWith( + expect.stringContaining('isochrone'), + expect.anything() + ); + expect(mockServer.server.createMessage).toHaveBeenCalledOnce(); + }); + + it('uses address/poi geocode types for routing strategy', async () => { + const { tool, mockFetch } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse + }); + + const mockServer = { + server: { + getClientCapabilities: () => ({ sampling: {} }), + createMessage: vi.fn().mockResolvedValue({ + role: 'assistant', + content: { type: 'text', text: 'routing' }, + model: 'test', + stopReason: 'endTurn' + }), + sendLoggingMessage: vi.fn() + } + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (tool as any).server = mockServer; + + await tool.run({ longitude: -122.419, latitude: 37.759 }); + + const geocodeCall = mockFetch.mock.calls.find((c: string[]) => + c[0].includes('geocode/v6/reverse') + ); + expect(geocodeCall?.[0]).toContain('types=address%2Cpoi'); + }); + + it('falls back to neighborhood when sampling returns unknown value', async () => { + const { tool } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse, + 'isochrone/v1': isochroneResponse + }); + + const mockServer = { + server: { + getClientCapabilities: () => ({ sampling: {} }), + createMessage: vi.fn().mockResolvedValue({ + role: 'assistant', + content: { type: 'text', text: 'something unexpected' }, + model: 'test', + stopReason: 'endTurn' + }), + sendLoggingMessage: vi.fn() + } + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (tool as any).server = mockServer; + + const result = await tool.run({ longitude: -122.419, latitude: 37.759 }); + + expect(result.isError).toBe(false); + const text = (result.content[0] as { type: string; text: string }).text; + expect(text).toContain('neighborhood context'); + }); + + it('falls back gracefully when sampling throws', async () => { + const { tool } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse, + 'isochrone/v1': isochroneResponse + }); + + const mockServer = { + server: { + getClientCapabilities: () => ({ sampling: {} }), + createMessage: vi.fn().mockRejectedValue(new Error('sampling failed')), + sendLoggingMessage: vi.fn() + } + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (tool as any).server = mockServer; + + const result = await tool.run({ longitude: -122.419, latitude: 37.759 }); + + expect(result.isError).toBe(false); + const text = (result.content[0] as { type: string; text: string }).text; + expect(text).toContain('neighborhood context'); + }); + + it('boosts poi limit when strategy is poi', async () => { + const { tool, mockFetch } = setupMockHttp({ + 'geocode/v6/reverse': geocodeResponse, + 'category/restaurant': categoryResponse, + 'isochrone/v1': isochroneResponse + }); + + const mockServer = { + server: { + getClientCapabilities: () => ({ sampling: {} }), + createMessage: vi.fn().mockResolvedValue({ + role: 'assistant', + content: { type: 'text', text: 'poi' }, + model: 'test', + stopReason: 'endTurn' + }), + sendLoggingMessage: vi.fn() + } + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (tool as any).server = mockServer; + + await tool.run({ + longitude: -122.419, + latitude: 37.759, + query: 'restaurant', + limit: 5 // below the poi minimum of 15 + }); + + const categoryCall = mockFetch.mock.calls.find((c: string[]) => + c[0].includes('category/restaurant') + ); + expect(categoryCall?.[0]).toContain('limit=15'); + }); +});