Implementing a Gateway Resolver Extension
Grafbase Gateway extensions give you powerful ways to extend your gateway's functionality. You can add custom logic like authentication, authorization, and data transformation.
This guide shows you how to implement a custom resolver extension for your Grafbase gateway. You'll create two new directives: @restEndpoint
and @rest
, which let you define REST endpoints and map them to GraphQL queries.
Before you begin, you need:
- A working Rust installation with rustup
- The latest version of the Grafbase CLI
- The latest version of the Grafbase Gateway
The grafbase
command-line tool helps you create a new extension. Run this command:
grafbase extension init --type resolver rest
This creates a new rest
directory with these files and directories:
rest
├── Cargo.toml
├── definitions.graphql
├── extension.toml
├── src
│ └── lib.rs
└── tests
└── integration_tests.rs
Cargo.toml
configures the Rust project. Add dependencies and build configuration here.definitions.graphql
contains GraphQL directives and types for your extension.extension.toml
configures the extension with its name, version, kind, and other options.src/lib.rs
contains your extension's main implementation code.tests/integration_tests.rs
contains integration tests for your extension.
The rest
extension uses two directives to define the REST API endpoint and the HTTP method to use:
directive @restEndpoint(
name: String,
http: HttpEndpointDefinition
) repeatable on SCHEMA
directive @rest(
endpoint: String,
http: HttpRequestDefinition,
selection: String
) on FIELD_DEFINITION
input HttpEndpointDefinition {
baseURL: String
}
input HttpRequestDefinition {
method: HttpMethod
path: String
}
enum HttpMethod {
GET
}
@restEndpoint
defines the REST API endpoint. You can write this directive multiple times on the schema definition.@rest
defines the REST API path and selects the endpoint to use. Add this directive to any field definition that should trigger the extension.
Save the directives and input types to a file called definitions.graphql
.
An example of a subgraph schema using the directives:
@restEndpoint(
name: "restCountries",
http: {
baseURL: "https://restcountries.com/v3.1"
}
)
type Country {
name: String!
}
type Query {
listAllCountries: [Country!]! @rest(
endpoint: "restCountries",
http: {
method: GET,
path: "/all"
},
selection: "[.[] | { name: .name.official }]"
)
}
We created one REST endpoint called restCountries
. In the listAllCountries
field definition, we added the @rest
directive to define the REST API path and endpoint to use, along with the HTTP method GET.
The selection
argument defines the GraphQL selection set that transforms the REST API response into a GraphQL response. This follows the jq query syntax, which we implement in the extension.
Add the following configuration to extension.toml
:
[extension]
name = "rest"
version = "0.1.0"
kind = "resolver"
[directives]
definitions = "definitions.graphql"
field_resolvers = ["rest"]
name
identifies the extension.version
specifies the extension version.kind
sets the extension type to eitherresolver
orauth
.definitions
points to the GraphQL schema file we created.field_resolvers
lists the field resolver names that the extension provides and triggers.
Find up to date information on the Grafbase Rust SDK in the official documentation.
First install all the needed dependencies:
cargo add jaq-parse serde serde_json http jaq-core jaq-std
cargo add url --features serde_json
cargo add jaq-json --features serde_json
Open the src/lib.rs
source code in an editor of your choice. We recommend using an editor with Rust IDE features, such as RustRover or any other editor with language server protocol support, combined with the rust-analyzer.
First create a struct that acts as a gateway resolver extension:
#[derive(ResolverExtension)]
struct RestExtension {
endpoints: Vec<RestEndpoint>,
filters: HashMap<String, Filter<Native<Val>>>,
arena: Arena,
}
Our extension tracks all @restEndpoint
directives in the endpoints
field, and all compiled selection filters in the filters
field. The arena
is an arena allocator needed later for allocating memory for the selection parser.
Implement the Extension
trait for the RestExtension
:
impl Extension for RestExtension {
fn new(schema_directives: Vec<Directive>, _: Configuration) -> Result<Self, Box<dyn std::error::Error>> {
let mut endpoints = Vec::<RestEndpoint>::new();
// Iterates over all `@restEndpoint` directives from the subgraph.
for directive in schema_directives {
// Creates a new `RestEndpoint` from the given `Directive`.
let endpoint = RestEndpoint {
subgraph_name: directive.subgraph_name().to_string(),
args: directive.arguments()?,
};
endpoints.push(endpoint);
}
// Sorts the endpoints by name and subgraph name. This allows fast lookup
// and good cache efficiency.
endpoints.sort_by(|a, b| {
let by_name = a.args.name.cmp(&b.args.name);
let by_subgraph = a.subgraph_name.cmp(&b.subgraph_name);
by_name.then(by_subgraph)
});
// Arena allocator needed later for allocating memory for the selection parser.
let arena = Arena::default();
Ok(Self {
endpoints,
filters: HashMap::new(),
arena,
})
}
}
The RestEndpoint
struct uses Serde to deserialize the @restEndpoint
directive:
#[derive(Debug)]
pub struct RestEndpoint {
pub subgraph_name: String,
pub args: RestEndpointArgs,
}
#[derive(serde::Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct RestEndpointArgs {
pub name: String,
pub http: HttpSettings,
}
#[derive(serde::Deserialize, Debug)]
pub struct HttpSettings {
#[serde(rename = "baseURL")]
pub base_url: Url,
}
When we call the arguments()
method from the Directive
in the constructor, it deserializes the arguments into the RestEndpointArgs
and HttpSettings
structs.
Next implement the Resolver
trait for the RestExtension
struct:
impl Resolver for RestExtension {
fn resolve_field(
&mut self,
_: SharedContext,
directive: Directive,
_: FieldDefinition,
_: FieldInputs,
) -> Result<FieldOutput, Error> {
...
}
}
We only need the directive
argument, which contains the arguments provided in the @rest
directive. We access the arguments using the arguments()
method on the Directive
struct:
let rest: Rest<'_> = directive.arguments().map_err(|e| Error {
extensions: Vec::new(),
message: format!("Could not parse directive arguments: {e}"),
})?;
This deserializes the arguments into the Rest
struct, defined as:
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Rest<'a> {
pub endpoint: &'a str,
pub http: HttpCall<'a>,
pub selection: &'a str,
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HttpCall<'a> {
pub method: HttpMethod,
pub path: &'a str,
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum HttpMethod {
Get,
}
Next find an endpoint in the @restEndpoint
directives. Every subgraph must define its own endpoints, but the extension sees all subgraphs and their endpoints. That's why we need to do the lookup with endpoint and subgraph names:
let Some(endpoint) = self.get_endpoint(rest.endpoint, directive.subgraph_name()) else {
return Err(Error {
extensions: Vec::new(),
message: format!("Endpoint not found: {}", rest.endpoint),
});
};
The easiest lookup method would use a HashMap
, but for small amounts of endpoints a sorted vector can fit in the CPU cache, and provides faster lookup speed:
impl RestExtension {
pub fn get_endpoint(&self, name: &str, subgraph_name: &str) -> Option<&RestEndpoint> {
self.endpoints
.binary_search_by(|e| {
let by_name = e.args.name.as_str().cmp(name);
let by_subgraph = e.subgraph_name.as_str().cmp(subgraph_name);
by_name.then(by_subgraph)
})
.map(|i| &self.endpoints[i])
.ok()
}
}
Next combine the base URL with the endpoint path:
let mut url = endpoint.args.http.base_url.clone();
let path = rest.http.path.strip_prefix("/").unwrap_or(rest.http.path);
if !path.is_empty() {
let mut path_segments = url.path_segments_mut().map_err(|_| Error {
extensions: Vec::new(),
message: "Could not parse URL".to_string(),
})?;
path_segments.push(path);
}
Now call the REST endpoint and parse the response:
// Build a new HTTP request, using the method from our deserialized `@rest` directive.
let request = HttpRequest::builder(url, rest.http.method.into()).build();
// Execute the HTTP request and handle any errors.
let result = http::execute(&request).map_err(|e| Error {
extensions: Vec::new(),
message: format!("HTTP request failed: {e}"),
})?;
// Check if the HTTP request succeeded.
if !result.status().is_success() {
return Err(Error {
extensions: Vec::new(),
message: format!("HTTP request failed with status: {}", result.status()),
});
}
// Check if the HTTP request succeeded.
if !result.status().is_success() {
return Err(Error {
extensions: Vec::new(),
message: format!("HTTP request failed with status: {}", result.status()),
});
}
// Parse the response body as arbitrary JSON.
let data: serde_json::Value = result.json().map_err(|e| Error {
extensions: Vec::new(),
message: format!("Error deserializing response: {e}"),
})?;
For the http method conversion, implement a From
trait from our type to http::Method
:
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum HttpMethod {
Get,
}
impl From<HttpMethod> for ::http::Method {
fn from(method: HttpMethod) -> Self {
match method {
HttpMethod::Get => Self::GET,
}
}
}
Now implement our jq selection filter based on the jaq-core docs:
impl RestExtension {
pub fn create_filter<'a>(&'a mut self, selection: &str) -> Result<&'a Filter<Native<Val>>, Error> {
if !self.filters.contains_key(selection) {
let program = File {
code: selection,
path: (),
};
let loader = Loader::new(jaq_std::defs().chain(jaq_json::defs()));
let modules = loader.load(&self.arena, program).map_err(|e| {
let error = e.first().map(|e| e.0.code).unwrap_or_default();
Error {
extensions: Vec::new(),
message: format!("The selection is not valid jq syntax: `{error}`"),
}
})?;
let filter = Compiler::default()
.with_funs(jaq_std::funs().chain(jaq_json::funs()))
.compile(modules)
.map_err(|e| {
let error = e.first().map(|e| e.0.code).unwrap_or_default();
Error {
extensions: Vec::new(),
message: format!("The selection is not valid jq syntax: `{error}`"),
}
})?;
self.filters.insert(selection.to_string(), filter);
}
Ok(self.filters.get(selection).unwrap())
}
}
We follow the jaq-core documentation with our implementation. We compile the filter once and store it in a cache for future use.
Now continue our resolve_field
implementation by using the filter we just created, and add the final JSON to the response:
let filter = self.create_filter(rest.selection)?;
let inputs = RcIter::new(core::iter::empty());
let filtered = filter.run((Ctx::new([], &inputs), Val::from(data)));
let mut results = FieldOutput::new();
for result in filtered {
match result {
Ok(result) => results.push_value(serde_json::Value::from(result)),
Err(e) => results.push_error(Error {
extensions: Vec::new(),
message: format!("Error parsing result value: {e}"),
}),
}
}
Ok(results)
Find the full example in the grafbase repository.
The Grafbase SDK includes test utilities for testing your extensions. grafbase extension init
includes the required dependencies, but we need to add one more crate to start testing:
cargo add --dev wiremock
Find the tests in the tests/integration_tests.rs
file. See the full example in the grafbase repository.
Let's write a test that:
- Starts a mock REST server in a separate thread
- Creates a virtual subgraph and configures the REST extension
- Creates a test runner that composes the subgraph and starts the Grafbase Gateway
- Runs a GraphQL query against the gateway and asserts the response
Here's the function that generates a virtual subgraph:
fn subgraph(rest_endpoint: &str) -> ExtensionOnlySubgraph {
let extension_path = std::env::current_dir().unwrap().join("build");
let path_str = format!("file://{}", extension_path.display());
let schema = formatdoc! {r#"
extend schema
@link(url: "https://specs.apollo.dev/federation/v2.0", import: ["@key", "@shareable"])
@link(url: "{path_str}", import: ["@restEndpoint", "@rest"])
@restEndpoint(
name: "endpoint",
http: {{
baseURL: "{rest_endpoint}"
}}
)
type Query {{
users: [User!]! @rest(
endpoint: "endpoint",
http: {{
method: GET,
path: "/users"
}},
selection: "[.[] | {{ id, name, age }}]"
)
}}
type User {{
id: ID!
name: String!
age: Int!
}}
"#};
DynamicSchema::builder(schema)
.into_extension_only_subgraph("test", &extension_path)
.unwrap()
}
The schema takes the extension build directory and mock REST service URL as parameters.
With this function ready, we can write our first test:
#[tokio::test]
async fn get_all_fields() {
// Our mock response from the REST service
let response_body = json!([
{
"id": "1",
"name": "John Doe",
"age": 30,
},
{
"id": "2",
"name": "Jane Doe",
"age": 25,
}
]);
let template = ResponseTemplate::new(200).set_body_json(response_body);
let mock_server = mock_server("/users", template).await;
// Generate the virtual subgraph with the URL our mock server is running on
let subgraph = subgraph(&mock_server.uri());
// Build a test config with the subgraph and enable networking. This tries to look for
// the grafbase-gateway and grafbase binaries from the PATH. You can also point it to the specific
// binaries by calling the `with_gateway` and `with_cli` methods.
let config = TestConfig::builder()
.with_subgraph(subgraph)
.enable_networking()
.build("")
.unwrap();
// Start the Grafbase Gateway and return a runner instance to run queries against.
let runner = TestRunner::new(config).await.unwrap();
let query = indoc! {r#"
query {
users {
id
name
age
}
}
"#};
let result: serde_json::Value = runner.graphql_query(query).send().await.unwrap();
insta::assert_json_snapshot!(result, @r#"
{
"data": {
"users": [
{
"id": "1",
"name": "John Doe",
"age": 30
},
{
"id": "2",
"name": "Jane Doe",
"age": 25
}
]
}
}
"#);
}
Run the tests with cargo
:
cargo test
Make gateway errors visible by calling enable_stdout
and enable_stderr
methods on the TestConfig
builder.
Build the extension after completing development:
grafbase extension build
Copy the build
directory to your Grafbase project:
cp -r build/* path/to/your/grafbase/project/rest/
Add the extension to the gateway configuration:
[extensions]
rest = { path = "rest" }
Launch the gateway:
grafbase-gateway --schema federated-schema.graphql --config grafbase.toml