Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138718.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138718
summary: Add `project_routing` option
area: SQL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import static org.elasticsearch.xpack.sql.action.Protocol.REQUEST_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.action.Protocol.TIME_ZONE_NAME;
import static org.elasticsearch.xpack.sql.action.Protocol.VERSION_NAME;
import static org.elasticsearch.xpack.sql.proto.CoreProtocol.PROJECT_ROUTING_NAME;

/**
* Base class for requests that contain sql queries (Query and Translate)
Expand Down Expand Up @@ -91,6 +92,7 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
private QueryBuilder filter = null;
private List<SqlTypedParamValue> params = emptyList();
private Map<String, Object> runtimeMappings = emptyMap();
private String projectRouting;

static final ParseField QUERY = new ParseField(QUERY_NAME);
static final ParseField CURSOR = new ParseField(CURSOR_NAME);
Expand All @@ -104,6 +106,7 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
static final ParseField MODE = new ParseField(MODE_NAME);
static final ParseField CLIENT_ID = new ParseField(CLIENT_ID_NAME);
static final ParseField VERSION = new ParseField(VERSION_NAME);
static final ParseField PROJECT_ROUTING = new ParseField(PROJECT_ROUTING_NAME);

public AbstractSqlQueryRequest() {
super();
Expand Down Expand Up @@ -155,6 +158,7 @@ protected static <R extends AbstractSqlQueryRequest> ObjectParser<R, Void> objec
);
parser.declareObject(AbstractSqlQueryRequest::filter, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER);
parser.declareObject(AbstractSqlQueryRequest::runtimeMappings, (p, c) -> p.map(), SearchSourceBuilder.RUNTIME_MAPPINGS_FIELD);
parser.declareString(AbstractSqlQueryRequest::projectRouting, PROJECT_ROUTING);
return parser;
}

Expand Down Expand Up @@ -423,6 +427,14 @@ public AbstractSqlQueryRequest runtimeMappings(Map<String, Object> runtimeMappin
return this;
}

public String projectRouting() {
return projectRouting;
}

public void projectRouting(String projectRouting) {
this.projectRouting = projectRouting;
}

public AbstractSqlQueryRequest(StreamInput in) throws IOException {
super(in);
query = in.readString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,9 @@ public SqlTranslateRequestBuilder zoneId(ZoneId zoneId) {
request.zoneId(zoneId);
return this;
}

public SqlTranslateRequestBuilder projectRouting(String projectRouting) {
request.projectRouting(projectRouting);
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class CoreProtocol {
public static final String INDEX_INCLUDE_FROZEN_NAME = "index_include_frozen";
public static final String RUNTIME_MAPPINGS_NAME = "runtime_mappings";
public static final String ALLOW_PARTIAL_SEARCH_RESULTS_NAME = "allow_partial_search_results";
public static final String PROJECT_ROUTING_NAME = "project_routing";
// async
public static final String WAIT_FOR_COMPLETION_TIMEOUT_NAME = "wait_for_completion_timeout";
public static final String KEEP_ON_COMPLETION_NAME = "keep_on_completion";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.search.crossproject.CrossProjectModeDecider;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.ql.InvalidArgumentException;
import org.elasticsearch.xpack.sql.action.SqlQueryAction;
import org.elasticsearch.xpack.sql.action.SqlQueryRequest;

Expand All @@ -34,10 +36,11 @@
@ServerlessScope(Scope.PUBLIC)
public class RestSqlQueryAction extends BaseRestHandler {
private static final Logger LOGGER = LogManager.getLogger(RestSqlQueryAction.class);
private final Settings settings;

private final CrossProjectModeDecider crossProjectModeDecider;

public RestSqlQueryAction(Settings settings) {
this.settings = settings;
crossProjectModeDecider = new CrossProjectModeDecider(settings);
}

@Override
Expand All @@ -47,15 +50,21 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (settings != null && settings.getAsBoolean("serverless.cross_project.enabled", false)) {
// accept but drop project_routing param until fully supported
request.param("project_routing");
}

SqlQueryRequest sqlRequest;
try (XContentParser parser = request.contentOrSourceParamParser()) {
sqlRequest = SqlQueryRequest.fromXContent(parser);
}

String routingParam = request.param("project_routing");
if (routingParam != null) {
// takes precedence on the parameter in the body
sqlRequest.projectRouting(routingParam);
}
if (sqlRequest.projectRouting() != null && crossProjectModeDecider.crossProjectEnabled() == false) {
throw new InvalidArgumentException("[project_routing] is only allowed when cross-project search is enabled");
}

return channel -> {
RestCancellableNodeClient cancellableClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancellableClient.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
package org.elasticsearch.xpack.sql.plugin;

import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.search.crossproject.CrossProjectModeDecider;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.ql.InvalidArgumentException;
import org.elasticsearch.xpack.sql.action.SqlTranslateAction;
import org.elasticsearch.xpack.sql.action.SqlTranslateRequest;

Expand All @@ -29,6 +32,12 @@
@ServerlessScope(Scope.PUBLIC)
public class RestSqlTranslateAction extends BaseRestHandler {

private final CrossProjectModeDecider crossProjectModeDecider;

public RestSqlTranslateAction(Settings settings) {
this.crossProjectModeDecider = new CrossProjectModeDecider(settings);
}

@Override
public List<Route> routes() {
return List.of(new Route(GET, SQL_TRANSLATE_REST_ENDPOINT), new Route(POST, SQL_TRANSLATE_REST_ENDPOINT));
Expand All @@ -40,7 +49,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
try (XContentParser parser = request.contentOrSourceParamParser()) {
sqlRequest = SqlTranslateRequest.fromXContent(parser);
}

String routingParam = request.param("project_routing");
if (routingParam != null) {
// takes precedence on the parameter in the body
sqlRequest.projectRouting(routingParam);
}
if (sqlRequest.projectRouting() != null && crossProjectModeDecider.crossProjectEnabled() == false) {
throw new InvalidArgumentException("[project_routing] is only allowed when cross-project search is enabled");
}
return channel -> client.executeLocally(SqlTranslateAction.INSTANCE, sqlRequest, new RestToXContentListener<>(channel));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public List<RestHandler> getRestHandlers(

return Arrays.asList(
new RestSqlQueryAction(settings),
new RestSqlTranslateAction(),
new RestSqlTranslateAction(settings),
new RestSqlClearCursorAction(),
new RestSqlStatsAction(),
new RestSqlAsyncGetResultsAction(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ public static void operation(
request.indexIncludeFrozen(),
new TaskId(clusterService.localNode().getId(), task.getId()),
task,
request.allowPartialSearchResults()
request.allowPartialSearchResults(),
request.projectRouting()
);

if (Strings.hasText(request.cursor()) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ protected void doExecute(Task task, SqlTranslateRequest request, ActionListener<
Protocol.INDEX_INCLUDE_FROZEN,
null,
null,
Protocol.ALLOW_PARTIAL_SEARCH_RESULTS
Protocol.ALLOW_PARTIAL_SEARCH_RESULTS,
request.projectRouting()
);

planExecutor.searchSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class SqlConfiguration extends org.elasticsearch.xpack.ql.session.Configu
@Nullable
private final Map<String, Object> runtimeMappings;
private final boolean allowPartialSearchResults;
private final String projectRouting;

public SqlConfiguration(
ZoneId zi,
Expand All @@ -61,7 +62,8 @@ public SqlConfiguration(
boolean includeFrozen,
@Nullable TaskId taskId,
@Nullable SqlQueryTask task,
boolean allowPartialSearchResults
boolean allowPartialSearchResults,
String projectRouting
) {
super(zi, username, clusterName);

Expand All @@ -79,6 +81,7 @@ public SqlConfiguration(
this.taskId = taskId;
this.task = task;
this.allowPartialSearchResults = allowPartialSearchResults;
this.projectRouting = projectRouting;
}

public String catalog() {
Expand Down Expand Up @@ -136,4 +139,8 @@ public SqlQueryTask task() {
public boolean allowPartialSearchResults() {
return allowPartialSearchResults;
}

public String projectRouting() {
return projectRouting;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ private SqlTestUtils() {}
false,
null,
null,
false
false,
null
);

public static SqlConfiguration randomConfiguration(ZoneId providedZoneId, SqlVersion sqlVersion) {
Expand All @@ -82,7 +83,8 @@ public static SqlConfiguration randomConfiguration(ZoneId providedZoneId, SqlVer
randomBoolean(),
new TaskId(randomAlphaOfLength(10), taskId),
randomTask(taskId, mode, sqlVersion),
randomBoolean()
randomBoolean(),
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public void testDatabaseFunctionOutput() {
randomBoolean(),
null,
null,
randomBoolean()
randomBoolean(),
null
);
Analyzer analyzer = analyzer(sqlConfig, IndexResolution.valid(test));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public void testNoUsernameFunctionOutput() {
randomBoolean(),
null,
null,
randomBoolean()
randomBoolean(),
null
);
Analyzer analyzer = analyzer(sqlConfig, IndexResolution.valid(test));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ private int executeCommandInOdbcModeAndCountRows(String sql) {
false,
null,
null,
false
false,
null
);
Tuple<Command, SqlSession> tuple = sql(sql, emptyList(), config, MAPPING1);

Expand Down Expand Up @@ -350,7 +351,8 @@ private void executeCommand(
false,
null,
null,
false
false,
null
);
Tuple<Command, SqlSession> tuple = sql(sql, params, config, mapping);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public class SysTablesTests extends ESTestCase {
true,
null,
null,
false
false,
null
);

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ private Tuple<Command, SqlSession> sql(String sql, Mode mode, @Nullable SqlVersi
false,
null,
null,
false
false,
null
);
EsIndex test = new EsIndex("test", SqlTypesTests.loadMapping("mapping-multi-field-with-nested.json", true));
Analyzer analyzer = analyzer(configuration, IndexResolution.valid(test));
Expand Down