diff --git a/MODULE.bazel b/MODULE.bazel index 43d0485d2..187d68164 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -31,6 +31,7 @@ bazel_dep( name = "rules_python", version = "1.6.3", ) +bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep( name = "protobuf", version = "34.1", @@ -96,3 +97,16 @@ bazel_dep( name = "yaml-cpp", version = "0.9.0", ) + +_CEL_POLICY_TAG = "ebfb2361f47080af643c14cf4da4c2b551a68740" + +_CEL_POLICY_SHA = "ea69e9c6b7bd5bc37d358148aebd2fcca38bc7c45a23feb635de72338e0327c1" + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "cel_policy", + sha256 = _CEL_POLICY_SHA, + strip_prefix = "cel-policy-%s" % _CEL_POLICY_TAG, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % _CEL_POLICY_TAG, +) diff --git a/conformance/policy/BUILD b/conformance/policy/BUILD new file mode 100644 index 000000000..57657a777 --- /dev/null +++ b/conformance/policy/BUILD @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//conformance/policy:policy_conformance_test.bzl", + "cel_policy_conformance_test", +) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "policy_conformance_test_lib", + testonly = True, + srcs = ["policy_conformance_test.cc"], + deps = [ + "//common:ast", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//extensions/protobuf:enum_adapter", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//internal:testing_no_main", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "//policy:cel_policy_validation_result", + "//policy:compiler", + "//policy:test_util", + "//policy:yaml_policy_parser", + "//runtime", + "//runtime:activation", + "//runtime:function_adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_policy_conformance_test( + name = "policy_conformance_test", + skip_tests = [ + # TODO(b/506179116): Fix these. + # Need to add k8s custom yaml parser and mock runtime. + "k8s", + # Need to add support for context proto from yaml. + "context_pb", + ], + test_files = [ + "@cel_policy//conformance:testdata", + ], +) diff --git a/conformance/policy/policy_conformance_test.bzl b/conformance/policy/policy_conformance_test.bzl new file mode 100644 index 000000000..d39c8f5ef --- /dev/null +++ b/conformance/policy/policy_conformance_test.bzl @@ -0,0 +1,47 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating policy conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +_TESTDATA_DIR = "cel_policy/conformance/testdata" + +def cel_policy_conformance_test(name, test_files, skip_tests = [], **kwargs): + """Generates a policy conformance test target. + + Args: + name: Name of the test target. + test_files: List of targets or files representing the test data. + skip_tests: List of test cases to skip. + testdata_dir: Path to testdata directory under runfiles. + **kwargs: Additional arguments passed to the underlying cc_test. + """ + args = ["--gunit_fail_if_no_test_linked"] + args.append("--testdata_dir='%s'" % _TESTDATA_DIR) + + if skip_tests: + args.append("--skip_tests=" + ",".join(skip_tests)) + + cc_test( + name = name, + data = test_files, + deps = [ + "//conformance/policy:policy_conformance_test_lib", + ], + args = args, + **kwargs + ) diff --git a/conformance/policy/policy_conformance_test.cc b/conformance/policy/policy_conformance_test.cc new file mode 100644 index 000000000..0087216d5 --- /dev/null +++ b/conformance/policy/policy_conformance_test.cc @@ -0,0 +1,562 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +// NOLINTNEXTLINE(build/c++17) for OSS compatibility +#include + +#include "cel/expr/eval.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/protobuf/enum_adapter.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/compiler.h" +#include "policy/test_util.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(std::string, testdata_dir, "", "Path to testdata directory."); +ABSL_FLAG(std::vector, skip_tests, {}, + "Comma-separated list of tests to skip."); + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; + +// Implementations for extension functions referenced in conformance tests. +cel::Value LocationCode(const cel::StringValue& ip, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory, google::protobuf::Arena* arena) { + std::string ip_str = ip.ToString(); + if (ip_str == "10.0.0.1") return cel::StringValue(arena, "us"); + if (ip_str == "10.0.0.2") return cel::StringValue(arena, "de"); + return cel::StringValue(arena, "ir"); +} + +// TODO(uncreated-issue/92): This should be migrated to use the testrunner utility +// after adding support for reading the yaml specification for envs/tests. +class InputEvaluator { + public: + static absl::StatusOr> Create( + const std::shared_ptr& pool) { + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + // Enable default extensions (optional, bindings) + cel::Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + env.SetConfig(config); + env_runtime.SetConfig(config); + + auto compiler_builder_or = env.NewCompilerBuilder(); + CEL_ASSIGN_OR_RETURN(auto compiler_builder, std::move(compiler_builder_or)); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + auto runtime_builder_or = env_runtime.CreateRuntimeBuilder(); + CEL_ASSIGN_OR_RETURN(auto runtime_builder, std::move(runtime_builder_or)); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + return absl::WrapUnique( + new InputEvaluator(std::move(compiler), std::move(runtime))); + } + + absl::StatusOr Evaluate(absl::string_view expr_str, + google::protobuf::Arena* arena) const { + CEL_ASSIGN_OR_RETURN(auto validation_result, compiler_->Compile(expr_str)); + if (!validation_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to compile input expr: ", expr_str)); + } + CEL_ASSIGN_OR_RETURN(auto ast, validation_result.ReleaseAst()); + CEL_ASSIGN_OR_RETURN( + auto program, + runtime_->CreateProgram(std::make_unique(std::move(*ast)))); + cel::Activation activation; + return program->Evaluate(arena, activation); + } + + private: + InputEvaluator(std::unique_ptr compiler, + std::unique_ptr runtime) + : compiler_(std::move(compiler)), runtime_(std::move(runtime)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; +}; + +absl::StatusOr EvaluateInputValue( + const cel::expr::conformance::test::InputValue& input_val, + const InputEvaluator& evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + if (input_val.has_expr()) { + return evaluator.Evaluate(input_val.expr(), arena); + } + if (input_val.has_value()) { + return cel::test::FromExprValue(input_val.value(), descriptor_pool, + message_factory, arena); + } + return absl::InvalidArgumentError("Empty InputValue"); +} + +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + CelValueMatcherImpl(cel::Value expected_val, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) + : expected_val_(std::move(expected_val)), + pool_(pool), + message_factory_(message_factory), + arena_(arena) {} + + bool MatchAndExplain(const cel::Value& actual_val, + testing::MatchResultListener* listener) const override { + cel::Value actual = actual_val; + if (actual.IsOptional() && !expected_val_.IsOptional()) { + auto opt_val = actual.AsOptional(); + if (opt_val->HasValue()) { + actual = opt_val->Value(); + } + } + cel::Value eq_result; + auto eq_status = actual.Equal(expected_val_, pool_, message_factory_, + arena_, &eq_result); + if (!eq_status.ok()) { + *listener << "equality check failed with status: " << eq_status; + return false; + } + if (!eq_result.IsTrue()) { + *listener << "expected: " << expected_val_.DebugString() + << "\nactual: " << actual.DebugString(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "is equal to " << expected_val_.DebugString(); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not equal to " << expected_val_.DebugString(); + } + + private: + cel::Value expected_val_; + const google::protobuf::DescriptorPool* pool_; + google::protobuf::MessageFactory* message_factory_; + google::protobuf::Arena* arena_; +}; + +absl::StatusOr> MakeExpectedValueMatcher( + const cel::expr::conformance::test::TestOutput& output, + const InputEvaluator& input_evaluator, const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + cel::Value expected_val; + if (output.has_result_expr()) { + CEL_ASSIGN_OR_RETURN(expected_val, + input_evaluator.Evaluate(output.result_expr(), arena)); + } else if (output.has_result_value()) { + CEL_ASSIGN_OR_RETURN(expected_val, + cel::test::FromExprValue(output.result_value(), pool, + message_factory, arena)); + } else { + return absl::InvalidArgumentError("Unsupported output kind"); + } + return testing::Matcher( + new CelValueMatcherImpl(expected_val, pool, message_factory, arena)); +} + +bool ShouldRunTest(absl::string_view test_name, + const std::vector& skip_tests) { + for (const std::string& skip : skip_tests) { + if (absl::StartsWith(test_name, skip)) { + return false; + } + } + return true; +} + +class PolicyTestSuiteRunner { + public: + PolicyTestSuiteRunner(std::string suite_name, + std::unique_ptr compiler, + std::unique_ptr runtime, + std::shared_ptr policy_source, + CelPolicyValidationResult compile_result, + std::shared_ptr pool, + std::shared_ptr input_evaluator, + bool expect_compile_fail = false) + : suite_name_(std::move(suite_name)), + compiler_(std::move(compiler)), + runtime_(std::move(runtime)), + policy_source_(std::move(policy_source)), + compile_result_(std::move(compile_result)), + pool_(std::move(pool)), + input_evaluator_(std::move(input_evaluator)), + expect_compile_fail_(expect_compile_fail) {} + + void RunTest(const cel::expr::conformance::test::TestCase& test, + absl::string_view full_test_name) { + const auto& output = test.output(); + + if (expect_compile_fail_) { + ASSERT_FALSE(compile_result_.IsValid()) + << "Expected compilation to fail in " << full_test_name; + ASSERT_TRUE(output.has_eval_error()) + << "Expected eval_error to be present in compile error test " + << full_test_name; + std::string err_msg = compile_result_.FormatIssues(); + for (const auto& expected_err : output.eval_error().errors()) { + EXPECT_THAT(err_msg, HasSubstr(expected_err.message())) + << "Did not find expected compile time error"; + } + return; + } + + // Compilation should have succeeded for evaluation tests + ASSERT_TRUE(compile_result_.IsValid()) + << "Compilation has validation errors in " << full_test_name << ": " + << compile_result_.FormatIssues(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime_->CreateProgram(std::make_unique( + *compile_result_.GetAst()))); + + // Parse Inputs and evaluate them + google::protobuf::Arena arena; + Activation activation; + for (const auto& [var_name, input_val] : test.input()) { + auto val_or = EvaluateInputValue( + input_val, *input_evaluator_, pool_.get(), + google::protobuf::MessageFactory::generated_factory(), &arena); + ASSERT_THAT(val_or.status(), IsOk()) + << "Failed to evaluate input '" << var_name << "' in " + << full_test_name; + activation.InsertOrAssignValue(var_name, *std::move(val_or)); + } + + // Evaluate Policy + auto eval_result_or = program->Evaluate(&arena, activation); + ASSERT_THAT(eval_result_or.status(), IsOk()) + << "Evaluation failed in " << full_test_name; + cel::Value actual_val = *eval_result_or; + + ASSERT_OK_AND_ASSIGN( + auto matcher, MakeExpectedValueMatcher( + output, *input_evaluator_, pool_.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + + // Apply matcher to the output of evaluation + EXPECT_THAT(actual_val, matcher) << "Test failed: " << full_test_name; + } + + private: + std::string suite_name_; + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::shared_ptr policy_source_; + CelPolicyValidationResult compile_result_; + std::shared_ptr pool_; + std::shared_ptr input_evaluator_; + bool expect_compile_fail_; +}; + +class CelPolicyTest : public testing::Test { + public: + explicit CelPolicyTest(std::shared_ptr runner, + cel::expr::conformance::test::TestCase test_case, + std::string full_test_name, bool skip) + : runner_(std::move(runner)), + test_case_(std::move(test_case)), + full_test_name_(std::move(full_test_name)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP() << "Skipping test: " << full_test_name_; + } + EXPECT_NO_FATAL_FAILURE(runner_->RunTest(test_case_, full_test_name_)); + } + + private: + std::shared_ptr runner_; + cel::expr::conformance::test::TestCase test_case_; + std::string full_test_name_; + bool skip_; +}; + + +absl::Status RegisterTestSuite( + const std::filesystem::path& dir_path, const std::string& suite_name, + const std::shared_ptr& input_evaluator, + const std::shared_ptr& pool, + const std::vector& skip_tests) { + // Check if the entire suite should be skipped (prefix match) + for (const auto& skip : skip_tests) { + if (suite_name == skip || + absl::StartsWith(suite_name, absl::StrCat(skip, "/"))) { + std::cout << "[ SKIPPED SUITE ] " << suite_name << std::endl; + return absl::OkStatus(); + } + } + + std::filesystem::path policy_path = dir_path / "policy.yaml"; + std::filesystem::path tests_path = dir_path / "tests.yaml"; + bool is_yaml = true; + if (!std::filesystem::exists(tests_path)) { + tests_path = dir_path / "tests.textproto"; + is_yaml = false; + } + std::filesystem::path config_path = dir_path / "config.yaml"; + + if (!std::filesystem::exists(policy_path) || + !std::filesystem::exists(tests_path)) { + // Not a valid test suite, assume it's a directory we don't care about. + return absl::OkStatus(); + } + + // Parse Environment Config + cel::Config config; + if (std::filesystem::exists(config_path)) { + std::string config_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(config_path.string(), &config_content)); + CEL_ASSIGN_OR_RETURN(config, cel::EnvConfigFromYaml(config_content)); + } + + // Enable default extensions (optional, bindings) in the config + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + + // Set up compiler & runtime environments + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler_builder, env.NewCompilerBuilder()); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + env_runtime.CreateRuntimeBuilder()); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + // Register locationCode in runtime + CEL_RETURN_IF_ERROR( + (cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("locationCode", LocationCode, + runtime_builder.function_registry()))); + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + // Parse Policy + std::string policy_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(policy_path.string(), &policy_content)); + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(policy_content, "policy.yaml")); + auto policy_source = std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse policy.yaml in ", suite_name, + "\nIssues:\n", parse_result.FormattedIssues())); + } + const CelPolicy* policy = parse_result.GetPolicy(); + + // Compile Policy (unexpected non-ok status represents a bug) + CEL_ASSIGN_OR_RETURN(CelPolicyValidationResult compile_result, + CompilePolicy(*compiler, *policy)); + + std::string tests_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(tests_path.string(), &tests_content)); + TestSuite test_suite; + if (is_yaml) { + CEL_ASSIGN_OR_RETURN(test_suite, + cel::test::ParsePolicyTestSuiteYaml(tests_content)); + } else { + if (!google::protobuf::TextFormat::ParseFromString(tests_content, &test_suite)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse text proto in ", tests_path.string())); + } + } + + auto runner = std::make_shared( + suite_name, std::move(compiler), std::move(runtime), + std::move(policy_source), std::move(compile_result), pool, + input_evaluator, absl::StrContains(suite_name, "compile_errors")); + + for (const auto& section : test_suite.sections()) { + std::string section_name = section.name(); + for (const auto& test : section.tests()) { + std::string test_name = test.name(); + std::string full_test_name = + absl::StrCat(suite_name, "/", section_name, "/", test_name); + + bool skip = !ShouldRunTest(full_test_name, skip_tests); + + testing::RegisterTest( + suite_name.c_str(), + absl::StrCat(section_name, "/", test_name).c_str(), nullptr, + test_name.c_str(), __FILE__, __LINE__, + [runner, test, full_test_name, skip]() -> CelPolicyTest* { + return new CelPolicyTest(runner, test, full_test_name, skip); + }); + } + } + return absl::OkStatus(); +} + +void RegisterAllTests() { + std::string testdata_dir_flag = absl::GetFlag(FLAGS_testdata_dir); + std::vector skip_tests = absl::GetFlag(FLAGS_skip_tests); + + std::string abs_testdata_dir = + cel::internal::ResolveRunfilesPath(testdata_dir_flag); + ABSL_CHECK(!abs_testdata_dir.empty()) + << "Could not find testdata directory: " << testdata_dir_flag; + + auto evaluator_or = InputEvaluator::Create(GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(evaluator_or.status()) << "Failed to create input evaluator"; + std::shared_ptr evaluator = std::move(evaluator_or.value()); + + // Walk the testdata directory + std::filesystem::path testdata_path(abs_testdata_dir); + ABSL_CHECK(std::filesystem::exists(testdata_path)) + << "Testdata path does not exist: " << testdata_path; + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(testdata_path)) { + if (!entry.is_directory()) { + continue; + } + std::filesystem::path dir_path = entry.path(); + // Check if this directory has policy.yaml and tests.yaml (or + // tests.textproto) + if (std::filesystem::exists(dir_path / "policy.yaml") && + (std::filesystem::exists(dir_path / "tests.yaml") || + std::filesystem::exists(dir_path / "tests.textproto"))) { + std::string suite_name = absl::StrReplaceAll( + std::filesystem::relative(dir_path, testdata_path).string(), + {{"\\", "/"}}); + + ABSL_CHECK_OK(RegisterTestSuite(dir_path, suite_name, evaluator, + GetSharedTestingDescriptorPool(), + skip_tests)); + } + } +} + +} // namespace +} // namespace cel + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + cel::RegisterAllTests(); + return RUN_ALL_TESTS(); +} diff --git a/internal/BUILD b/internal/BUILD index 0ac5c4e46..6d0efab72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -92,6 +92,7 @@ cc_library( hdrs = ["runfiles.h"], deps = [ "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@rules_cc//cc/runfiles", ], diff --git a/internal/runfiles.cc b/internal/runfiles.cc index 259e2e7ca..0e1ff045d 100644 --- a/internal/runfiles.cc +++ b/internal/runfiles.cc @@ -16,10 +16,13 @@ #include -#include "rules_cc/cc/runfiles/runfiles.h" +#include +#include "rules_cc/cc/runfiles/runfiles.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" + +#include "absl/status/status.h" #include "absl/strings/string_view.h" namespace cel::internal { @@ -37,4 +40,15 @@ std::string ResolveRunfilesPath(absl::string_view path) { return runfiles->Rlocation(std::string(path)); } +absl::Status GetFileContents(absl::string_view path, std::string* out) { + std::ifstream file{std::string(path)}; + if (!file.is_open()) { + return absl::NotFoundError( + absl::StrCat("Failed to open file: ", path)); + } + out->append((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + return absl::OkStatus(); +} + } // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h index 643c677b4..11fdcf337 100644 --- a/internal/runfiles.h +++ b/internal/runfiles.h @@ -11,12 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// Utilities for working with bazel runfiles. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" namespace cel::internal { @@ -25,6 +28,9 @@ namespace cel::internal { // Intended for resolving test cases from cel-spec and cel-policy. std::string ResolveRunfilesPath(absl::string_view path); +// Read contents of a file at a resolved path to a string. +absl::Status GetFileContents(absl::string_view path, std::string* out); + } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ diff --git a/policy/BUILD b/policy/BUILD new file mode 100644 index 000000000..19195be2b --- /dev/null +++ b/policy/BUILD @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cel_policy", + srcs = [ + "cel_policy.cc", + ], + hdrs = [ + "cel_policy.h", + ], + deps = [ + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cel_policy_test", + srcs = ["cel_policy_test.cc"], + deps = [ + ":cel_policy", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cel_policy_parser", + srcs = [ + "cel_policy_parse_context.cc", + "cel_policy_parse_result.cc", + ], + hdrs = [ + "cel_policy_parse_context.h", + "cel_policy_parse_result.h", + "cel_policy_parser.h", + ], + deps = [ + ":cel_policy", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "yaml_policy_parser", + srcs = [ + "yaml_policy_parser.cc", + ], + hdrs = ["yaml_policy_parser.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:source", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_library( + name = "cel_policy_validation_result", + srcs = [ + "cel_policy_validation_result.cc", + ], + hdrs = [ + "cel_policy_validation_result.h", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//policy/internal:issue_reporter", + "//policy/internal:optimizer_expr_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "yaml_policy_parser_test", + srcs = [ + "test_custom_yaml_policy_parser.cc", + "yaml_policy_parser_test.cc", + ], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":yaml_policy_parser", + "//common:source", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + ":compiler", + ":yaml_policy_parser", + "//common:ast", + "//common:decl", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@yaml-cpp", + ], +) diff --git a/policy/cel_policy.cc b/policy/cel_policy.cc new file mode 100644 index 000000000..c2d97edeb --- /dev/null +++ b/policy/cel_policy.cc @@ -0,0 +1,273 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +std::string IdDebugString(CelPolicyElementId id) { + if (id == -1) { + return ""; + } + return absl::StrCat("#", id, "> "); +} + +std::string IndentBlock(absl::string_view text) { + if (text.empty()) { + return ""; + } + std::vector lines; + for (absl::string_view line : absl::StrSplit(text, '\n')) { + if (line.empty()) { + lines.push_back(""); + } else { + lines.push_back(absl::StrCat(" ", line)); + } + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +void CelPolicySource::NoteSourcePosition(CelPolicyElementId id, + SourcePosition position) { + source_positions_[id] = position; +} + +std::optional CelPolicySource::GetSourcePosition( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return it->second; +} + +std::optional CelPolicySource::GetSourceLocation( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return policy_source_->GetLocation(it->second); +} + +std::string CelPolicySource::DebugString() const { + std::string result; + + // Sort the source elements in descending order of position + std::vector> sorted_positions; + for (const auto& pair : source_positions_) { + sorted_positions.push_back(pair); + } + std::sort(sorted_positions.begin(), sorted_positions.end(), + [](const auto& a, const auto& b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + }); + + result = policy_source_->content().ToString(); + for (const auto& [id, position] : sorted_positions) { + result.insert(position, IdDebugString(id)); + } + return result; +} + +std::string ValueString::DebugString() const { + return absl::StrCat(IdDebugString(id_), "\"", value_, "\""); +} + +std::string Import::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "name: ", name_.DebugString()); + return result; +} + +std::string OutputBlock::DebugString() const { + std::string result; + absl::StrAppend(&result, "output: ", output_.DebugString()); + if (explanation_.has_value()) { + absl::StrAppend(&result, "\nexplanation: ", explanation_->DebugString()); + } + return result; +} + +Match::Match(const Match& other) + : id_(other.id_), condition_(other.condition_) { + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = + std::make_unique(*std::get>(other.result_)); + } +} + +Match& Match::operator=(const Match& other) { + if (this != &other) { + id_ = other.id_; + condition_ = other.condition_; + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = std::make_unique( + *std::get>(other.result_)); + } + } + return *this; +} + +std::string Match::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "match: {\n"); + if (condition_.has_value()) { + absl::StrAppend(&result, " condition: ", condition_->DebugString(), "\n"); + } + if (has_rule()) { + absl::StrAppend(&result, " result:\n", + IndentBlock(IndentBlock(rule().DebugString())), "\n"); + } else { + absl::StrAppend(&result, " result: {\n", + IndentBlock(IndentBlock(output_block().DebugString())), + "\n }\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Variable::DebugString() const { + std::string result; + absl::StrAppend(&result, "variable: {\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + absl::StrAppend(&result, " expression: ", expression_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Rule::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "rule: {\n"); + if (rule_id_.has_value()) { + absl::StrAppend(&result, " rule_id: ", rule_id_->DebugString(), "\n"); + } + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + for (const Variable& variable : variables_) { + absl::StrAppend(&result, IndentBlock(variable.DebugString()), "\n"); + } + for (const Match& match : matches_) { + absl::StrAppend(&result, IndentBlock(match.DebugString()), "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string MetadataValueDebugString(std::any value) { + if (value.type() == typeid(std::monostate)) { + return "null"; + } + if (value.type() == typeid(ValueString)) { + return std::any_cast(value).DebugString(); + } + if (value.type() == typeid(bool)) { + return std::any_cast(value) ? "true" : "false"; + } + if (value.type() == typeid(int)) { + return absl::StrCat(std::any_cast(value)); + } + if (value.type() == typeid(std::string)) { + return std::any_cast(value); + } + return absl::StrCat("typeid: ", value.type().name()); +} + +std::string CelPolicy::DebugString() const { + std::string result; + absl::StrAppend(&result, "CelPolicy{\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, IndentBlock(IndentBlock(source_->DebugString())), + "\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + if (!metadata_.empty()) { + std::vector sorted_keys; + for (const auto& [key, _] : metadata_) { + sorted_keys.push_back(key); + } + std::sort(sorted_keys.begin(), sorted_keys.end()); + + absl::StrAppend(&result, " metadata: {\n"); + for (const auto& key : sorted_keys) { + const auto& value = metadata_.at(key); + absl::StrAppend(&result, " ", key, ": ", + MetadataValueDebugString(value), "\n"); + } + absl::StrAppend(&result, " }\n"); + } + if (!imports_.empty()) { + absl::StrAppend(&result, " imports:\n"); + for (const Import& import : imports_) { + absl::StrAppend(&result, " ", import.DebugString(), "\n"); + } + } + absl::StrAppend(&result, IndentBlock(rule_.DebugString()), "\n"); + absl::StrAppend(&result, "}"); + return result; +} + +} // namespace cel diff --git a/policy/cel_policy.h b/policy/cel_policy.h new file mode 100644 index 000000000..af8f7c977 --- /dev/null +++ b/policy/cel_policy.h @@ -0,0 +1,320 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +using CelPolicyElementId = int32_t; + +class CelPolicySource { + public: + explicit CelPolicySource(cel::SourcePtr policy_source) + : policy_source_(std::move(policy_source)) {} + + const Source* absl_nonnull content() const { return policy_source_.get(); } + + void NoteSourcePosition(CelPolicyElementId id, SourcePosition position); + + std::optional GetSourcePosition(CelPolicyElementId id) const; + + std::optional GetSourceLocation(CelPolicyElementId id) const; + + std::string DebugString() const; + + private: + cel::SourcePtr policy_source_; + absl::flat_hash_map source_positions_; +}; + +class ValueString { + public: + ValueString() : id_(-1) {} + + explicit ValueString(CelPolicyElementId id, absl::string_view value) + : id_(id), value_(value) {} + + CelPolicyElementId id() const { return id_; } + absl::string_view value() const { return value_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + std::string value_; +}; + +class Import { + public: + Import(CelPolicyElementId id, ValueString name) + : id_(id), name_(std::move(name)) {} + CelPolicyElementId id() const { return id_; } + const ValueString& name() const { return name_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + ValueString name_; +}; + +// Defines a variable that can be used in CEL expressions within the policy. +// Variables are evaluated once and stored in the activation context. +class Variable { + public: + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + const ValueString& expression() const { return expression_; } + void set_expression(ValueString expression) { + expression_ = std::move(expression); + } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + + std::string DebugString() const; + + private: + ValueString name_; + ValueString expression_; + std::optional description_; + std::optional display_name_; +}; + +class Rule; + +class OutputBlock { + public: + OutputBlock() = default; + OutputBlock(ValueString output, std::optional explanation) + : output_(std::move(output)), explanation_(std::move(explanation)) {} + + const ValueString& output() const { return output_; } + void set_output(ValueString output) { output_ = std::move(output); } + + const std::optional& explanation() const { return explanation_; } + void set_explanation(ValueString explanation) { + explanation_ = std::move(explanation); + } + + std::string DebugString() const; + + private: + ValueString output_; + std::optional explanation_; +}; + +// Defines a match condition and result. +// If the result is a Rule, it is considered a sub-rule and will be evaluated +// only if the match condition evaluates to true. +class Match { + public: + Match() = default; + Match(const Match& other); + Match& operator=(const Match& other); + + CelPolicyElementId id() const; + void set_id(CelPolicyElementId id); + + bool has_condition() const; + std::optional condition() const; + void set_condition(ValueString condition); + + bool has_output_block() const; + const OutputBlock& output_block() const; + OutputBlock& mutable_output_block(); + + bool has_rule() const; + const Rule& rule() const; + Rule& mutable_rule(); + + void set_result(OutputBlock result); + void set_result(std::unique_ptr result); + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional condition_; + std::variant> result_; +}; + +// Rule is the body of the policy and contains a list of variables and matches. +// Variables are evaluated once and stored in the activation context. +// Matches are evaluated in order and the first match is returned. If the +// match contains a sub-rule, the sub-rule is evaluated only if the match +// condition evaluates to true. +class Rule { + public: + Rule() = default; + Rule(const Rule& other) = default; + + CelPolicyElementId id() const { return id_; } + void set_id(CelPolicyElementId id) { id_ = id; } + + const std::optional& rule_id() const { return rule_id_; } + void set_rule_id(ValueString rule_id) { rule_id_ = std::move(rule_id); } + + const std::optional& description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + const std::vector& variables() const { return variables_; } + std::vector& mutable_variables() { return variables_; } + + const std::vector& matches() const { return matches_; } + std::vector& mutable_matches() { return matches_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional rule_id_; + std::optional description_; + std::vector variables_; + std::vector matches_; +}; + +// CelPolicy is the top-level policy object. +// It contains a source, name, description, display name, imports, and a rule. +// The source is the CEL policy source code. +// The name, description, and display name are metadata about the policy. +// The rule is the main body of the policy. +class CelPolicy { + public: + explicit CelPolicy(std::shared_ptr source) + : source_(std::move(source)) {} + + CelPolicy(const CelPolicy& other) = default; + CelPolicy& operator=(const CelPolicy& other) = default; + + const CelPolicySource* absl_nullable source() const { return source_.get(); } + const std::shared_ptr& source_ptr() const { return source_; } + + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + const absl::flat_hash_map& metadata() const { + return metadata_; + } + absl::flat_hash_map& mutable_metadata() { + return metadata_; + } + const std::vector& imports() const { return imports_; } + std::vector& mutable_imports() { return imports_; } + + const Rule& rule() const { return rule_; } + Rule& mutable_rule() { return rule_; } + + std::string DebugString() const; + + private: + std::shared_ptr source_; + ValueString name_; + std::optional description_; + std::optional display_name_; + absl::flat_hash_map metadata_; + std::vector imports_; + Rule rule_; +}; + +// Implementation details. + +inline CelPolicyElementId Match::id() const { return id_; } +inline void Match::set_id(CelPolicyElementId id) { id_ = id; } + +inline bool Match::has_condition() const { return condition_.has_value(); } + +inline std::optional Match::condition() const { + return condition_; +} + +inline void Match::set_condition(ValueString condition) { + condition_ = std::move(condition); +} + +inline bool Match::has_output_block() const { + return std::holds_alternative(result_); +} + +inline const OutputBlock& Match::output_block() const { + ABSL_DCHECK(std::holds_alternative(result_)); + return std::get(result_); +} + +inline OutputBlock& Match::mutable_output_block() { + if (!std::holds_alternative(result_)) { + result_ = OutputBlock(); + } + return std::get(result_); +} + +inline bool Match::has_rule() const { + return std::holds_alternative>(result_); +} + +inline const Rule& Match::rule() const { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline Rule& Match::mutable_rule() { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline void Match::set_result(OutputBlock result) { + result_ = std::move(result); +} + +inline void Match::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ diff --git a/policy/cel_policy_parse_context.cc b/policy/cel_policy_parse_context.cc new file mode 100644 index 000000000..66861d085 --- /dev/null +++ b/policy/cel_policy_parse_context.cc @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_context.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +CelPolicy& CelPolicyParseContext::policy() const { + ABSL_CHECK(policy_ != nullptr) + << "CelPolicyParseContext::policy() called after GetResult()"; + return *policy_; +} + +CelPolicyParseResult CelPolicyParseContext::GetResult() { + if (policy_ != nullptr && issues_.empty()) { + return CelPolicyParseResult(std::move(policy_source_), std::move(policy_), + std::move(issues_)); + } + policy_.reset(); + return CelPolicyParseResult(std::move(policy_source_), nullptr, + std::move(issues_)); +} + +void CelPolicyParseContext::ReportError(CelPolicyElementId element_id, + std::string_view message) { + issues_.push_back(CelPolicyIssue(element_id, std::string(message))); +} + +} // namespace cel diff --git a/policy/cel_policy_parse_context.h b/policy/cel_policy_parse_context.h new file mode 100644 index 000000000..6482fa1ae --- /dev/null +++ b/policy/cel_policy_parse_context.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// A mutable context for parsing a CelPolicy. An instance of this class is +// created for each policy parse and is passed to the parser, which is meant to +// be stateless. +// +// Parsers call methods on this class to report issues and populate the policy +// being parsed. Call GetResult() to obtain the resulting CelPolicyParseResult, +// which takes ownership of the parsed policy. Do not use the context after +// calling GetResult(). +class CelPolicyParseContext { + public: + explicit CelPolicyParseContext(std::shared_ptr policy_source) + : policy_source_(std::move(policy_source)), + policy_(std::make_unique(policy_source_)) {} + + CelPolicySource& policy_source() const { return *policy_source_; } + + // Returns the policy being parsed. It should not be used after + // calling GetResult(). + CelPolicy& policy() const; + + // The context should not be used after calling GetResult(). + CelPolicyParseResult GetResult(); + + // Reports an error for the given element with the given error message. + void ReportError(CelPolicyElementId id, std::string_view message); + + CelPolicyElementId next_element_id() { return next_element_id_++; } + + private: + std::shared_ptr policy_source_; + CelPolicyElementId next_element_id_ = 0; + std::vector issues_; + std::unique_ptr policy_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ diff --git a/policy/cel_policy_parse_result.cc b/policy/cel_policy_parse_result.cc new file mode 100644 index 000000000..32d6431bb --- /dev/null +++ b/policy/cel_policy_parse_result.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_result.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { +namespace { + +absl::string_view SeverityString(CelPolicyIssue::Severity severity) { + switch (severity) { + case CelPolicyIssue::Severity::kInformation: + return "INFORMATION"; + case CelPolicyIssue::Severity::kWarning: + return "WARNING"; + case CelPolicyIssue::Severity::kError: + return "ERROR"; + case CelPolicyIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string CelPolicyIssue::ToDisplayString( + const CelPolicySource* absl_nullable source) const { + SourceLocation location; + std::string description; + std::string snippet; + if (source != nullptr) { + if (relative_position_) { + std::optional base = + source->GetSourcePosition(element_id_); + if (element_id_ == -1) { + base.emplace(0); + } + if (base) { + location = source->content() + ->GetLocation(*base + *relative_position_) + .value_or(SourceLocation{}); + } + } else { + location = + source->GetSourceLocation(element_id_).value_or(SourceLocation{}); + } + description = std::string(source->content()->description()); + snippet = source->content()->DisplayErrorLocation(location); + } + + const int display_column = location.column >= 0 ? location.column + 1 : -1; + + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + description, location.line, display_column, message_, + snippet); +} + +std::string CelPolicyParseResult::FormattedIssues() const { + std::string formatted_issues; + for (const CelPolicyIssue& issue : issues_) { + if (!formatted_issues.empty()) { + absl::StrAppend(&formatted_issues, "\n"); + } + absl::StrAppend(&formatted_issues, issue.ToDisplayString(*policy_source_)); + } + return formatted_issues; +} + +} // namespace cel diff --git a/policy/cel_policy_parse_result.h b/policy/cel_policy_parse_result.h new file mode 100644 index 000000000..2bf80b1ce --- /dev/null +++ b/policy/cel_policy_parse_result.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { + +class CelPolicyIssue { + public: + enum class Severity { kInformation, kDeprecated, kWarning, kError }; + + CelPolicyIssue(CelPolicyElementId element_id, absl::string_view message) + : element_id_(element_id), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, Severity severity, + absl::string_view message) + : element_id_(element_id), severity_(severity), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, Severity severity, + absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + severity_(severity), + message_(message) {} + + std::string ToDisplayString( + const CelPolicySource* absl_nullable source) const; + std::string ToDisplayString(const CelPolicySource& source) const { + return ToDisplayString(&source); + } + + Severity severity() const { return severity_; } + absl::string_view message() const { return message_; } + + private: + CelPolicyElementId element_id_; + std::optional relative_position_; + Severity severity_ = Severity::kError; + std::string message_; +}; + +class CelPolicyParseResult { + public: + explicit CelPolicyParseResult(std::shared_ptr policy_source, + std::unique_ptr policy, + std::vector issues) + : policy_source_(std::move(policy_source)), + policy_(std::move(policy)), + issues_(std::move(issues)) {} + + bool IsValid() const { return policy_ != nullptr; } + + const CelPolicy* absl_nullable GetPolicy() const { return policy_.get(); } + + absl::StatusOr> ReleasePolicy() { + if (policy_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyParseResult is empty. Check for Issues."); + } + return std::move(policy_); + } + + absl::Span GetIssues() const { return issues_; } + + std::string FormattedIssues() const; + + private: + std::shared_ptr policy_source_; + absl_nullable std::unique_ptr policy_; + std::vector issues_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ diff --git a/policy/cel_policy_parser.h b/policy/cel_policy_parser.h new file mode 100644 index 000000000..0a11c9e68 --- /dev/null +++ b/policy/cel_policy_parser.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ + +#include "absl/status/status.h" +#include "policy/cel_policy_parse_context.h" + +namespace cel { + +// A policy parser for a given policy format. The type `T` parameter is the +// representation of the input file format, such as `` for YAML. +// +// Parsers are intended to be stateless: all state, including the resulting +// policy and any issues encountered, should be kept in the context passed to +// the `ParsePolicy` method. +template +class CelPolicyParser { + public: + virtual ~CelPolicyParser() = default; + + // Parses the input and populates a CelPolicy in the context. + virtual absl::Status ParsePolicy(CelPolicyParseContext& ctx) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ diff --git a/policy/cel_policy_test.cc b/policy/cel_policy_test.cc new file mode 100644 index 000000000..640247e7f --- /dev/null +++ b/policy/cel_policy_test.cc @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Field; +using testing::Optional; +using testing::SizeIs; + +TEST(CelPolicyBuilderTest, Build) { + CelPolicyElementId next_id = 1; + ASSERT_OK_AND_ASSIGN(SourcePtr source, NewSource("CEL\n policy\n source")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CelPolicy policy(policy_source); + policy.set_name(ValueString(next_id++, "test_policy")); + policy.set_description(ValueString(next_id++, "test_description")); + policy.set_display_name(ValueString(next_id++, "test_display_name")); + ValueString import1_name = ValueString(next_id++, "test_import1"); + policy.mutable_imports().push_back(Import(next_id++, import1_name)); + ValueString import2_name = ValueString(next_id++, "test_import2"); + policy.mutable_imports().push_back(Import(next_id++, import2_name)); + + Rule& rule = policy.mutable_rule(); + rule.set_id(next_id++); + rule.set_rule_id(ValueString(next_id++, "test_rule_id")); + rule.set_description(ValueString(next_id++, "test_rule_description")); + + Variable variable; + variable.set_name(ValueString(next_id++, "test_variable")); + variable.set_expression(ValueString(next_id++, "test_expression")); + variable.set_description(ValueString(next_id++, "test_variable_description")); + variable.set_display_name( + ValueString(next_id++, "test_variable_display_name")); + + Match match1; + match1.set_id(next_id++); + match1.set_condition(ValueString(next_id++, "test_condition")); + CelPolicyElementId output_id = next_id++; + CelPolicyElementId explanation_id = next_id++; + match1.set_result( + OutputBlock(ValueString(output_id, "test_result"), + ValueString(explanation_id, "test_explanation"))); + + Match match2; + match2.set_id(next_id++); + match2.set_condition(ValueString(next_id++, "test_condition2")); + + auto sub_rule = std::make_unique(); + sub_rule->set_id(next_id++); + sub_rule->set_rule_id(ValueString(next_id++, "sub_rule_id")); + sub_rule->set_description(ValueString(next_id++, "sub_rule_description")); + Match sub_rule_match; + sub_rule_match.set_id(next_id++); + sub_rule_match.set_condition(ValueString(next_id++, "sub_rule_condition")); + sub_rule_match.set_result( + OutputBlock(ValueString(next_id++, "sub_rule_result"), std::nullopt)); + sub_rule->mutable_matches().push_back(sub_rule_match); + + match2.set_result(std::move(sub_rule)); + + rule.mutable_variables().push_back(variable); + rule.mutable_matches().push_back(match1); + rule.mutable_matches().push_back(match2); + + EXPECT_EQ(policy.name().value(), "test_policy"); + ASSERT_TRUE(policy.description().has_value()); + EXPECT_EQ(policy.description()->value(), "test_description"); + ASSERT_TRUE(policy.display_name().has_value()); + EXPECT_EQ(policy.display_name()->value(), "test_display_name"); + + ASSERT_THAT(policy.imports(), SizeIs(2)); + + EXPECT_EQ(policy.imports()[0].name().value(), "test_import1"); + EXPECT_EQ(policy.imports()[1].name().value(), "test_import2"); + ASSERT_TRUE(policy.rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().rule_id()->value(), "test_rule_id"); + ASSERT_TRUE(policy.rule().description().has_value()); + EXPECT_EQ(policy.rule().description()->value(), "test_rule_description"); + + ASSERT_THAT(policy.rule().variables(), SizeIs(1)); + + EXPECT_EQ(policy.rule().variables()[0].name().value(), "test_variable"); + EXPECT_EQ(policy.rule().variables()[0].expression().value(), + "test_expression"); + ASSERT_TRUE(policy.rule().variables()[0].description().has_value()); + EXPECT_EQ(policy.rule().variables()[0].description()->value(), + "test_variable_description"); + ASSERT_TRUE(policy.rule().variables()[0].display_name().has_value()); + EXPECT_EQ(policy.rule().variables()[0].display_name()->value(), + "test_variable_display_name"); + + ASSERT_THAT(policy.rule().matches(), SizeIs(2)); + + EXPECT_EQ(policy.rule().matches()[0].condition().value().value(), + "test_condition"); + ASSERT_TRUE(policy.rule().matches()[0].has_output_block()); + EXPECT_EQ(policy.rule().matches()[0].output_block().output().value(), + "test_result"); + ASSERT_TRUE( + policy.rule().matches()[0].output_block().explanation().has_value()); + EXPECT_EQ(policy.rule().matches()[0].output_block().explanation()->value(), + "test_explanation"); + + EXPECT_EQ(policy.rule().matches()[1].condition().value().value(), + "test_condition2"); + ASSERT_TRUE(policy.rule().matches()[1].has_rule()); + ASSERT_TRUE(policy.rule().matches()[1].rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().rule_id()->value(), + "sub_rule_id"); + ASSERT_TRUE(policy.rule().matches()[1].rule().description().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().description()->value(), + "sub_rule_description"); + ASSERT_THAT(policy.rule().matches()[1].rule().matches(), SizeIs(1)); + EXPECT_EQ(policy.rule() + .matches()[1] + .rule() + .matches()[0] + .condition() + .value() + .value(), + "sub_rule_condition"); + + std::string actual = policy.DebugString(); + EXPECT_EQ(actual, absl::StrReplaceAll(R"(CelPolicy{ + =========================================================== + CEL + policy + source + =========================================================== + name: #1> "test_policy" + description: #2> "test_description" + display_name: #3> "test_display_name" + imports: + #5> name: #4> "test_import1" + #7> name: #6> "test_import2" + #8> rule: { + rule_id: #9> "test_rule_id" + description: #10> "test_rule_description" + variable: { + name: #11> "test_variable" + expression: #12> "test_expression" + description: #13> "test_variable_description" + display_name: #14> "test_variable_display_name" + } + #15> match: { + condition: #16> "test_condition" + result: { + output: #17> "test_result" + explanation: #18> "test_explanation" + } + } + #19> match: { + condition: #20> "test_condition2" + result: + #21> rule: { + rule_id: #22> "sub_rule_id" + description: #23> "sub_rule_description" + #24> match: { + condition: #25> "sub_rule_condition" + result: { + output: #26> "sub_rule_result" + } + } + } + } + } + })", + {{"\n ", "\n"}})); +} + +TEST(CelPolicySourceTest, Build) { + std::string source = + "name: test_policy\n imports:\n - name: test_import\n"; + + ASSERT_OK_AND_ASSIGN(SourcePtr source_ptr, NewSource(source)); + CelPolicySource policy_source(std::move(source_ptr)); + policy_source.NoteSourcePosition(1, source.find("test_policy")); + policy_source.NoteSourcePosition(2, source.find("test_import")); + + EXPECT_THAT(policy_source.GetSourcePosition(1), Optional(6)); + EXPECT_THAT(policy_source.GetSourceLocation(1), + Optional(AllOf(Field(&SourceLocation::line, 1), + Field(&SourceLocation::column, 6)))); + EXPECT_THAT(policy_source.GetSourcePosition(2), Optional(44)); + EXPECT_THAT(policy_source.GetSourceLocation(2), + Optional(AllOf(Field(&SourceLocation::line, 3), + Field(&SourceLocation::column, 13)))); + EXPECT_EQ(policy_source.GetSourcePosition(3), std::nullopt); + EXPECT_EQ(policy_source.GetSourceLocation(3), std::nullopt); + EXPECT_EQ( + policy_source.DebugString(), + "name: #1> test_policy\n imports:\n - name: #2> test_import\n"); +} + +} // namespace +} // namespace cel diff --git a/policy/cel_policy_validation_result.cc b/policy/cel_policy_validation_result.cc new file mode 100644 index 000000000..e257f064c --- /dev/null +++ b/policy/cel_policy_validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +std::string CelPolicyValidationResult::FormatIssues() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const CelPolicyIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/policy/cel_policy_validation_result.h b/policy/cel_policy_validation_result.h new file mode 100644 index 000000000..bddb9a3ca --- /dev/null +++ b/policy/cel_policy_validation_result.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// CelPolicyValidationResult holds the result of policy compilation. +// +// Policy compilation/validation errors are captured in issues. +class CelPolicyValidationResult { + public: + CelPolicyValidationResult( + std::unique_ptr ast, std::vector issues, + std::shared_ptr source = nullptr) + : ast_(std::move(ast)), + issues_(std::move(issues)), + source_(std::move(source)) {} + + explicit CelPolicyValidationResult( + std::vector issues, + std::shared_ptr source = nullptr) + : ast_(nullptr), issues_(std::move(issues)), source_(std::move(source)) {} + + // Returns true if validation succeeded and an AST is present. + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + // Moves out and returns the AST. + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyValidationResult is empty. Check for CelPolicyIssues."); + } + return std::move(ast_); + } + + // Returns the list of issues encountered during compilation. + absl::Span GetIssues() const { return issues_; } + + // Returns the contained policy source, if any. + const CelPolicySource* absl_nullable GetSource() const { + return source_.get(); + } + + // Returns a formatted error string of the compiled issues. + std::string FormatIssues() const; + + private: + absl_nullable std::unique_ptr ast_; + std::vector issues_; + std::shared_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ diff --git a/policy/compiler.cc b/policy/compiler.cc new file mode 100644 index 000000000..7a892447c --- /dev/null +++ b/policy/compiler.cc @@ -0,0 +1,1058 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/internal/issue_reporter.h" +#include "policy/internal/optimizer_expr_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +constexpr absl::string_view kCelBlock = "cel.@block"; + +enum class RuleSemantics { + // TODO(b/506179116): will also need "aggregate" or similar concept. + kFirstMatch, + + kNotForUseWithExhaustiveSwitchStatements, +}; + +template +void AbslStringify(Sink& s, RuleSemantics semantics) { + switch (semantics) { + case RuleSemantics::kFirstMatch: + s.Append("first_match"); + return; + default: + s.Append(""); + return; + } +} + +struct EmbeddedAst { + CelPolicyElementId id; + std::unique_ptr ast; +}; + +struct CompiledVariable { + std::string ident; + EmbeddedAst ast; +}; + +struct CompiledOutputBlock { + EmbeddedAst output_ast; + cel::Type result_type; + std::optional explanation_ast; +}; + +struct CompiledRule; + +struct CompiledMatch { + using Production = + std::variant absl_nonnull, + CompiledOutputBlock>; + + CelPolicyElementId id; + std::optional condition; + Production production; +}; + +struct CompiledRule { + CelPolicyElementId id; + std::vector variables; + std::vector matches; + // Not set if cannot be determined. + std::optional result_type; +}; + +std::optional GetOutputType( + const CompiledMatch::Production& production) { + return std::visit( + [](const auto& production) -> std::optional { + if constexpr (std::is_same_v, + CompiledOutputBlock>) { + return production.result_type; + } else if constexpr (std::is_same_v, + std::unique_ptr>) { + return production->result_type; + } + return std::nullopt; + }, + production); +} + +// Internal representation of the compiled policy elements. +// +// This is used for checking the component expression before composing into the +// final AST based on the provided rule semantics. +class IntermediateCompiledPolicy { + public: + CompiledRule& mutable_root_rule() { return root_rule_; } + + const CompiledRule& root_rule() const { return root_rule_; } + + void set_name(absl::string_view name) { name_ = name; } + absl::string_view name() const { return name_; } + void set_display_name(absl::string_view display_name) { + display_name_ = display_name; + } + absl::string_view display_name() const { return display_name_; } + void set_description(absl::string_view description) { + description_ = description; + } + absl::string_view description() const { return description_; } + + void set_semantics(RuleSemantics semantics) { semantics_ = semantics; } + RuleSemantics semantics() const { return semantics_; } + + private: + std::string name_; + std::string display_name_; + std::string description_; + RuleSemantics semantics_ = RuleSemantics::kFirstMatch; + + CompiledRule root_rule_; +}; + +CelPolicyIssue::Severity MapSeverity(cel::TypeCheckIssue::Severity severity) { + switch (severity) { + case cel::TypeCheckIssue::Severity::kError: + return CelPolicyIssue::Severity::kError; + case cel::TypeCheckIssue::Severity::kWarning: + return CelPolicyIssue::Severity::kWarning; + case cel::TypeCheckIssue::Severity::kDeprecated: + return CelPolicyIssue::Severity::kDeprecated; + default: + return CelPolicyIssue::Severity::kError; + } +} + +bool IsWrapperOf(cel::TypeKind wrapper_kind, cel::TypeKind primitive_kind) { + switch (wrapper_kind) { + case cel::TypeKind::kBoolWrapper: + return primitive_kind == cel::TypeKind::kBool; + case cel::TypeKind::kIntWrapper: + return primitive_kind == cel::TypeKind::kInt; + case cel::TypeKind::kUintWrapper: + return primitive_kind == cel::TypeKind::kUint; + case cel::TypeKind::kDoubleWrapper: + return primitive_kind == cel::TypeKind::kDouble; + case cel::TypeKind::kStringWrapper: + return primitive_kind == cel::TypeKind::kString; + case cel::TypeKind::kBytesWrapper: + return primitive_kind == cel::TypeKind::kBytes; + default: + return false; + } +} + +cel::Type FilterSpecialTypes(cel::Type type) { + if (type.IsTypeParam()) { + // Free type param should not appear in the output type, but if it does, + // force it to dyn. + return DynType(); + } + if (type.IsEnum()) { + return IntType{}; + } + if (type.IsError()) { + return DynType(); + } + if (type.IsType()) { + // drop parameters so all type types are compatible. + return TypeType{}; + } + return type; +} + +// Returns true if `from` is assignable to `to`. +// +// Slightly adjusted from the standard routine to cover some edge cases around +// null and wrappers. +// +// TODO(b/522391716): try to standardize assignability checks. +bool OutputTypeIsAssignable(cel::Type from, cel::Type to) { + from = FilterSpecialTypes(from); + to = FilterSpecialTypes(to); + + // Any and dyn are assignable to/from everything. + if (from.kind() == cel::TypeKind::kAny || + from.kind() == cel::TypeKind::kDyn || to.kind() == cel::TypeKind::kAny || + to.kind() == cel::TypeKind::kDyn) { + return true; + } + + // Wrappers auto-unwrap. + if (IsWrapperOf(from.kind(), to.kind()) || + IsWrapperOf(to.kind(), from.kind())) { + return true; + } + + // Null is assignable to anything that is message-like. + if (from.kind() == cel::TypeKind::kNull) { + switch (to.kind()) { + case cel::TypeKind::kNull: + case cel::TypeKind::kStruct: + case cel::TypeKind::kOpaque: + case cel::TypeKind::kTimestamp: + case cel::TypeKind::kDuration: + case cel::TypeKind::kBytesWrapper: + case cel::TypeKind::kBoolWrapper: + case cel::TypeKind::kIntWrapper: + case cel::TypeKind::kUintWrapper: + case cel::TypeKind::kDoubleWrapper: + case cel::TypeKind::kStringWrapper: + return true; + default: + return false; + } + } + + if (from.kind() != to.kind()) { + return false; + } + + if (from.name() != to.name()) { + return false; + } + + if (from.GetParameters().size() != to.GetParameters().size()) { + return false; + } + + for (int i = 0; i < from.GetParameters().size(); ++i) { + if (!OutputTypeIsAssignable(from.GetParameters()[i], + to.GetParameters()[i])) { + return false; + } + } + + return true; +} + +bool OutputTypeIsCompatible(cel::Type from, cel::Type to) { + // We don't handle widening like in a self-contained CEL expression, but + // permit some cases where one type is more specific than the other. + return OutputTypeIsAssignable(from, to) || OutputTypeIsAssignable(to, from); +} + +bool HasErrors(const policy_internal::IssueReporter& issues) { + for (const auto& issue : issues.issues()) { + if (issue.severity() == CelPolicyIssue::Severity::kError) { + return true; + } + } + return false; +} + +// Note on lifetime safety: +// +// The output policy will contain references to types that are owned by the +// arena member of this class. This is safe as long as the policy compiler lives +// as long as the output policies. +class PolicyCompiler { + public: + explicit PolicyCompiler(policy_internal::IssueReporter* issues, + std::unique_ptr base_compiler) + : issues_(*issues), base_compiler_(std::move(base_compiler)) {} + + absl::string_view GetSourceDescription() const { + if (src_ == nullptr) { + return ""; + } + return src_->content()->description(); + } + + void AdaptTypeCheckIssues(CelPolicyElementId id, const ValidationResult& r) { + const Source* source = r.GetSource(); + + for (const auto& iss : r.GetIssues()) { + std::optional offset; + if (source != nullptr) { + offset = source->GetPosition(iss.location()); + } + if (offset.has_value()) { + issues_.ReportOffsetIssue(id, offset.value(), + MapSeverity(iss.severity()), iss.message()); + continue; + } + issues_.ReportIssue(id, MapSeverity(iss.severity()), iss.message()); + } + } + + absl::StatusOr CompileOutputBlock( + const cel::OutputBlock& output_block, const Compiler* env) { + CompiledOutputBlock output; + CEL_ASSIGN_OR_RETURN(auto output_validation, + env->Compile(output_block.output().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.output().id(), output_validation); + + cel::Type result_type = DynType(); + if (output_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, output_validation.ReleaseAst()); + auto root_expr_id = ast->root_expr().id(); + output.output_ast = + EmbeddedAst{output_block.output().id(), std::move(ast)}; + if (auto it = output_validation.GetResolvedTypeMap().find(root_expr_id); + it != output_validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + } + if (output_block.explanation().has_value()) { + CEL_ASSIGN_OR_RETURN(auto explanation_validation, + env->Compile(output_block.explanation()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.explanation()->id(), + explanation_validation); + if (explanation_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, explanation_validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kString) { + issues_.ReportError(output_block.explanation()->id(), + "explanation must evaluate to string"); + } else { + output.explanation_ast = + EmbeddedAst{output_block.explanation()->id(), std::move(ast)}; + } + } + } + output.result_type = result_type; + return output; + } + + absl::Status CompileMatch(const Match& match, const Compiler* env, + CompiledRule* out) { + CompiledMatch c_match; + c_match.id = match.id(); + if (match.condition().has_value()) { + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(match.condition()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(match.condition()->id(), validation); + if (validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kBool) { + issues_.ReportError(match.condition()->id(), + "condition must evaluate to bool"); + } + c_match.condition = + EmbeddedAst{match.condition()->id(), std::move(ast)}; + } + } + + if (match.has_output_block()) { + CEL_ASSIGN_OR_RETURN(c_match.production, + CompileOutputBlock(match.output_block(), env)); + } else if (match.has_rule()) { + auto rule = std::make_unique(); + CEL_RETURN_IF_ERROR(CompileRule(match.rule(), env, rule.get())); + c_match.production = std::move(rule); + } else { + issues_.ReportError(match.id(), "match must specify an output or rule"); + } + out->matches.push_back(std::move(c_match)); + return absl::OkStatus(); + } + + absl::Status CompileRule(const Rule& rule, const cel::Compiler* env, + CompiledRule* out) { + out->id = rule.id(); + std::unique_ptr buf; + + absl::flat_hash_set seen_variables; + for (const auto& variable : rule.variables()) { + std::string name(variable.name().value()); + if (!seen_variables.insert(name).second) { + issues_.ReportError( + variable.expression().id(), + absl::StrCat("overlapping identifier for name 'variables.", name, + "'")); + continue; + } + std::string ident = absl::StrCat("variables.", name); + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(variable.expression().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(variable.expression().id(), validation); + if (!validation.IsValid()) { + continue; + } + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + cel::Type result_type = DynType(); + + if (auto it = validation.GetResolvedTypeMap().find(ast->root_expr().id()); + it != validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + out->variables.push_back(CompiledVariable{ + ident, + EmbeddedAst{variable.expression().id(), std::move(ast)}, + }); + auto next = env->ToBuilder(); + auto status = next->GetCheckerBuilder().AddOrReplaceVariable( + MakeVariableDecl(ident, result_type)); + if (!status.ok()) { + issues_.ReportError(variable.expression().id(), status.message()); + continue; + } + CEL_ASSIGN_OR_RETURN(buf, next->Build()); + env = buf.get(); + } + + std::optional overall_type; + for (const auto& match : rule.matches()) { + CEL_RETURN_IF_ERROR(CompileMatch(match, env, out)); + if (!overall_type.has_value()) { + overall_type = GetOutputType(out->matches.back().production); + continue; + } + + if (std::optional match_type = + GetOutputType(out->matches.back().production); + match_type.has_value()) { + if (!OutputTypeIsCompatible(*match_type, *overall_type)) { + issues_.ReportError( + match.id(), + absl::StrCat("incompatible output types: block has output type ", + FormatTypeName(*match_type), + ", but previous outputs have type ", + FormatTypeName(*overall_type))); + } + } + } + + out->result_type = overall_type; + return absl::OkStatus(); + } + + absl::Status CompilePolicy(const CelPolicy& policy, + IntermediateCompiledPolicy* out) { + src_ = policy.source(); + out->set_semantics(RuleSemantics::kFirstMatch); + out->set_name(policy.name().value()); + out->set_display_name( + policy.display_name().value_or(ValueString{}).value()); + out->set_description(policy.description().value_or(ValueString{}).value()); + + return CompileRule(policy.rule(), base_compiler_.get(), + &out->mutable_root_rule()); + } + + private: + google::protobuf::Arena arena_; + const CelPolicySource* absl_nullable src_; + policy_internal::IssueReporter& issues_; + std::unique_ptr base_compiler_; +}; + +bool IsExhaustive(const CompiledRule& rule); + +class FirstMatchComposer { + public: + FirstMatchComposer(const IntermediateCompiledPolicy& icp, + const Compiler& compiler, + policy_internal::IssueReporter& issues) + : issues_(issues), icp_(icp), compiler_(compiler) {} + + absl::Status Compose(); + + bool success() const { return ast_ != nullptr; } + + std::unique_ptr ReleaseAst() { return std::move(ast_); } + + private: + using VariableScope = absl::flat_hash_map; + + std::optional ResolvePolicyVariable(absl::string_view reference); + + absl::flat_hash_map ResolveBlockIndexes(const Ast& ast); + + bool CheckMatchStructure(const CompiledRule& rule); + + // Returns true if already optional wrapped. + absl::StatusOr ComposeRule(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + // returns true if already optional wrapped. + absl::StatusOr ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr); + + void MapVariables(Ast& ast); + + void ComposeRuleVariables(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + policy_internal::IssueReporter& issues_; + OptimizerExprFactory factory_; + const IntermediateCompiledPolicy& icp_; + const Compiler& compiler_; + std::vector scopes_; + bool optionalize_ = false; + std::unique_ptr ast_; +}; + +absl::Status FirstMatchComposer::Compose() { + ABSL_DCHECK(icp_.semantics() == RuleSemantics::kFirstMatch); + + factory_.mutable_ast().mutable_root_expr() = factory_.NewCall( + "cel.@block", factory_.NewList(), factory_.NewUnspecified()); + auto& block_init_list = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[0]; + auto& insertion_expr = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]; + optionalize_ = !IsExhaustive(icp_.root_rule()); + if (!CheckMatchStructure(icp_.root_rule())) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + bool optional_wrapped, + ComposeRule(icp_.root_rule(), block_init_list, insertion_expr)); + + if (optional_wrapped != optionalize_) { + return absl::InternalError( + "composition failed to handle non-exhaustive rules"); + } + + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, + compiler_.GetTypeChecker().Check(factory_.ast())); + if (!result.IsValid()) { + for (const auto& iss : result.GetIssues()) { + issues_.ReportError(icp_.root_rule().id, iss.message()); + } + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ast_, result.ReleaseAst()); + + return absl::OkStatus(); +} + +bool IsTriviallyTrueCondition(const CompiledMatch& match) { + if (!match.condition.has_value() || match.condition->ast == nullptr) { + return true; + } + const cel::Expr& expr = match.condition->ast->root_expr(); + if (expr.has_const_expr()) { + const cel::Constant& const_expr = expr.const_expr(); + if (const_expr.has_bool_value() && const_expr.bool_value()) { + return true; + } + } + return false; +} + +bool IsExhaustive(const CompiledRule& rule); + +bool IsExhaustive(const CompiledMatch& match) { + if (std::holds_alternative(match.production)) { + return true; + } + + const auto* nested_rule_ptr = + std::get_if>(&match.production); + ABSL_DCHECK(nested_rule_ptr != nullptr); + const CompiledRule& nested_rule = **nested_rule_ptr; + return IsExhaustive(nested_rule); +} + +bool IsExhaustive(const CompiledRule& rule) { + if (rule.matches.empty()) { + // Validation should fail, but generalization would be false. + return false; + } + bool has_default = false; + for (const auto& match : rule.matches) { + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + // If this isn't the last match in the rule, it should get flagged + // during validation since it means there are trivially unreachable + // matches. + has_default = true; + } + if (!IsTriviallyTrueCondition(match) && !IsExhaustive(match)) { + // There is a nested rule that might return an optional.none(). + return false; + } + } + // Otherwise, everything in this branch is exhaustive so we can defer + // wrapping. + return has_default; +} + +bool FirstMatchComposer::CheckMatchStructure(const CompiledRule& rule) { + if (rule.matches.empty()) { + issues_.ReportError(rule.id, "rule does not specify match conditions"); + return false; + } + + bool valid = true; + bool seen_trivially_true = false; + + for (const auto& match : rule.matches) { + if (seen_trivially_true) { + if (std::holds_alternative(match.production)) { + issues_.ReportError(match.id, "match creates unreachable outputs"); + } else if (std::holds_alternative>( + match.production)) { + issues_.ReportError(match.id, "rule creates unreachable outputs"); + } + valid = false; + } + + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + seen_trivially_true = true; + } + + if (auto* nested_rule = + std::get_if>(&match.production); + nested_rule != nullptr) { + ABSL_DCHECK(*nested_rule != nullptr); + if (!CheckMatchStructure(**nested_rule)) { + valid = false; + } + } + } + + return valid; +} + +std::optional FirstMatchComposer::ResolvePolicyVariable( + absl::string_view reference) { + for (auto scope_iter = scopes_.rbegin(); scope_iter != scopes_.rend(); + ++scope_iter) { + if (auto it = scope_iter->find(reference); it != scope_iter->end()) { + return it->second; + } + } + return std::nullopt; +} + +class IndexRewrite : public AstRewriterBase { + public: + explicit IndexRewrite(absl::flat_hash_map expr_id_to_index, + OptimizerExprFactory& factory) + : expr_id_to_index_(std::move(expr_id_to_index)), factory_(factory) {} + + bool PreVisitRewrite(Expr& e) override { + if (auto it = expr_id_to_index_.find(e.id()); + it != expr_id_to_index_.end()) { + e.mutable_ident_expr().set_name(absl::StrCat("@index", it->second)); + factory_.RecordReplacement(e.id(), e); + return true; + } + return false; + } + + private: + absl::flat_hash_map expr_id_to_index_; + OptimizerExprFactory& factory_; +}; + +absl::StatusOr FirstMatchComposer::ComposeRule(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + scopes_.emplace_back(); + auto pop_scope = absl::MakeCleanup([this]() { scopes_.pop_back(); }); + ComposeRuleVariables(rule, init, insertion_expr); + Expr* insertion_point = &insertion_expr; + const bool has_default = IsTriviallyTrueCondition(rule.matches.back()); + const bool needs_wrap = !IsExhaustive(rule); + size_t end = rule.matches.size() - (has_default ? 1 : 0); + for (size_t i = 0; i < end; i++) { + const auto& match = rule.matches[i]; + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + return absl::InternalError("detected unreachable match after validation"); + } + + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + + if (!IsTriviallyTrueCondition(match)) { + Ast condition = *match.condition->ast; + MapVariables(condition); + factory_.StartCopyContext(); + auto copy = factory_.Copy(condition.root_expr()); + auto source_info = factory_.RemapSourceInfo(condition.source_info()); + factory_.MergeSourceInfo(source_info); + *insertion_point = factory_.NewCall("_?_:_", std::move(copy)); + insertion_point->mutable_call_expr().mutable_args().push_back( + std::move(production)); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + insertion_point = &insertion_point->mutable_call_expr().add_args(); + continue; + } + + if (!is_wrapped) { + return absl::InternalError( + "composition failed. expected optional wrapped rule but got a plain " + "value"); + } + auto fn = needs_wrap ? "or" : "orValue"; + *insertion_point = factory_.NewMemberCall(fn, std::move(production)); + insertion_point = &insertion_point->mutable_call_expr().add_args(); + } + + if (has_default) { + const auto& match = rule.matches.back(); + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + *insertion_point = std::move(production); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + + return needs_wrap; + } + + // Otherwise, we fell through a non-exhaustive rule. + *insertion_point = factory_.NewCall("optional.none"); + return true; +} + +absl::StatusOr FirstMatchComposer::ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr) { + if (auto* nested_rule = + std::get_if>(&production); + nested_rule != nullptr) { + return ComposeRule(**nested_rule, init, insertion_expr); + } + auto* output = std::get_if(&production); + if (output == nullptr) { + return absl::InternalError("unexpected rule production type"); + } + const EmbeddedAst& output_ast = output->output_ast; + Ast ast = *output_ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + Expr to_insert = factory_.Copy(ast.root_expr()); + auto source_info = factory_.RemapSourceInfo(ast.source_info()); + factory_.MergeSourceInfo(source_info); + insertion_expr = std::move(to_insert); + + return false; +} + +absl::flat_hash_map FirstMatchComposer::ResolveBlockIndexes( + const Ast& ast) { + absl::flat_hash_map out; + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + it++) { + const Reference& ref = it->second; + if (!it->second.overload_id().empty()) { + continue; + } + if (!absl::StartsWith(ref.name(), "variable")) { + continue; + } + if (auto index = ResolvePolicyVariable(ref.name()); index.has_value()) { + out[it->first] = *index; + } + } + return out; +} + +void FirstMatchComposer::MapVariables(Ast& ast) { + absl::flat_hash_map edit_map = ResolveBlockIndexes(ast); + IndexRewrite rewriter(std::move(edit_map), factory_); + AstRewrite(ast.mutable_root_expr(), rewriter); +} + +void FirstMatchComposer::ComposeRuleVariables(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + for (const auto& variable : rule.variables) { + Ast ast = *variable.ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + auto insertion = factory_.Copy(ast.root_expr()); + // TODO(b/506179116): apply the position offsets here. + auto info = factory_.RemapSourceInfo(ast.source_info()); + ABSL_DCHECK(init.has_list_expr()); + int index = init.mutable_list_expr().elements().size(); + init.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(insertion))); + scopes_.back()[variable.ident] = index; + } +} + +bool HasComprehensionParent(const NavigableAstNode& node) { + const NavigableAstNode* curr = &node; + while (curr != nullptr) { + if (curr->node_kind() == NodeKind::kComprehension) { + return true; + } + curr = curr->parent(); + } + return false; +} + +// Unnester implementation. +class Unnester { + public: + Unnester(Ast ast, int height, policy_internal::IssueReporter& issues) + : factory_(std::move(ast)), height_(height), issues_(issues) {} + + // Run the unnesting. + // The class cannot be reused after this is called. + absl::StatusOr Unnest() { + if (height_ > 0) { + CEL_RETURN_IF_ERROR(Slice()); + } + CEL_RETURN_IF_ERROR(Cleanup()); + return std::move(factory_.mutable_ast()); + } + + private: + // The core unnest routine. + absl::Status Slice(); + // Fixup the AST post-unnesting. + absl::Status Cleanup(); + + void ReportErrorAtId(int64_t id, absl::string_view message); + + OptimizerExprFactory factory_; + int height_; + policy_internal::IssueReporter& issues_; +}; + +class UnnestRewriter : public AstRewriterBase { + public: + explicit UnnestRewriter(OptimizerExprFactory& f, Expr& block_list_expr, + absl::Span cuts) + : factory_(f), cuts_(cuts), block_list_expr_(block_list_expr) {} + + bool PostVisitRewrite(Expr& expr) override { + using std::swap; + // Post order so we always see children before parents. + // No need to copy metadata since we're only moving exprs or minting + // new ones. + if (absl::c_contains(cuts_, expr.id())) { + size_t idx = block_list_expr_.list_expr().elements().size(); + Expr value = factory_.NewIdent(absl::StrCat("@index", idx)); + factory_.RecordReplacement(expr.id(), value, /*keep_metadata=*/true); + swap(value, expr); + block_list_expr_.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(value))); + return true; + } + return false; + } + + private: + OptimizerExprFactory& factory_; + absl::Span cuts_; + Expr& block_list_expr_; +}; + +absl::Status Unnester::Slice() { + Expr& root = factory_.mutable_ast().mutable_root_expr(); + if (root.call_expr().function() != kCelBlock || + root.call_expr().args().size() != 2 || + !root.call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + // Two passes, we identify the slice points (bottom up), then cut + // and paste the leaves into the block list. + NavigableAst nav_ast = NavigableAst::Build(factory_.ast().root_expr()); + + ABSL_DCHECK(nav_ast.IdsAreUnique()); + bool can_cut = true; + std::vector cuts; + for (const NavigableAstNode& node : nav_ast.Root().DescendantsPostorder()) { + // Subsequent cuts will be height_ + 1 in the block, indices. Within the + // error margin we specified. + if (node.height() % height_ == 0) { + if (HasComprehensionParent(node)) { + ReportErrorAtId( + node.expr()->id(), + absl::StrCat( + "cannot unnest AST due to comprehension. cannot accommodate " + "height limit of ", + height_)); + can_cut = false; + continue; + } + if (&node == &nav_ast.Root()) { + // If evenly divisible by height, don't cut since it will net a taller + // AST. + continue; + } + cuts.push_back(node.expr()->id()); + } + } + + if (!can_cut || cuts.empty()) { + return absl::OkStatus(); + } + + Expr& block_list_expr = root.mutable_call_expr().mutable_args()[0]; + Expr& insertion_expr = root.mutable_call_expr().mutable_args()[1]; + + UnnestRewriter rewriter(factory_, block_list_expr, cuts); + AstRewrite(insertion_expr, rewriter); + + return absl::OkStatus(); +} + +absl::Status Unnester::Cleanup() { + using std::swap; + + const auto& ast = factory_.ast(); + if (ast.root_expr().call_expr().function() != kCelBlock || + ast.root_expr().call_expr().args().size() != 2 || + !ast.root_expr().call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + if (ast.root_expr().call_expr().args()[0].list_expr().elements().empty()) { + Expr value = std::move(factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]); + factory_.mutable_ast().mutable_root_expr() = std::move(value); + } + + return absl::OkStatus(); +} + +void Unnester::ReportErrorAtId(int64_t id, absl::string_view message) { + int32_t position = 0; + auto it = factory_.ast().source_info().positions().find(id); + if (it != factory_.ast().source_info().positions().end()) { + position = it->second; + } + issues_.ReportError(-1, position, message); +} +} // namespace + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options) { + policy_internal::IssueReporter issues; + if (options.unnesting_height_limit != 0 && + options.unnesting_height_limit < 2) { + return absl::InvalidArgumentError( + "unnesting_height_limit must be at least 2"); + } + auto builder = compiler.ToBuilder(); + ExpressionContainer cont; + for (const auto& import : policy.imports()) { + auto status = cont.AddAbbreviation(import.name().value()); + if (!status.ok()) { + issues.ReportError( + import.name().id(), + absl::StrCat("'", import.name().value(), "': ", status.message())); + } + } + + builder->GetCheckerBuilder().SetExpressionContainer(cont); + CEL_ASSIGN_OR_RETURN(auto base_compiler, builder->Build()); + + PolicyCompiler policy_compiler(&issues, std::move(base_compiler)); + + IntermediateCompiledPolicy icp; + CEL_RETURN_IF_ERROR(policy_compiler.CompilePolicy(policy, &icp)); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + CEL_ASSIGN_OR_RETURN(base_compiler, builder->Build()); + switch (icp.semantics()) { + case RuleSemantics::kFirstMatch: { + FirstMatchComposer composer(icp, *base_compiler, issues); + CEL_RETURN_IF_ERROR(composer.Compose()); + if (!composer.success()) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + auto ast = composer.ReleaseAst(); + Unnester unnester(std::move(*ast), options.unnesting_height_limit, + issues); + CEL_ASSIGN_OR_RETURN(Ast unnested_ast, unnester.Unnest()); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + return CelPolicyValidationResult( + std::make_unique(std::move(unnested_ast)), {}, + policy.source_ptr()); + } + default: + return absl::UnimplementedError( + absl::StrCat("Unsupported RuleSemantics: ", icp.semantics())); + } +} + +} // namespace cel diff --git a/policy/compiler.h b/policy/compiler.h new file mode 100644 index 000000000..0187bd1a2 --- /dev/null +++ b/policy/compiler.h @@ -0,0 +1,50 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_validation_result.h" + +namespace cel { + +struct CompilePolicyOptions { + // If greater than 0, the compiler will attempt to unnest rule branches + // at the specified height. The overall height of the final AST may exceed + // this by a small, fixed margin. + // + // To avoid slicing comprehensions, subexpressions within comprehensions + // are not eligible for unnesting. If the height limit cannot be accommodated, + // an error with code InvalidArgument is returned. + // + // If the AST is converted to proto, even relatively low levels of nesting + // can cause problems in serialization/deserialization. This does not apply + // if the AST is used directly by the runtime. + int unnesting_height_limit = 0; +}; + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +// +// TODO(b/506179116): Implementation in progress. Functionally complete, +// but errors are not consistent with other implementations. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ diff --git a/policy/compiler_test.cc b/policy/compiler_test.cc new file mode 100644 index 000000000..8db494b45 --- /dev/null +++ b/policy/compiler_test.cc @@ -0,0 +1,946 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/types/message_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/cel_policy.yaml"; + +absl::StatusOr> BuildTestCompiler() { + CompilerOptions opts; + opts.adapt_parser_errors = true; + opts.parser_options.enable_optional_syntax = true; + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool(), opts)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", IntType()))); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", IntType()))); + + const google::protobuf::Descriptor* descriptor = + cel::internal::GetSharedTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + if (descriptor == nullptr) { + return absl::InternalError("Failed to find TestAllTypes descriptor"); + } + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("spec", cel::MessageType(descriptor)))); + + return builder->Build(); +} + +absl::StatusOr> ParsePolicyFromYaml( + absl::string_view yaml_content) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(yaml_content, "test.yaml")); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(auto parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError("Invalid policy YAML structure"); + } + return parse_result.ReleasePolicy(); +} + +TEST(CompilerTest, SmokeTest) { + std::string contents; + std::string test_file = + cel::internal::ResolveRunfilesPath(kTestPolicyFilePath); + auto read_status = cel::internal::GetFileContents(test_file, &contents); + ASSERT_THAT(read_status, IsOk()); + + auto source_or = cel::NewSource(contents, "cel_policy.yaml"); + ASSERT_THAT(source_or.status(), IsOk()); + auto source = *std::move(source_or); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + auto parse_result_or = cel::ParseYamlCelPolicy(policy_source); + ASSERT_THAT(parse_result_or.status(), IsOk()); + auto parse_result = *std::move(parse_result_or); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, VariableOutOfScopeReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: variables.non_existent == 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, ConditionNotBoolReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("condition must evaluate to bool")); +} + +TEST(CompilerTest, InvalidOutputExpressionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: undeclared_var +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, UnreachableMatchAfterTriviallyTrueCondition) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"first"' + - condition: true + output: '"second"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, UnreachableMatchAfterUnconditionalExhaustiveSubRule) { + absl::string_view yaml = R"yaml( +name: dead_branch +rule: + match: + - rule: + match: + - output: 1 + - output: 2 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, RuleWithoutMatchesReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("rule does not specify match conditions")); +} + +TEST(CompilerTest, ExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 15 + output: '"greater than 15"' + - condition: variables.test_var > 5 + output: '"greater than 5"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +TEST(CompilerTest, NonExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 5 + output: '"greater than 5"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, PolicyReferencesEnvInput) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: spec.single_int32 > 10 + output: '"greater than 10"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +struct EvaluationTestCase { + std::string name; + std::string yaml_policy; + struct Input { + int64_t x; + int64_t y; + } input; + ValueMatcher expected_result_matcher; +}; + +class PolicyEvaluationTest : public testing::TestWithParam { +}; + +TEST_P(PolicyEvaluationTest, Evaluate) { + const auto& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(test_case.yaml_policy)); + ASSERT_OK_AND_ASSIGN(auto validation_result, + CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(validation_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, validation_result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + // Set up activation + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(test_case.input.x)); + activation.InsertOrAssignValue("y", cel::IntValue(test_case.input.y)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.expected_result_matcher); +} + +constexpr absl::string_view kEvalPolicyYaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: x > 10 && y > 10 + output: '"both greater than 10"' + - condition: x > 10 + output: '"x greater than 10"' + - condition: y > 10 + output: '"y greater than 10"' + - output: '"default"' +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + PolicyEvaluationTest, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "BothGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 15}, + .expected_result_matcher = StringValueIs("both greater than 10"), + }, + EvaluationTestCase{ + .name = "XGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 5}, + .expected_result_matcher = StringValueIs("x greater than 10"), + }, + EvaluationTestCase{ + .name = "YGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 15}, + .expected_result_matcher = StringValueIs("y greater than 10"), + }, + EvaluationTestCase{ + .name = "Default", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 5}, + .expected_result_matcher = StringValueIs("default"), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNonExhaustivePolicyYaml = R"yaml( +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x < 3 + output: 1 + - condition: x < 5 + output: 2 + - condition: x < 0 + rule: + match: + - condition: x > -2 + output: 3 + - condition: x > -4 + output: 4 + - output: 5 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NonExhaustivePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals0_FallthroughTopLevel", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEquals2_MatchesFirstNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals6_FallthroughNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 6, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1_MatchesMinus2", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3_MatchesMinus4", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus5_MatchesDefault", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -5, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(5)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNestedVariablePolicyYaml = R"yaml( +name: nested_rule4 +rule: + variables: + - name: i + expression: "1" + - name: j + expression: "2" + match: + - condition: x > 0 + rule: + variables: + - name: k + expression: "3" + match: + - output: "variables.i + variables.j + variables.k" + - condition: x < 0 + rule: + variables: + - name: j + expression: "5" + - name: k + expression: "4" + match: + - output: "variables.i + variables.j + variables.k" + - output: "variables.i + variables.j" +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NestedVariablePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XGreaterThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(6), + }, + EvaluationTestCase{ + .name = "XLessThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(10), + }, + EvaluationTestCase{ + .name = "XEquals0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view + kOptionalChainingUnconditionalSubRuleOptionalParentYaml = R"yaml( +name: optional_chaining +rule: + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 + condition: x < 0 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRuleOptionalParent, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalSubRuleYaml = R"yaml( +name: optional_chaining +rule: + id: r1 + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRule, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(2), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalComplexYaml = R"yaml( +name: optional_chaining +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: x == 1 + output: 1 + - output: 2 + - rule: + match: + - condition: x == -1 + output: 3 + - condition: x == -2 + output: 4 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalComplex, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus2", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 + - output: 3 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = IntValueIs(2), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: non_exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalNonExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CompilerTest, ImportsAndAbbreviations) { + absl::string_view yaml = R"yaml( +name: imports_test +imports: + - name: cel.expr.conformance.proto3.TestAllTypes +rule: + match: + - condition: 'spec == TestAllTypes{single_int32: 10}' + output: '"matched"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + auto ast_or = CompilePolicy(*compiler, *policy); + ASSERT_THAT(ast_or, IsOk()); +} + +TEST(CompilerTest, MatchWithoutProductionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match must specify an output or rule")); +} + +int GetAstHeight(const cel::Ast& ast) { + auto nav_ast = cel::NavigableAst::Build(ast.root_expr()); + return nav_ast.Root().height(); +} + +TEST(CompilerTest, UnnestHeightValidation) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 1; + auto status_or = CompilePolicy(*compiler, *policy, options); + EXPECT_THAT(status_or.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "unnesting_height_limit must be at least 2"))); + + options.unnesting_height_limit = 2; + EXPECT_THAT(CompilePolicy(*compiler, *policy, options), IsOk()); +} + +constexpr absl::string_view kDeepPolicyYaml = R"yaml( +name: deep_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x > 1 + rule: + match: + - condition: x > 2 + rule: + match: + - condition: x > 3 + rule: + match: + - condition: x > 4 + rule: + match: + - condition: x > 5 + output: 6 + - output: 5 + - output: 4 + - output: 3 + - output: 2 + - output: 1 + - output: 0 +)yaml"; + +TEST(CompilerTest, UnnestHeightReduction) { + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + // Compile without unnesting + CompilePolicyOptions options_no_unnest; + options_no_unnest.unnesting_height_limit = 0; + ASSERT_OK_AND_ASSIGN(auto result_no_unnest, + CompilePolicy(*compiler, *policy, options_no_unnest)); + ASSERT_TRUE(result_no_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_no_unnest, result_no_unnest.ReleaseAst()); + int height_no_unnest = GetAstHeight(*ast_no_unnest); + + CompilePolicyOptions options_unnest; + options_unnest.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result_unnest, + CompilePolicy(*compiler, *policy, options_unnest)); + ASSERT_TRUE(result_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_unnest, result_unnest.ReleaseAst()); + int height_unnest = GetAstHeight(*ast_unnest); + + EXPECT_EQ(height_no_unnest, 8); + EXPECT_EQ(height_unnest, 5); + EXPECT_LT(height_unnest, height_no_unnest); +} + +TEST(CompilerTest, UnnestComprehensionFailure) { + absl::string_view yaml = R"yaml( +name: comprehension_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: "[1, 2].all(i, i > x)" + output: 1 + - output: 2 + - output: 0 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("cannot unnest AST due to comprehension")); +} + +struct UnnestEvaluationTestCase { + std::string name; + int64_t x; + ValueMatcher expected; +}; + +class UnnestedDeepPolicyEvaluationTest + : public testing::TestWithParam {}; + +TEST_P(UnnestedDeepPolicyEvaluationTest, Evaluate) { + const auto& tc = GetParam(); + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(tc.x)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(cel::Value res, program->Evaluate(&arena, activation)); + + EXPECT_THAT(res, tc.expected); +} + +INSTANTIATE_TEST_SUITE_P( + UnnestedDeepPolicyEvaluation, UnnestedDeepPolicyEvaluationTest, + testing::Values(UnnestEvaluationTestCase{"XEquals6", 6, IntValueIs(6)}, + UnnestEvaluationTestCase{"XEquals5", 5, IntValueIs(5)}, + UnnestEvaluationTestCase{"XEquals4", 4, IntValueIs(4)}, + UnnestEvaluationTestCase{"XEquals3", 3, IntValueIs(3)}, + UnnestEvaluationTestCase{"XEquals2", 2, IntValueIs(2)}, + UnnestEvaluationTestCase{"XEquals1", 1, IntValueIs(1)}, + UnnestEvaluationTestCase{"XEquals0", 0, IntValueIs(0)}, + UnnestEvaluationTestCase{"XEqualsMinus1", -1, + IntValueIs(0)}), + [](const testing::TestParamInfo< + UnnestedDeepPolicyEvaluationTest::ParamType>& info) { + return info.param.name; + }); + +TEST(CompilerTest, UnnestCleanupRunsWhenDisabled) { + // A policy without variables and without nesting. + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 0; // Disabled + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // If cleanup ran, it should have optimized away the trivial `cel.@block`. + // So the root expression should NOT be a call to `cel.@block`. + // It should be just the constant `"ok"`. + auto nav_ast = cel::NavigableAst::Build(ast->root_expr()); + EXPECT_FALSE(nav_ast.Root().expr()->has_call_expr() && + nav_ast.Root().expr()->call_expr().function() == "cel.@block"); + EXPECT_TRUE(nav_ast.Root().expr()->has_const_expr()); +} +} // namespace +} // namespace cel diff --git a/policy/internal/BUILD b/policy/internal/BUILD new file mode 100644 index 000000000..30f43d431 --- /dev/null +++ b/policy/internal/BUILD @@ -0,0 +1,68 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "issue_reporter", + srcs = ["issue_reporter.cc"], + hdrs = ["issue_reporter.h"], + deps = [ + "//common:source", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "optimizer_expr_factory", + srcs = ["optimizer_expr_factory.cc"], + hdrs = ["optimizer_expr_factory.h"], + deps = [ + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor_base", + "//common:constant", + "//common:expr", + "//common:expr_factory", + "//common:source", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "optimizer_expr_factory_test", + srcs = ["optimizer_expr_factory_test.cc"], + deps = [ + ":optimizer_expr_factory", + "//common:ast", + "//common:ast_proto", + "//common:ast_rewrite", + "//common:decl", + "//common:expr", + "//common:expr_factory", + "//common:source", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:expr_printer", + "//tools:cel_unparser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/policy/internal/issue_reporter.cc b/policy/internal/issue_reporter.cc new file mode 100644 index 000000000..944e687d6 --- /dev/null +++ b/policy/internal/issue_reporter.cc @@ -0,0 +1,45 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/issue_reporter.h" + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel::policy_internal { + +void IssueReporter::ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message) { + issues_.push_back({element, severity, message}); +} + +void IssueReporter::ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, + absl::string_view message) { + issues_.push_back({element, relative_position, severity, message}); +} + +void IssueReporter::ReportError(CelPolicyElementId element, + absl::string_view message) { + ReportIssue(element, Severity::kError, message); +} + +void IssueReporter::ReportError(CelPolicyElementId element, SourcePosition pos, + absl::string_view message) { + ReportOffsetIssue(element, pos, Severity::kError, message); +} + +} // namespace cel::policy_internal diff --git a/policy/internal/issue_reporter.h b/policy/internal/issue_reporter.h new file mode 100644 index 000000000..3f88806ef --- /dev/null +++ b/policy/internal/issue_reporter.h @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel::policy_internal { + +class IssueReporter { + private: + using Severity = CelPolicyIssue::Severity; + + public: + void ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message); + + void ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, absl::string_view message); + + void ReportError(CelPolicyElementId element, absl::string_view message); + void ReportError(CelPolicyElementId element, SourcePosition relative_pos, + absl::string_view message); + + std::vector ReleaseIssues() { + using std::swap; + std::vector out; + swap(out, issues_); + return out; + } + const std::vector& issues() const { return issues_; } + + private: + std::vector issues_; +}; + +} // namespace cel::policy_internal + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ diff --git a/policy/internal/optimizer_expr_factory.cc b/policy/internal/optimizer_expr_factory.cc new file mode 100644 index 000000000..6c89ae958 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.cc @@ -0,0 +1,373 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +namespace { + +class MaxIdVisitor final : public AstVisitorBase { + public: + ExprId max_id() const { return max_id_; } + + void PreVisitExpr(const Expr& expr) override { + max_id_ = std::max(max_id_, expr.id()); + } + + void PostVisitExpr(const Expr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr& struct_expr) override { + for (const auto& field : struct_expr.fields()) { + max_id_ = std::max(max_id_, field.id()); + } + } + + void PostVisitMap(const Expr&, const MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + max_id_ = std::max(max_id_, entry.id()); + } + } + + private: + ExprId max_id_ = 0; +}; + +ExprId GetMaxId(const Expr& expr) { + MaxIdVisitor visitor; + AstTraverse(expr, visitor); + return visitor.max_id(); +} + +ExprId GetMaxId(const Ast& ast) { + ExprId max_id = GetMaxId(ast.root_expr()); + for (const auto& [id, _] : ast.source_info().positions()) { + max_id = std::max(max_id, id); + } + for (const auto& [id, expr] : ast.source_info().macro_calls()) { + max_id = std::max(max_id, id); + max_id = std::max(max_id, GetMaxId(expr)); + } + return max_id; +} + +// Replaces nested macros in a macro_calls expr with reference nodes. +// +// The macro_calls map is used for retaining the original structure of the +// parsed expression before macro expansion. When a macro appears inside another +// macro, the parser will replace the inner macro expr node with an unspecified +// expr with the inner macro's ID in the macro_calls map to save space. +class MakeMacroCallRewrite final : public AstRewriterBase { + public: + explicit MakeMacroCallRewrite(const SourceInfo& source_info) + : source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (source_info_.macro_calls().find(expr.id()) != + source_info_.macro_calls().end()) { + ExprId id = expr.id(); + expr.mutable_kind() = UnspecifiedExpr(); + expr.set_id(id); + return true; + } + return false; + } + + private: + const SourceInfo& source_info_; +}; + +// Updates macro_calls map entries to reflect a replaced expression in the +// main AST. +class ReplaceMacroCallRewrite final : public AstRewriterBase { + public: + ReplaceMacroCallRewrite(ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) + : old_id_(old_id), replacement_(replacement), source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = macro_replacement(); + return true; + } + return false; + } + + Expr macro_replacement() { + if (!macro_replacement_) { + macro_replacement_.emplace(replacement_); + MakeMacroCallRewrite hole_creator(source_info_); + AstRewrite(*macro_replacement_, hole_creator); + } + return *macro_replacement_; + } + + private: + ExprId old_id_; + const Expr& replacement_; + absl::optional macro_replacement_; + const SourceInfo& source_info_; +}; + +void ReplaceSubExpr(Expr& expr, ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) { + ReplaceMacroCallRewrite rewriter(old_id, replacement, source_info); + AstRewrite(expr, rewriter); +} + +class IdRewriter : public AstRewriterBase { + using CopyIdFn = absl::AnyInvocable; + + public: + explicit IdRewriter(CopyIdFn copy_id) : copy_id_(std::move(copy_id)) {} + + // No structure changes just ids. + bool PreVisitRewrite(Expr& expr) override { + expr.set_id(copy_id_(expr.id())); + if (expr.has_struct_expr()) { + for (auto& field : expr.mutable_struct_expr().mutable_fields()) { + field.set_id(copy_id_(field.id())); + } + } else if (expr.has_map_expr()) { + for (auto& entry : expr.mutable_map_expr().mutable_entries()) { + entry.set_id(copy_id_(entry.id())); + } + } + return false; + } + + private: + CopyIdFn copy_id_; +}; + +} // namespace + +OptimizerExprFactory::OptimizerExprFactory(Ast basis) + : ast_(std::move(basis)), next_id_(GetMaxId(ast_) + 1) {} + +OptimizerExprFactory::OptimizerExprFactory() : next_id_(1) {} + +Expr OptimizerExprFactory::Copy(const Expr& expr) { + Expr copied = expr; + IdRewriter rewriter([this](ExprId id) { return CopyId(id); }); + AstRewrite(copied, rewriter); + return copied; +} + +ListExprElement OptimizerExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField OptimizerExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry OptimizerExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +ExprId OptimizerExprFactory::NextId() { return next_id_++; } + +ExprId OptimizerExprFactory::CopyId(ExprId id) { + if (id == 0) { + return 0; + } + auto it = renumbers_.find(id); + if (it != renumbers_.end()) { + return it->second; + } + ExprId new_id = NextId(); + renumbers_[id] = new_id; + return new_id; +} + +SourceInfo OptimizerExprFactory::RemapSourceInfo(const SourceInfo& info, + SourcePosition offset) { + SourceInfo out; + + for (const auto& [old_id, macro_expr] : info.macro_calls()) { + if (auto it = renumbers_.find(old_id); it != renumbers_.end()) { + ExprId new_id = it->second; + out.mutable_macro_calls()[new_id] = Copy(macro_expr); + } + } + + for (const auto& [old_id, new_id] : renumbers_) { + if (auto it = info.positions().find(old_id); it != info.positions().end()) { + out.mutable_positions()[new_id] = it->second + offset; + } + } + + return out; +} + +void OptimizerExprFactory::MergeSourceInfo(const SourceInfo& info) { + auto& target_info = ast_.mutable_source_info(); + + for (const auto& [id, pos] : info.positions()) { + auto [it, inserted] = target_info.mutable_positions().insert({id, pos}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in positions merge"}); + } + } + + for (const auto& [id, expr] : info.macro_calls()) { + auto [it, inserted] = target_info.mutable_macro_calls().insert({id, expr}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in macro calls merge"}); + } + } + + // TODO(b/506179116): need to add some check that we aren't + // introducing incompatible tags. Not possible in the policy compiler right + // now. + for (const auto& ext : info.extensions()) { + auto& target_exts = target_info.mutable_extensions(); + if (!absl::c_linear_search(target_exts, ext)) { + target_exts.push_back(ext); + } + } +} + +void OptimizerExprFactory::RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata) { + auto& source_info = ast_.mutable_source_info(); + if (!keep_metadata) { + source_info.mutable_positions().erase(id); + source_info.mutable_macro_calls().erase(id); + } + + for (auto& [macro_id, macro_expr] : source_info.mutable_macro_calls()) { + ReplaceSubExpr(macro_expr, id, replacement, source_info); + } +} + +Expr OptimizerExprFactory::ReportError(absl::string_view message) { + ExprId id = NextId(); + issues_.push_back(Issue{id, std::string(message)}); + return NewUnspecified(id); +} + +Expr OptimizerExprFactory::ReportErrorAt(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{expr.id(), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::ReportErrorAtCopy(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{CopyId(expr.id()), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::NewUnspecified() { return NewUnspecified(NextId()); } + +Expr OptimizerExprFactory::NewNullConst() { return NewNullConst(NextId()); } + +Expr OptimizerExprFactory::NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(const char* value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(const char* value) { + return NewStringConst(NextId(), value); +} + +absl::flat_hash_map OptimizerExprFactory::ConsumeRenumbers() { + using std::swap; + absl::flat_hash_map out; + swap(out, renumbers_); + return out; +} + +void OptimizerExprFactory::StartCopyContext() { renumbers_.clear(); } + +const std::vector& OptimizerExprFactory::issues() + const { + return issues_; +} + +const Ast& OptimizerExprFactory::ast() const { return ast_; } + +Ast& OptimizerExprFactory::mutable_ast() { return ast_; } + +absl::string_view OptimizerExprFactory::AccuVarName() { + return ExprFactory::AccuVarName(); +} + +Expr OptimizerExprFactory::NewAccuIdent() { return NewAccuIdent(NextId()); } + +ExprId OptimizerExprFactory::CopyId(const Expr& expr) { + return CopyId(expr.id()); +} + +} // namespace cel diff --git a/policy/internal/optimizer_expr_factory.h b/policy/internal/optimizer_expr_factory.h new file mode 100644 index 000000000..6f63f1485 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.h @@ -0,0 +1,419 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestOptimizerExprFactory; + +// `OptimizerExprFactory` is a specialization of `ExprFactory` used for AST +// optimization. It provides utilities for correcting metadata for modified +// ASTs. +class OptimizerExprFactory : protected ExprFactory { + public: + struct Issue { + ExprId location = 0; + std::string message; + }; + + explicit OptimizerExprFactory(Ast basis); + OptimizerExprFactory(); + + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + // Consume the current set of renumberings. + absl::flat_hash_map ConsumeRenumbers(); + + // Starts a new copy context. The current set of renumberings are cleared. + void StartCopyContext(); + + const std::vector& issues() const; + + // Record that a node in the working AST was replaced. This is used to correct + // metadata referencing the old ID. + void RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata = false); + + // Makes a copy of source metadata that is remapped to new expr Ids using + // current renumberings. This is suitable for merging into the main source + // info. + SourceInfo RemapSourceInfo(const SourceInfo& info, SourcePosition offset = 0); + + // Merge a remapped SourceInfo into the current one. + void MergeSourceInfo(const SourceInfo& info); + + const Ast& ast() const; + Ast& mutable_ast(); + + absl::string_view AccuVarName(); + + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified(); + + ABSL_MUST_USE_RESULT Expr NewNullConst(); + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value); + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value); + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value); + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name); + + ABSL_MUST_USE_RESULT Expr NewAccuIdent(); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field); + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args); + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args); + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false); + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields); + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false); + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result); + + ABSL_MUST_USE_RESULT Expr ReportError(absl::string_view message); + + // Reports an error at the id in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAt(const Expr& expr, + absl::string_view message); + // Reports an error at the mapped id of the copy of expr in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAtCopy(const Expr& expr, + absl::string_view message); + + protected: + ABSL_MUST_USE_RESULT ExprId NextId(); + + ABSL_MUST_USE_RESULT ExprId CopyId(ExprId id); + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr); + + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + private: + Ast ast_; + absl::flat_hash_map renumbers_; + std::vector issues_; + + ExprId next_id_ = 1; +}; + +// Implementation details. + +template +Expr OptimizerExprFactory::NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); +} + +template +Expr OptimizerExprFactory::NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); +} + +template +Expr OptimizerExprFactory::NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); +} + +template +StructExprField OptimizerExprFactory::NewStructField(Name name, Value value, + bool optional) { + return NewStructField(NextId(), std::move(name), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); +} + +template +MapExprEntry OptimizerExprFactory::NewMapEntry(Key key, Value value, + bool optional) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); +} + +template +Expr OptimizerExprFactory::NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); +} + +template +Expr OptimizerExprFactory::NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ diff --git a/policy/internal/optimizer_expr_factory_test.cc b/policy/internal/optimizer_expr_factory_test.cc new file mode 100644 index 000000000..1b14b5628 --- /dev/null +++ b/policy/internal/optimizer_expr_factory_test.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/ast_rewrite.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/expr_printer.h" +#include "tools/cel_unparser.h" + +namespace cel { + +using ::testing::SizeIs; + +// Expose protected members of OptimizerExprFactory for use in tests +// +// These allow setting explicit IDs which is not safe for the optimizing +// factory. +class TestOptimizerExprFactory final : public OptimizerExprFactory { + public: + using OptimizerExprFactory::OptimizerExprFactory; + + using OptimizerExprFactory::NewBoolConst; + using OptimizerExprFactory::NewCall; + using OptimizerExprFactory::NewComprehension; + using OptimizerExprFactory::NewIdent; + using OptimizerExprFactory::NewList; + using OptimizerExprFactory::NewListElement; + using OptimizerExprFactory::NewMap; + using OptimizerExprFactory::NewMapEntry; + using OptimizerExprFactory::NewMemberCall; + using OptimizerExprFactory::NewSelect; + using OptimizerExprFactory::NewStruct; + using OptimizerExprFactory::NewStructField; + using OptimizerExprFactory::NewUnspecified; + using OptimizerExprFactory::NextId; +}; + +namespace { + +class ReplaceExprRewriter final : public AstRewriterBase { + public: + ReplaceExprRewriter(ExprId old_id, const Expr& replacement) + : old_id_(old_id), replacement_(replacement) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = replacement_; + return true; + } + return false; + } + + private: + ExprId old_id_; + const Expr& replacement_; +}; + +void ReplaceExprInTree(Expr& expr, ExprId old_id, const Expr& replacement) { + ReplaceExprRewriter rewriter(old_id, replacement); + AstRewrite(expr, rewriter); +} + +absl::StatusOr> CreateTestCompiler() { + CompilerOptions opts; + opts.parser_options.add_macro_calls = true; + CEL_ASSIGN_OR_RETURN( + auto builder, cel::NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("to_replace", cel::DynType()))); + return builder->Build(); +} + +TEST(OptimizerExprFactory, CopyUnspecified) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(OptimizerExprFactory, CopyIdent) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyConst) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(OptimizerExprFactory, CopySelect) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(OptimizerExprFactory, CopyCall) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(OptimizerExprFactory, CopyList) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(OptimizerExprFactory, CopyStruct) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(OptimizerExprFactory, CopyMap) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(OptimizerExprFactory, CopyComprehension) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(OptimizerExprFactory, RemapSourceInfo) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + info.mutable_positions()[1] = 42; // old ID 1 has position 42 + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to position 42 + 10 = 52 + auto it = remapped.positions().find(2); + ASSERT_NE(it, remapped.positions().end()); + EXPECT_EQ(it->second, 52); +} + +TEST(OptimizerExprFactory, RemapSourceInfoWithMacroCalls) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + // old ID 1 has macro call with ID 3 + info.mutable_macro_calls()[1] = factory.NewIdent("bar"); + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to the copied macro call + // since "bar" has ID 3, Copy(bar) should map ID 3 to ID 4 + + auto it = remapped.macro_calls().find(2); + ASSERT_NE(it, remapped.macro_calls().end()); + + // The macro call should be an Ident with new ID 4 + EXPECT_EQ(it->second.id(), 4); + EXPECT_TRUE(it->second.has_ident_expr()); + EXPECT_EQ(it->second.ident_expr().name(), "bar"); +} + +TEST(OptimizerExprFactory, ReportError) { + TestOptimizerExprFactory factory{Ast()}; + Expr err_expr = factory.ReportError("something went wrong"); + + // err_expr should be unspecified with ID 1 + EXPECT_EQ(err_expr.id(), 1); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with ID 1 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "something went wrong"); +} + +TEST(OptimizerExprFactory, ReportErrorAt) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + Expr err_expr = factory.ReportErrorAtCopy(orig, "error on foo"); + + // err_expr should be unspecified with ID 3 (NextId) + EXPECT_EQ(err_expr.id(), 3); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with mapped ID 2 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 2); + EXPECT_EQ(factory.issues()[0].message, "error on foo"); +} + +TEST(OptimizerExprFactory, MergeSourceInfo) { + // Create a base AST with some source info + SourceInfo base_info; + base_info.set_syntax_version("cel1"); + base_info.set_location("test.cel"); + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + + TestOptimizerExprFactory factory{std::move(base_ast)}; + + // Create a new source info to merge + SourceInfo new_info; + new_info.mutable_positions()[2] = 20; + + factory.MergeSourceInfo(new_info); + + // The merged source info should have both positions + const auto& merged_info = factory.ast().source_info(); + EXPECT_EQ(merged_info.syntax_version(), "cel1"); + EXPECT_EQ(merged_info.location(), "test.cel"); + + auto it1 = merged_info.positions().find(1); + ASSERT_NE(it1, merged_info.positions().end()); + EXPECT_EQ(it1->second, 10); + + auto it2 = merged_info.positions().find(2); + ASSERT_NE(it2, merged_info.positions().end()); + EXPECT_EQ(it2->second, 20); +} + +TEST(OptimizerExprFactory, MergeSourceInfoConflict) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_positions()[1] = 20; // conflicting ID 1 + + factory.MergeSourceInfo(new_info); + + // Should report an error for the conflict + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in positions merge"); +} + +TEST(OptimizerExprFactory, RecordReplacement) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + base_info.mutable_positions()[2] = 20; + + TestOptimizerExprFactory factory{Ast()}; + + // macro_calls[1] maps ID 1 to macro call "bar(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[1] = + factory.NewCall("bar", factory.NewIdent(1, "foo")); + + // macro_calls[2] maps ID 2 to macro call "baz(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[2] = + factory.NewCall("baz", factory.NewIdent(1, "foo")); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory optimizer{std::move(base_ast)}; + + // Record the replacement of ID 1 by a new Ident "replacement" with ID 3 + optimizer.RecordReplacement(1, factory.NewIdent(3, "replacement")); + + const auto& result_info = optimizer.ast().source_info(); + + // 1. ID 1 should be erased from positions + EXPECT_EQ(result_info.positions().find(1), result_info.positions().end()); + EXPECT_NE(result_info.positions().find(2), result_info.positions().end()); + + // 2. ID 1 should be erased from macro_calls keys + EXPECT_EQ(result_info.macro_calls().find(1), result_info.macro_calls().end()); + + // 3. macro_calls[2] should still exist, but its argument referencing ID 1 + // should be replaced with the Ident "replacement" with ID 3 inline + auto it = result_info.macro_calls().find(2); + ASSERT_NE(it, result_info.macro_calls().end()); + + const Expr& macro_expr = it->second; + ASSERT_TRUE(macro_expr.has_call_expr()); + ASSERT_EQ(macro_expr.call_expr().args().size(), 1); + + const Expr& arg = macro_expr.call_expr().args()[0]; + EXPECT_EQ(arg.id(), 3); + EXPECT_TRUE(arg.has_ident_expr()); + EXPECT_EQ(arg.ident_expr().name(), "replacement"); +} + +class IdAdorner : public cel::test::ExpressionAdorner { + public: + std::string Adorn(const cel::Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(OptimizerExprFactory, UnparseCopiedMacroCall) { + // Arrange: create an template expression and one to inline. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto basis_result, + compiler->Compile("[1].map(x, x + to_replace)")); + ASSERT_TRUE(basis_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto basis_ast, basis_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto copy_result, + compiler->Compile("[1].filter(x, x > 2).size()")); + ASSERT_TRUE(copy_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto copy_ast, copy_result.ReleaseAst()); + + // Locate the "to_replace" IdentExpr node in reference_map + ExprId to_replace_id = 0; + for (const auto& [id, ref] : basis_ast->reference_map()) { + if (ref.name() == "to_replace") { + to_replace_id = id; + break; + } + } + ASSERT_NE(to_replace_id, 0); + + // Act: implement the optimization. + TestOptimizerExprFactory factory{std::move(*basis_ast)}; + Expr copied_expr = factory.Copy(copy_ast->root_expr()); + SourceInfo remapped_info = factory.RemapSourceInfo(copy_ast->source_info()); + factory.MergeSourceInfo(remapped_info); + + ReplaceExprInTree(factory.mutable_ast().mutable_root_expr(), to_replace_id, + copied_expr); + factory.RecordReplacement(to_replace_id, copied_expr); + + // Test AST structure. + EXPECT_EQ( + cel::test::ExprPrinter(IdAdorner()).Print(factory.ast().root_expr()), + R"(__comprehension__( + // Variable + x, + // Target + [ + 1#2 + ]#1, + // Accumulator + @result, + // Init + []#8, + // LoopCondition + true#9, + // LoopStep + _+_( + @result#10, + [ + _+_( + x#5, + __comprehension__( + // Variable + x, + // Target + [ + 1#18 + ]#17, + // Accumulator + @result, + // Init + []#19, + // LoopCondition + true#20, + // LoopStep + _?_:_( + _>_( + x#23, + 2#24 + )#22, + _+_( + @result#26, + [ + x#28 + ]#27 + )#25, + @result#29 + )#21, + // Result + @result#30)#16.size()#15 + )#6 + ]#11 + )#12, + // Result + @result#13)#14)"); + + // Check that the structure is compatible with unparser. + cel::expr::ParsedExpr optimized_parsed; + auto status = AstToParsedExpr(factory.ast(), &optimized_parsed); + ASSERT_THAT(status, absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::string unparsed, + google::api::expr::Unparse(optimized_parsed)); + + EXPECT_EQ(unparsed, "[1].map(x, x + [1].filter(x, x > 2).size())"); + + const CallExpr& call_expr = factory.mutable_ast() + .mutable_source_info() + .mutable_macro_calls()[14] + .mutable_call_expr(); + ASSERT_THAT(call_expr.args(), SizeIs(2)); + ASSERT_THAT(call_expr.args()[1].call_expr().args(), SizeIs(2)); + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].id(), 15); + + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].call_expr().target().id(), + 16); + EXPECT_EQ(call_expr.args()[1] + .call_expr() + .args()[1] + .call_expr() + .target() + .kind_case(), + ExprKindCase::kUnspecifiedExpr); +} + +TEST(OptimizerExprFactory, CopyMultipleAstsWithConsumeRenumbers) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto ast1_result, compiler->Compile("[1]")); + ASSERT_TRUE(ast1_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast1, ast1_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto ast2_result, compiler->Compile("2")); + ASSERT_TRUE(ast2_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast2, ast2_result.ReleaseAst()); + + TestOptimizerExprFactory factory{Ast()}; + + Expr copied1 = factory.Copy(ast1->root_expr()); + auto renumbers1 = factory.ConsumeRenumbers(); + + Expr copied2 = factory.Copy(ast2->root_expr()); + auto renumbers2 = factory.ConsumeRenumbers(); + + EXPECT_EQ(renumbers1.size(), 2); + EXPECT_EQ(renumbers2.size(), 1); + + EXPECT_NE(copied1.id(), copied2.id()); + EXPECT_GT(copied2.id(), copied1.id()); +} + +TEST(OptimizerExprFactory, MaxIdVisitorExprKinds) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateTestCompiler()); + + // Expression that covers all the kinds. + ASSERT_OK_AND_ASSIGN(auto source, NewSource(R"cel( + Struct{field : 1} || + {'key' : 'value'} || [1].exists(x, x) || foo(bar))cel")); + ASSERT_OK_AND_ASSIGN(auto ast, compiler->GetParser().Parse(*source)); + + TestOptimizerExprFactory factory{std::move(*ast)}; + + EXPECT_EQ(factory.NextId(), 26); +} + +TEST(OptimizerExprFactory, CopyListElement) { + TestOptimizerExprFactory factory{Ast()}; + ListExprElement orig = factory.NewListElement(factory.NewIdent("foo")); + ListExprElement copied = factory.Copy(orig); + EXPECT_EQ(copied.expr(), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyStructField) { + TestOptimizerExprFactory factory{Ast()}; + StructExprField orig = factory.NewStructField("bar", factory.NewIdent("baz")); + StructExprField copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 3); + EXPECT_EQ(copied.name(), "bar"); + EXPECT_EQ(copied.value(), factory.NewIdent(4, "baz")); +} + +TEST(OptimizerExprFactory, CopyMapEntry) { + TestOptimizerExprFactory factory{Ast()}; + MapExprEntry orig = + factory.NewMapEntry(factory.NewIdent("bar"), factory.NewIdent("baz")); + MapExprEntry copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 4); + EXPECT_EQ(copied.key(), factory.NewIdent(5, "bar")); + EXPECT_EQ(copied.value(), factory.NewIdent(6, "baz")); +} + +TEST(OptimizerExprFactory, MergeSourceInfoMacroConflict) { + SourceInfo base_info; + base_info.mutable_macro_calls()[1] = Expr(); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_macro_calls()[1] = Expr(); + + factory.MergeSourceInfo(new_info); + + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in macro calls merge"); +} + +} // namespace +} // namespace cel diff --git a/policy/test_custom_yaml_policy_parser.cc b/policy/test_custom_yaml_policy_parser.cc new file mode 100644 index 000000000..faced6952 --- /dev/null +++ b/policy/test_custom_yaml_policy_parser.cc @@ -0,0 +1,188 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parser.h" +#include "policy/yaml_policy_parser.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel::internal { + +// TestCustomYamlPolicyParser is used to support unit tests for custom tags +// and custom policy structures. It demonstrates the versatility of the +// cel::YamlPolicyParser framework API by implementing custom tag and block +// parsing without needing to modify the core parser. +class TestCustomYamlPolicyParser : public cel::YamlPolicyParser { + absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const override { + if (tag_name.value() == "name" || tag_name.value() == "description" || + tag_name.value() == "imports") { + return cel::YamlPolicyParser::ParsePolicyTag(ctx, tag_name, node); + } + if (tag_name.value() == "purpose") { + std::optional purpose = + GetValueString(ctx, node, "Policy purpose is not a string"); + if (purpose.has_value()) { + ctx.policy().mutable_metadata()["purpose"] = *purpose; + } + return true; + } + if (tag_name.value() == "version") { + std::optional version = + GetValueString(ctx, node, "Policy version is not a string"); + if (!version.has_value()) { + return true; + } + int version_int; + if (!absl::SimpleAtoi(version->value(), &version_int)) { + ctx.ReportError(version->id(), + absl::StrCat("Policy version is not an integer: ", + version->value())); + return true; + } + ctx.policy().mutable_metadata()["version"] = version_int; + return true; + } + + if (tag_name.value() == "conditions") { + if (!node.IsSequence()) { + ctx.ReportError(tag_name.id(), "Policy 'conditions' is not a sequence"); + return true; + } + for (const YAML::Node& condition : node) { + // Track the number of existing matches before parsing. When ParseMatch + // evaluates an 'else' block, it recursively triggers parsing and adds + // internal inner matches directly to the rule's match vector. + // Inserting the outer match at begin() + size_before ensures that the + // primary outer 'if' condition is always evaluated before its nested + // 'else' fallbacks. + // + // Example: + // if: x > 0 + // then: "positive" + // else: "negative" + // + // The inner "negative" match is parsed and appended to rule.matches() + // by the inner recursive call, before the outer "x > 0" match finishes. + // Inserting at size_before places the "x > 0" match ahead of the inner + // one. + size_t size_before = ctx.policy().rule().matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, + cel::YamlPolicyParser::ParseMatch( + ctx, condition, ctx.policy().mutable_rule())); + ctx.policy().mutable_rule().mutable_matches().insert( + ctx.policy().mutable_rule().mutable_matches().begin() + size_before, + std::move(match)); + } + + return true; + } + return false; + } + + absl::Status ParseThenBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, + Match& match) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'then' is not a string"); + if (val.has_value()) { + OutputBlock output; + output.set_output(*val); + match.set_result(output); + } + } else if (value_node.IsMap()) { + auto nested_rule = std::make_unique(); + CEL_ASSIGN_OR_RETURN( + Match nested_match, + cel::YamlPolicyParser::ParseMatch(ctx, value_node, *nested_rule)); + nested_rule->mutable_matches().insert( + nested_rule->mutable_matches().begin(), std::move(nested_match)); + match.set_result(std::move(nested_rule)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::Status ParseElseBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, Rule& rule) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'else' is not a string"); + if (val.has_value()) { + Match else_match; + else_match.set_id(CollectMetadata(ctx, value_node)); + OutputBlock output; + output.set_output(*val); + else_match.set_result(output); + rule.mutable_matches().push_back(std::move(else_match)); + } + } else if (value_node.IsMap()) { + size_t size_before = rule.matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, cel::YamlPolicyParser::ParseMatch( + ctx, value_node, rule)); + rule.mutable_matches().insert( + rule.mutable_matches().begin() + size_before, std::move(match)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, Match& match, + Rule& rule) const override { + if (tag_name.value() == "if") { + std::optional condition = + GetValueString(ctx, node, "Policy 'if' condition is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "then") { + CEL_RETURN_IF_ERROR(ParseThenBlock(ctx, node, match)); + return true; + } + if (tag_name.value() == "else") { + CEL_RETURN_IF_ERROR(ParseElseBlock(ctx, node, rule)); + return true; + } + return false; + } +}; + +const CelPolicyParser& GetTestCustomYamlPolicyParser() { + static const auto* const parser = new TestCustomYamlPolicyParser(); + return *parser; +} + +} // namespace cel::internal diff --git a/policy/test_util.cc b/policy/test_util.cc new file mode 100644 index 000000000..9fe1e43d1 --- /dev/null +++ b/policy/test_util.cc @@ -0,0 +1,221 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +#include "policy/test_util.h" + +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "cel/expr/value.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "yaml-cpp/yaml.h" + +namespace cel::test { + +namespace { + +absl::Status YamlToExprValue(const YAML::Node& node, + cel::expr::Value* proto) { + if (node.IsNull()) { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + if (node.IsScalar()) { + // Try bool + try { + proto->set_bool_value(node.as()); + return absl::OkStatus(); + } catch (...) { + } + // Try int64 + try { + int64_t val; + if (YAML::convert::decode(node, val)) { + proto->set_int64_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Try double + try { + double val; + if (YAML::convert::decode(node, val)) { + proto->set_double_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Fallback to string + proto->set_string_value(node.as()); + return absl::OkStatus(); + } + if (node.IsSequence()) { + auto* list = proto->mutable_list_value(); + for (const auto& elem : node) { + CEL_RETURN_IF_ERROR(YamlToExprValue(elem, list->add_values())); + } + return absl::OkStatus(); + } + if (node.IsMap()) { + auto* map_val = proto->mutable_map_value(); + for (auto it = node.begin(); it != node.end(); ++it) { + auto* entry = map_val->add_entries(); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->first, entry->mutable_key())); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->second, entry->mutable_value())); + } + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unknown YAML node type"); +} + +absl::Status ParseInputValue( + const YAML::Node& node, + cel::expr::conformance::test::InputValue* input_val) { + if (node.IsMap() && node["expr"].IsDefined()) { + input_val->set_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node.IsMap() && node["value"].IsDefined()) { + return YamlToExprValue(node["value"], input_val->mutable_value()); + } + return YamlToExprValue(node, input_val->mutable_value()); +} + +absl::Status ParseTestOutput(const YAML::Node& node, + cel::expr::conformance::test::TestOutput* output) { + if (!node.IsDefined()) { + return absl::InvalidArgumentError("Missing output node"); + } + if (node.IsMap()) { + if (node["expr"].IsDefined()) { + output->set_result_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node["value"].IsDefined()) { + return YamlToExprValue(node["value"], output->mutable_result_value()); + } + if (node["error"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + eval_error->add_errors()->set_message(node["error"].as()); + return absl::OkStatus(); + } + if (node["error_set"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + for (const auto& err : node["error_set"]) { + eval_error->add_errors()->set_message(err.as()); + } + return absl::OkStatus(); + } + if (node["unknown"].IsDefined()) { + auto* unknown = output->mutable_unknown(); + for (const auto& expr_id_node : node["unknown"]) { + unknown->add_exprs(expr_id_node.as()); + } + return absl::OkStatus(); + } + } + return YamlToExprValue(node, output->mutable_result_value()); +} + +absl::StatusOr +ParsePolicyTestSuiteYamlImpl(absl::string_view yaml_content) { + YAML::Node tests_node; + try { + tests_node = YAML::Load(std::string(yaml_content)); + } catch (const std::exception& e) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse YAML: ", e.what())); + } + + cel::expr::conformance::test::TestSuite test_suite; + if (tests_node["description"].IsDefined()) { + test_suite.set_description(tests_node["description"].as()); + } + + YAML::Node sections = tests_node["sections"]; + if (!sections.IsDefined()) { + sections = tests_node["section"]; // support singular format + } + if (!sections.IsDefined()) { + return absl::InvalidArgumentError( + "Missing 'sections' or 'section' in tests YAML"); + } + + for (const auto& section_node : sections) { + auto* section = test_suite.add_sections(); + if (section_node["name"].IsDefined()) { + section->set_name(section_node["name"].as()); + } + if (section_node["description"].IsDefined()) { + section->set_description(section_node["description"].as()); + } + + YAML::Node tests = section_node["tests"]; + if (!tests.IsDefined()) { + tests = section_node["test"]; // support singular format + } + if (!tests.IsDefined()) { + continue; + } + + for (const auto& test_node : tests) { + auto* test_case = section->add_tests(); + if (test_node["name"].IsDefined()) { + test_case->set_name(test_node["name"].as()); + } + if (test_node["description"].IsDefined()) { + test_case->set_description(test_node["description"].as()); + } + if (test_node["context_expr"].IsDefined()) { + test_case->mutable_input_context()->set_context_expr( + test_node["context_expr"].as()); + } + + YAML::Node input_node = test_node["input"]; + if (input_node.IsDefined() && input_node.IsMap()) { + auto* input_map = test_case->mutable_input(); + for (auto it = input_node.begin(); it != input_node.end(); ++it) { + std::string var_name = it->first.as(); + cel::expr::conformance::test::InputValue input_val; + CEL_RETURN_IF_ERROR(ParseInputValue(it->second, &input_val)); + (*input_map)[var_name] = std::move(input_val); + } + } + + YAML::Node output_node = test_node["output"]; + if (output_node.IsDefined()) { + CEL_RETURN_IF_ERROR( + ParseTestOutput(output_node, test_case->mutable_output())); + } + } + } + + return test_suite; +} + +} // namespace + +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content) { + try { + return ParsePolicyTestSuiteYamlImpl(yaml_content); + } catch (...) { + return absl::InvalidArgumentError("Failed to parse YAML"); + } +} + +} // namespace cel::test diff --git a/policy/test_util.h b/policy/test_util.h new file mode 100644 index 000000000..5fe306050 --- /dev/null +++ b/policy/test_util.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "cel/expr/conformance/test/suite.pb.h" + +namespace cel::test { + +// Parses a YAML content representing a policy test suite (tests.yaml) +// and adapts it to the cel.expr.conformance.test.TestSuite protobuf message. +// +// TODO(uncreated-issue/92): Move to the testrunner library. +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ diff --git a/policy/testdata/BUILD b/policy/testdata/BUILD new file mode 100644 index 000000000..10a26fa0b --- /dev/null +++ b/policy/testdata/BUILD @@ -0,0 +1,19 @@ +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "policy_testdata", + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) + +exports_files( + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) diff --git a/policy/testdata/cel_policy.yaml b/policy/testdata/cel_policy.yaml new file mode 100644 index 000000000..010ad8855 --- /dev/null +++ b/policy/testdata/cel_policy.yaml @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Environment: +# spec: TestAllTypes +name: cel_policy +description: A test policy for CEL +display_name: Cel Policy +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +- name: cel.expr.conformance.proto3.TestAllTypes.NestedEnum +rule: + id: test_rule + description: test rule description + variables: + - name: test_var + expression: > + TestAllTypes{single_int64: 10}.single_int64 + match: + - condition: > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + output: | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + explanation: | + "invalid spec, spec is greater than 10" + - condition: > + spec.standalone_enum == NestedEnum.BAR + output: | + "invalid spec, reference to BAR is not allowed" + - condition: spec.single_int64 == variables.test_var + output: '"invalid spec: exactly matches test_var"' + explanation: '"the spec cannot have single_int64 set to a known bad value"' \ No newline at end of file diff --git a/policy/testdata/cel_policy_parser.baseline b/policy/testdata/cel_policy_parser.baseline new file mode 100644 index 000000000..7a6678bfe --- /dev/null +++ b/policy/testdata/cel_policy_parser.baseline @@ -0,0 +1,89 @@ +POLICY SOURCE: cel_policy.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # Environment: + # spec: TestAllTypes + #0> name: #1> cel_policy + #2> description: #3> A test policy for CEL + #4> display_name: #5> Cel Policy + #6> imports: + - #7> name: #8> cel.expr.conformance.proto3.TestAllTypes + - #9> name: #10> cel.expr.conformance.proto3.TestAllTypes.NestedEnum + #11> rule: + #13> #12> id: #14> test_rule + #15> description: #16> test rule description + #17> variables: + - #18> name: #19> test_var + #20> expression: #21> > + TestAllTypes{single_int64: 10}.single_int64 + #22> match: + - #24> #23> condition: #25> > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + #26> output: #27> | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + #28> explanation: #29> | + "invalid spec, spec is greater than 10" + - #31> #30> condition: #32> > + spec.standalone_enum == NestedEnum.BAR + #33> output: #34> | + "invalid spec, reference to BAR is not allowed" + - #36> #35> condition: #37> spec.single_int64 == variables.test_var + #38> output: #39> '"invalid spec: exactly matches test_var"' + #40> explanation: #41> '"the spec cannot have single_int64 set to a known bad value"' + =========================================================== + name: #1> "cel_policy" + description: #3> "A test policy for CEL" + display_name: #5> "Cel Policy" + imports: + #7> name: #8> "cel.expr.conformance.proto3.TestAllTypes" + #9> name: #10> "cel.expr.conformance.proto3.TestAllTypes.NestedEnum" + #12> rule: { + rule_id: #14> "test_rule" + description: #16> "test rule description" + variable: { + name: #19> "test_var" + expression: #21> "TestAllTypes{single_int64: 10}.single_int64 + " + } + #23> match: { + condition: #25> "spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + " + result: { + output: #27> ""invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + " + explanation: #29> ""invalid spec, spec is greater than 10" + " + } + } + #30> match: { + condition: #32> "spec.standalone_enum == NestedEnum.BAR + " + result: { + output: #34> ""invalid spec, reference to BAR is not allowed" + " + } + } + #35> match: { + condition: #37> "spec.single_int64 == variables.test_var" + result: { + output: #39> ""invalid spec: exactly matches test_var"" + explanation: #41> ""the spec cannot have single_int64 set to a known bad value"" + } + } + } +} diff --git a/policy/testdata/custom_policy_format.yaml b/policy/testdata/custom_policy_format.yaml new file mode 100644 index 000000000..a67356906 --- /dev/null +++ b/policy/testdata/custom_policy_format.yaml @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: test +version: 42 +conditions: +- if: spec.single_string == "none" + then: "'zero'" + else: + if: spec.single_string == "integer" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: "'negative integer'" + else: "'not an integer'" diff --git a/policy/testdata/custom_policy_format_parser.baseline b/policy/testdata/custom_policy_format_parser.baseline new file mode 100644 index 000000000..d5b1a2235 --- /dev/null +++ b/policy/testdata/custom_policy_format_parser.baseline @@ -0,0 +1,75 @@ +POLICY SOURCE: custom_policy_format.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> cel_policy_custom_tags + #2> description: #3> A custom policy format + #4> imports: + - #5> name: #6> cel.expr.conformance.proto3.TestAllTypes + #7> purpose: #8> test + #9> version: #10> 42 + #11> conditions: + - #13> #12> if: #14> spec.single_string == "none" + #15> then: #16> "'zero'" + #17> else: + #19> #18> if: #20> spec.single_string == "integer" + #21> then: + #23> #22> if: #24> spec.single_int32 > 0 + #25> then: #26> "'positive integer'" + #27> else: #29> #28> "'negative integer'" + #30> else: #32> #31> "'not an integer'" + + =========================================================== + name: #1> "cel_policy_custom_tags" + description: #3> "A custom policy format" + metadata: { + purpose: #8> "test" + version: 42 + } + imports: + #5> name: #6> "cel.expr.conformance.proto3.TestAllTypes" + rule: { + #12> match: { + condition: #14> "spec.single_string == "none"" + result: { + output: #16> "'zero'" + } + } + #18> match: { + condition: #20> "spec.single_string == "integer"" + result: + rule: { + #22> match: { + condition: #24> "spec.single_int32 > 0" + result: { + output: #26> "'positive integer'" + } + } + #29> match: { + result: { + output: #28> "'negative integer'" + } + } + } + } + #32> match: { + result: { + output: #31> "'not an integer'" + } + } + } +} diff --git a/policy/testdata/custom_policy_format_with_errors.yaml b/policy/testdata/custom_policy_format_with_errors.yaml new file mode 100644 index 000000000..594747c60 --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors.yaml @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: + - testing +version: new +conditions: +- if: + spec.single_string: "none" + then: "'zero'" + else: "'not zero'" +- if: spec.single_string == "number" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: + - ignore +- else: "'negative integer'" + diff --git a/policy/testdata/custom_policy_format_with_errors_parser.baseline b/policy/testdata/custom_policy_format_with_errors_parser.baseline new file mode 100644 index 000000000..978d27bda --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors_parser.baseline @@ -0,0 +1,16 @@ +POLICY SOURCE: custom_policy_format_with_errors.yaml +-------------------------------------------------------------------- +-------------------------------------------------------------------- +PARSER ISSUES: +ERROR: custom_policy_format_with_errors.yaml:19:3: Policy purpose is not a string + | - testing + | ..^ +ERROR: custom_policy_format_with_errors.yaml:20:10: Policy version is not an integer: new + | version: new + | .........^ +ERROR: custom_policy_format_with_errors.yaml:23:5: Policy 'if' condition is not a string + | spec.single_string: "none" + | ....^ +ERROR: custom_policy_format_with_errors.yaml:31:7: Bad syntax in 'if/then' block + | - ignore + | ......^ diff --git a/policy/testdata/nested_rule.yaml b/policy/testdata/nested_rule.yaml new file mode 100644 index 000000000..2b07faa64 --- /dev/null +++ b/policy/testdata/nested_rule.yaml @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: nested_rule +rule: + variables: + - name: "permitted_regions" + expression: "['us', 'uk', 'es']" + match: + - rule: + id: "banned regions" + description: > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + variables: + - name: "banned_regions" + expression: "{'us': false, 'ru': false, 'ir': false}" + match: + - condition: | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + output: "{'banned': true}" + - condition: resource.origin in variables.permitted_regions + output: "{'banned': false}" + - output: "{'banned': true}" + explanation: "'resource is in the banned region ' + resource.origin" \ No newline at end of file diff --git a/policy/testdata/nested_rule_parser.baseline b/policy/testdata/nested_rule_parser.baseline new file mode 100644 index 000000000..128f81bda --- /dev/null +++ b/policy/testdata/nested_rule_parser.baseline @@ -0,0 +1,84 @@ +POLICY SOURCE: nested_rule.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> nested_rule + #2> rule: + #4> #3> variables: + - #5> name: #6> "permitted_regions" + #7> expression: #8> "['us', 'uk', 'es']" + #9> match: + - #11> #10> rule: + #13> #12> id: #14> "banned regions" + #15> description: #16> > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + #17> variables: + - #18> name: #19> "banned_regions" + #20> expression: #21> "{'us': false, 'ru': false, 'ir': false}" + #22> match: + - #24> #23> condition: #25> | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + #26> output: #27> "{'banned': true}" + - #29> #28> condition: #30> resource.origin in variables.permitted_regions + #31> output: #32> "{'banned': false}" + - #34> #33> output: #35> "{'banned': true}" + #36> explanation: #37> "'resource is in the banned region ' + resource.origin" + =========================================================== + name: #1> "nested_rule" + description: "nested_rule.yaml" + #3> rule: { + variable: { + name: #6> "permitted_regions" + expression: #8> "['us', 'uk', 'es']" + } + #10> match: { + result: + #12> rule: { + rule_id: #14> "banned regions" + description: #16> "determine whether the resource origin is in the banned list. If the region is also in the permitted list, the ban has no effect. + " + variable: { + name: #19> "banned_regions" + expression: #21> "{'us': false, 'ru': false, 'ir': false}" + } + #23> match: { + condition: #25> "resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + " + result: { + output: #27> "{'banned': true}" + } + } + } + } + #28> match: { + condition: #30> "resource.origin in variables.permitted_regions" + result: { + output: #32> "{'banned': false}" + } + } + #33> match: { + result: { + output: #35> "{'banned': true}" + explanation: #37> "'resource is in the banned region ' + resource.origin" + } + } + } +} diff --git a/policy/yaml_policy_parser.cc b/policy/yaml_policy_parser.cc new file mode 100644 index 000000000..c838cff33 --- /dev/null +++ b/policy/yaml_policy_parser.cc @@ -0,0 +1,411 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +CelPolicyElementId YamlPolicyParser::CollectMetadata( + CelPolicyParseContext& ctx, const YAML::Node& node) const { + CelPolicyElementId element_id = ctx.next_element_id(); + if (!node.Mark().is_null()) { + ctx.policy_source().NoteSourcePosition(element_id, node.Mark().pos); + } + return element_id; +} + +std::optional YamlPolicyParser::GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const { + if (!node.IsDefined()) { + // This should never happen since the YAML syntax has already been checked. + return std::nullopt; + } + + CelPolicyElementId id = CollectMetadata(ctx, node); + if (!node.IsScalar()) { + ctx.ReportError(id, error_message); + return std::nullopt; + } + + try { + return ValueString(id, node.as()); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return std::nullopt; + } +} + +absl::Status YamlPolicyParser::ParsePolicy(CelPolicyParseContext& ctx) const { + const Source* source = ctx.policy_source().content(); + if (source == nullptr) { + return absl::OkStatus(); + } + + ctx.policy().set_description(ValueString(-1, source->description())); + std::string text = source->content().ToString(); + YAML::Node node; + try { + node = YAML::Load(text); + } catch (YAML::Exception& e) { + if (!e.mark.is_null()) { + ctx.policy_source().NoteSourcePosition(0, e.mark.pos); + } + ctx.ReportError(0, "Invalid CEL policy YAML syntax"); + return absl::OkStatus(); + } + + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy is not a map"); + return absl::OkStatus(); + } + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, ParsePolicyTag(ctx, *key, value_node)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized top-level policy tag: ", key->value())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParsePolicyTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node) const { + if (tag_name.value() == "imports") { + CEL_RETURN_IF_ERROR(ParseImports(ctx, node)); + return true; + } + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy 'name' is not a string"); + if (name.has_value()) { + ctx.policy().set_name(*name); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy 'description' is not a string"); + if (description.has_value()) { + ctx.policy().set_description(*description); + } + return true; + } + if (tag_name.value() == "display_name") { + std::optional display_name = + GetValueString(ctx, node, "Policy 'display_name' is not a string"); + if (display_name.has_value()) { + ctx.policy().set_display_name(*display_name); + } + return true; + } + if (tag_name.value() == "rule") { + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, ctx.policy().mutable_rule())); + return true; + } + return false; +} + +absl::Status YamlPolicyParser::ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy 'imports' is not a sequence"); + return absl::OkStatus(); + } + + for (const YAML::Node& import : node) { + CelPolicyElementId import_id = CollectMetadata(ctx, import); + if (!import.IsMap()) { + ctx.ReportError(import_id, "Import is not a map"); + continue; + } + const YAML::Node& name_node = import["name"]; + if (!name_node.IsDefined()) { + ctx.ReportError(import_id, "No 'name' tag in import"); + continue; + } + std::optional import_name = + GetValueString(ctx, name_node, "Import name is not a string"); + if (import_name.has_value()) { + ctx.policy().mutable_imports().push_back(Import(import_id, *import_name)); + } + } + return absl::OkStatus(); +} + +absl::Status YamlPolicyParser::ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy 'rule' is not a map"); + return absl::OkStatus(); + } + rule.set_id(CollectMetadata(ctx, node)); + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy rule tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseRuleTag(ctx, *key, value_node, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy rule tag: ", + key->value())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const { + if (tag_name.value() == "id") { + std::optional rule_id = + GetValueString(ctx, node, "Policy rule 'id' is not a string"); + if (rule_id.has_value()) { + rule.set_rule_id(*rule_id); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy rule 'description' is not a string"); + if (description.has_value()) { + rule.set_description(*description); + } + return true; + } + if (tag_name.value() == "variables") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variables' is not a sequence"); + return true; + } + for (const YAML::Node& variable_node : node) { + CEL_ASSIGN_OR_RETURN(Variable variable, + ParseVariable(ctx, variable_node, rule)); + rule.mutable_variables().push_back(std::move(variable)); + } + return true; + } + if (tag_name.value() == "match") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'match' is not a sequence"); + return true; + } + for (const YAML::Node& match_node : node) { + CEL_ASSIGN_OR_RETURN(Match match, ParseMatch(ctx, match_node, rule)); + rule.mutable_matches().push_back(std::move(match)); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseVariable( + CelPolicyParseContext& ctx, const YAML::Node& node, Rule& rule) const { + Variable variable; + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variable' is not a map"); + return variable; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy variable tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseVariableTag(ctx, *key, value_node, variable)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized policy variable tag: ", key->value())); + } + } + return variable; +} + +absl::StatusOr YamlPolicyParser::ParseVariableTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Variable& variable) const { + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy variable 'name' is not a string"); + if (name.has_value()) { + variable.set_name(*name); + } + return true; + } + if (tag_name.value() == "expression") { + std::optional expression = GetValueString( + ctx, node, "Policy variable 'expression' is not a string"); + if (expression.has_value()) { + variable.set_expression(*expression); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + Match match; + match.set_id(CollectMetadata(ctx, node)); + if (!node.IsMap()) { + ctx.ReportError(match.id(), "Policy rule 'match' is not a map"); + return match; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy match tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseMatchTag(ctx, *key, value_node, match, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy match tag: ", + key->value())); + } + } + + if (match.has_output_block()) { + if (match.output_block().output().value().empty() && + match.output_block().explanation().has_value()) { + ctx.ReportError(match.id(), "Match specifies explanation but no output"); + } + } + + return match; +} + +absl::StatusOr YamlPolicyParser::ParseMatchTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Match& match, Rule& rule) const { + if (tag_name.value() == "condition") { + std::optional condition = + GetValueString(ctx, node, "Policy match 'condition' is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "explanation") { + std::optional explanation = + GetValueString(ctx, node, "Policy match 'explanation' is not a string"); + if (explanation.has_value()) { + if (match.has_rule()) { + ctx.ReportError( + tag_name.id(), + "Cannot specify explanation when a nested rule is present"); + } else { + match.mutable_output_block().set_explanation(*explanation); + } + } + return true; + } + if (tag_name.value() == "output") { + std::optional output = + GetValueString(ctx, node, "Policy match 'output' is not a string"); + if (output.has_value()) { + if (match.has_rule()) { + ctx.ReportError(tag_name.id(), + "Cannot specify output when a nested rule is present"); + } else { + match.mutable_output_block().set_output(*output); + } + } + return true; + } + if (tag_name.value() == "rule") { + if (match.has_output_block()) { + ctx.ReportError(tag_name.id(), + "Cannot specify nested rule when output/explanation is " + "present"); + } + auto nested_rule = std::make_unique(); + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, *nested_rule)); + match.set_result(std::move(nested_rule)); + return true; + } + return false; +} + +const CelPolicyParser& GetDefaultYamlPolicyParser() { + static const auto* const parser = new YamlPolicyParser(); + return *parser; +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source) { + return ParseYamlCelPolicy(std::move(policy_source), + GetDefaultYamlPolicyParser()); +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser) { + CelPolicyParseContext ctx(std::move(policy_source)); + CEL_RETURN_IF_ERROR(parser.ParsePolicy(ctx)); + return ctx.GetResult(); +} + +} // namespace cel diff --git a/policy/yaml_policy_parser.h b/policy/yaml_policy_parser.h new file mode 100644 index 000000000..469209333 --- /dev/null +++ b/policy/yaml_policy_parser.h @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +// A parser for YAML-based CEL policies. +// +// To support additional or alternative YAML elements, subclass +// `YamlPolicyParser` and override specific parsing methods, `Parse*` +class YamlPolicyParser : public CelPolicyParser { + public: + std::optional GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const; + + absl::Status ParsePolicy(CelPolicyParseContext& ctx) const override; + + protected: + // Collects metadata (e.g. source position) for the given YAML node, stores it + // in the context, and returns an ID that can be used to refer to it. + virtual CelPolicyElementId CollectMetadata(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a top-level tag in the policy YAML. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const; + + // Parses the imports section of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a rule element of the policy YAML, which may be the top-level rule + // or a sub-rule of a match. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, Rule& rule) const; + + // Parses a tag in a policy YAML rule. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const; + + // Parses a variable element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariable(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML variable. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariableTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Variable& variable) const; + + // Parses a match element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML match. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Match& match, Rule& rule) const; +}; + +// Returns a default implementation of YamlPolicyParser. +const CelPolicyParser& GetDefaultYamlPolicyParser(); + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser); + +// YAML CelPolicy parser that uses the default format as implemented by +// `YamlPolicyParser`. +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ diff --git a/policy/yaml_policy_parser_test.cc b/policy/yaml_policy_parser_test.cc new file mode 100644 index 000000000..4e7dfc49c --- /dev/null +++ b/policy/yaml_policy_parser_test.cc @@ -0,0 +1,305 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +namespace internal { +const CelPolicyParser& GetTestCustomYamlPolicyParser(); +} // namespace internal + +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsNull; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/"; + +constexpr absl::string_view kBaselineSeparator = + "--------------------------------------------------------------------\n"; + +struct YamlPolicyParserTestCase { + std::string policy_source_file; + std::string baseline_file; + const cel::CelPolicyParser& (*parser_factory)(); +}; + +using YamlPolicyParserTest = testing::TestWithParam; + +TEST_P(YamlPolicyParserTest, Parse) { + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().policy_source_file)); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + std::string baseline; + std::string baseline_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().baseline_file)); + ASSERT_THAT(cel::internal::GetFileContents(baseline_file, &baseline), IsOk()); + baseline = absl::StripAsciiWhitespace(baseline); + + std::ostringstream out; + out << "POLICY SOURCE: " << GetParam().policy_source_file << "\n"; + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, GetParam().policy_source_file)); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + + ASSERT_OK_AND_ASSIGN( + CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source, GetParam().parser_factory())); + + out << kBaselineSeparator; + if (parse_result.IsValid()) { + out << "PARSED POLICY:\n"; + out << parse_result.GetPolicy()->DebugString(); + } else { + ASSERT_THAT(parse_result.GetPolicy(), IsNull()); + out << kBaselineSeparator; + out << "PARSER ISSUES:\n"; + for (const auto& issue : parse_result.GetIssues()) { + out << issue.ToDisplayString(*policy_source) << "\n"; + } + } + + std::string actual(absl::StripAsciiWhitespace(out.str())); + if (actual != baseline) { + // Log the actual result to make it easier to copy/paste into the baseline + // file when updating the tests. + ABSL_LOG(INFO) << "Actual:\n" << actual; + EXPECT_EQ(actual, baseline); + } +} + +INSTANTIATE_TEST_SUITE_P( + Formats, YamlPolicyParserTest, + testing::ValuesIn({ + YamlPolicyParserTestCase{ + .policy_source_file = "cel_policy.yaml", + .baseline_file = "cel_policy_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "nested_rule.yaml", + .baseline_file = "nested_rule_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format.yaml", + .baseline_file = "custom_policy_format_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format_with_errors.yaml", + .baseline_file = "custom_policy_format_with_errors_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + })); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +using YamlPolicyParseErrorTest = testing::TestWithParam; + +TEST_P(YamlPolicyParseErrorTest, YamlSyntaxError) { + const ParseTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(param.yaml, "test")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + EXPECT_THAT(parse_result.FormattedIssues(), HasSubstr(param.expected_error)); +} + +std::vector GetParseTestCases() { + return { + ParseTestCase{ + .yaml = R"yaml( ? [ John, Doe ]: age: 30 )yaml", + .expected_error = "1:22: Invalid CEL policy YAML syntax\n" + " | ? [ John, Doe ]: age: 30 \n" + " | .....................^", + }, + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Policy is not a map\n" + " | invalid yaml \n" + " | .^", + }, + ParseTestCase{ + .yaml = R"yaml( + ? [1, 2, 3] + : "Prime numbers sequence" + )yaml", + .expected_error = "2:23: Policy tag is not a string\n" + " | ? [1, 2, 3]\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: N/A + )yaml", + .expected_error = "2:28: Policy 'imports' is not a sequence\n" + " | imports: N/A\n" + " | ...........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - cel.expr.conformance + )yaml", + .expected_error = "3:21: Import is not a map\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - name: + - cel.expr.conformance + )yaml", + .expected_error = "4:21: Import name is not a string\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: do something + )yaml", + .expected_error = "2:25: Policy 'rule' is not a map\n" + " | rule: do something\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + id: + - 22 + )yaml", + .expected_error = "4:21: Policy rule 'id' is not a string\n" + " | - 22\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + no vars + )yaml", + .expected_error = "4:23: Policy rule 'variables' is not a sequence\n" + " | no vars\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: + foo: bar + )yaml", + .expected_error = "5:25: Policy variable 'name' is not a string\n" + " | foo: bar\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: test_var + expression: + - 22 + )yaml", + .expected_error = + "6:23: Policy variable 'expression' is not a string\n" + " | - 22\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: '\u0041\u00a9\u20ac\U0001f680' + - '\u0041\u00a9\u20ac\U0001f680': name + )yaml", + .expected_error = + "5:23: Unrecognized policy variable tag: " + "\\u0041\\u00a9\\u20ac\\U0001f680\n" + " | - '\\u0041\\u00a9\\u20ac\\U0001f680': " + "name\n" + " | ......................^", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(YamlPolicyParseErrorTest, YamlPolicyParseErrorTest, + ::testing::ValuesIn(GetParseTestCases())); + +TEST(YamlPolicyParserTest, OffsetIssueFormatting) { + // TODO(b/506179116): will need to copy the go implementation in extracting + // the source string from the YAML document instead of the interpreted string + // value to fix up error locations in folded and block literals. + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, "cel_policy.yaml")); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, "cel_policy.yaml")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + CelPolicyElementId name_id = policy->name().id(); + + CelPolicyIssue issue(name_id, 4, CelPolicyIssue::Severity::kError, + "Test error"); + + std::string formatted = issue.ToDisplayString(*policy_source); + + EXPECT_THAT(formatted, HasSubstr("ERROR: cel_policy.yaml:16:11: Test error")); + EXPECT_THAT(formatted, HasSubstr(" | name: cel_policy")); + EXPECT_THAT(formatted, HasSubstr(" | ..........^")); +} + +} // namespace +} // namespace cel