diff --git a/env/config.h b/env/config.h index e427832ff..68e4a1dd9 100644 --- a/env/config.h +++ b/env/config.h @@ -32,6 +32,11 @@ class Config { void SetName(std::string name) { name_ = std::move(name); } std::string GetName() const { return name_; } + void SetContextType(std::string context_type) { + context_type_ = std::move(context_type); + } + std::string GetContextType() const { return context_type_; } + struct ContainerConfig { std::string name; std::vector abbreviations; @@ -150,6 +155,7 @@ class Config { private: std::string name_; + std::string context_type_; ContainerConfig container_config_; std::vector extension_configs_; StandardLibraryConfig standard_library_config_; diff --git a/env/env.cc b/env/env.cc index 6cd3a3cdc..22d24295e 100644 --- a/env/env.cc +++ b/env/env.cc @@ -138,6 +138,11 @@ absl::StatusOr> Env::NewCompilerBuilder() { for (const auto& abbr : config_.GetContainerConfig().abbreviations) { CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); } + + if (!config_.GetContextType().empty()) { + CEL_RETURN_IF_ERROR( + checker_builder.AddContextDeclaration(config_.GetContextType())); + } for (const auto& alias : config_.GetContainerConfig().aliases) { CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); } diff --git a/env/env_test.cc b/env/env_test.cc index b599aa569..fda87dfab 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -344,6 +344,25 @@ TEST(ContainerConfigTest, ContainerConfigWithAliases) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContextVariableConfigTest, Basic) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContextType("cel.expr.conformance.proto3.TestAllTypes"); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + // Top-level fields of TestAllTypes like "single_int32" should resolve + // successfully. + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("single_int32 > 10")); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto result_invalid, + compiler->Compile("non_existent_field > 10")); + EXPECT_THAT(result_invalid.GetIssues(), Not(IsEmpty())); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 1bbfe6b36..a509412bf 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -1245,6 +1245,25 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, } out << YAML::EndSeq; } + +absl::Status ParseContextVariableConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node context_variable = root["context_variable"]; + if (!context_variable.IsDefined()) { + return absl::OkStatus(); + } + if (!context_variable.IsMap()) { + return YamlError(yaml, context_variable, + "Node 'context_variable' is not a map"); + } + const YAML::Node type_name = context_variable["type_name"]; + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + config.SetContextType(GetString(yaml, type_name)); + return absl::OkStatus(); +} + } // namespace absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { @@ -1263,6 +1282,7 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContextVariableConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); return config; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index a60048617..73f83089c 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -216,6 +216,16 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { EXPECT_THAT(type_info.params[1].params, IsEmpty()); } +TEST(EnvYamlTest, ParseContextVariableConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type_name: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: