diff --git a/.gitignore b/.gitignore index 6ea7b8b..a096f13 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,7 @@ src/Backend/test_data/json !src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv !src/Backend/test_data/csv/intergration_test_data_1.csv !src/Backend/test_data/csv/intergration_test_data_2.csv +!src/Backend/test_data/substrait_plans/** # allow parquet file !src/Backend/test_data/parquet/ !src/Backend/test_data/parquet/capitals_clean.parquet \ No newline at end of file diff --git a/README.md b/README.md index 51f4799..dcbe178 100644 --- a/README.md +++ b/README.md @@ -6,19 +6,21 @@ A high-performance, in-memory query execution engine. ![Rust Tests](https://github.com/Rich-T-kid/OptiSQL/actions/workflows/rust-test.yml/badge.svg) ![Frontend Tests](https://github.com/Rich-T-kid/OptiSQL/actions/workflows/frontend-test.yml/badge.svg) - ## Overview OptiSQL is a custom in-memory query execution engine. The backend (physical execution) is built using golang and rust.The front end (query parsing & optimization) is built using C++. **Technologies:** + - Go/Rust (physical optimizer, operators) - Substrait (logical/physical plan representation) - C++ (query parser & optimizer) - ect (make,git,s3) + ## Getting Started ### Prerequisites + - Go 1.24+ - Rust 1.70+ - C++23 @@ -83,6 +85,7 @@ OptiSQL/ Initial development is done in **Go** (`opti-sql-go`), which serves as the primary implementation. The **Rust** version (`opti-sql-rs`) is developed shortly after as a learning exercise and eventual performance-optimized alternative, closely mirroring the Go implementation. **Key Directories:** + - `/operators` - SQL operator implementations (filter, join, aggregation, project) - `/physical-optimizer` - Query plan parsing and optimization - `/substrait` - Substrait plan integration @@ -102,6 +105,7 @@ We use a structured branching model to maintain stability and enable smooth coll This approach prevents unstable code from reaching `main`, simplifies rollbacks, and ensures all changes undergo proper testing and review before deployment. Feature branches isolate work, allowing focused reviews and parallel development without conflicts. The `pre-release` branch acts as a staging area where features are bundled together before being released as a new version. **Workflow:** + 1. Create a feature branch from `pre-release` 2. Implement your changes with tests 3. Open a PR to merge into `pre-release` @@ -112,6 +116,7 @@ This approach prevents unstable code from reaching `main`, simplifies rollbacks, ### Code Quality All code quality checks are automated and enforced by CI: + - **Linting** - `golangci-lint` (Go), `clippy` (Rust) - **Formatting** - `go fmt` (Go), `cargo fmt` (Rust) - **Testing** - Unit tests required for all new code @@ -133,15 +138,55 @@ Before pushing, verify your changes pass all checks: make pre-push ``` +## How to build + +```bash +docker buildx build \ + --platform linux/amd64 \ + -t rich239/execution-engine:0.9.4 \ + -t rich239/execution-engine:latest \ + --push \ + . + +``` + +## How to run + +```bash +docker pull rich239/execution-engine +docker run -p 7024:7024 rich239/execution-engine +``` + +## Example GRPC body + +```bash +{ + "id": "97b61a8f-ffe1-4e4a-b6d7-73619698dc7a", + "sql_statement": "select * from table1 where id > 10", + "logical_plan": "ewogICAgIkVtaXQiOiAKICAgIHsKICAgICAgICAiT3BlcmF0b3IiOiAiRmlsdGVyIiwKICAgICAgICAiRmlsdGVyIjogCiAgICAgICAgewogICAgICAgICAgICAiaW5wdXQiOiAKICAgICAgICAgICAgewogICAgICAgICAgICAgICAgIk9wZXJhdG9yIjogIlNvdXJjZSIsCiAgICAgICAgICAgICAgICAiU291cmNlIjogCiAgICAgICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAgICAgImZpbGUtbmFtZSI6ICJ1c2VyX3Rlc3RfZGF0YS5jc3YiLAogICAgICAgICAgICAgICAgICAgICJsb2NhbCI6IGZhbHNlCiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCiAgICAgICAgICAgICJleHByZXNzaW9uIjogCiAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICJleHByX3R5cGUiOiAiQmluYXJ5RXhwciIsCiAgICAgICAgICAgICAgICAib3AiOiAiR3JlYXRlclRoYW4iLAogICAgICAgICAgICAgICAgImxlZnQiOiAKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkNvbHVtblJlc29sdmUiLAogICAgICAgICAgICAgICAgICAgICJuYW1lIjogImFnZV95ZWFycyIKICAgICAgICAgICAgICAgIH0sCiAgICAgICAgICAgICAgICAicmlnaHQiOiAKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkxpdGVyYWxSZXNvbHZlIiwKICAgICAgICAgICAgICAgICAgICAidmFsdWUiOiAxMCwKICAgICAgICAgICAgICAgICAgICAibGl0X3R5cGUiOiAiaW50IgogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQogICAgfQp9" +} +``` + This runs formatting, linting, and all tests. ## Contributing Want to contribute? Check out [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines on: + - Writing and running tests - PR format and commit message conventions - Development workflow and tooling - Build and run instructions ## License -This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). \ No newline at end of file + +This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). + +docker buildx build \ + --platform linux/amd64 \ + -t rich239/execution-engine:0.9.5 \ ## bump major/minor +-t rich239/execution-engine:latest \ + --push \ + . + +# TODO: remove env stuff diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index 4899a15..e10af17 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "opti-sql-go/config" "opti-sql-go/operators" "regexp" "strings" @@ -13,6 +14,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/compute" "github.com/apache/arrow/go/v17/arrow/memory" + "go.uber.org/zap" ) var ( @@ -24,35 +26,35 @@ var ( } ) -type binaryOperator int +type BinaryOperator int const ( // arithmetic - Addition binaryOperator = 1 - Subtraction binaryOperator = 2 - Multiplication binaryOperator = 3 - Division binaryOperator = 4 + Addition BinaryOperator = 1 + Subtraction BinaryOperator = 2 + Multiplication BinaryOperator = 3 + Division BinaryOperator = 4 // comparison - Equal binaryOperator = 6 - NotEqual binaryOperator = 7 - LessThan binaryOperator = 8 - LessThanOrEqual binaryOperator = 9 - GreaterThan binaryOperator = 10 - GreaterThanOrEqual binaryOperator = 11 + Equal BinaryOperator = 6 + NotEqual BinaryOperator = 7 + LessThan BinaryOperator = 8 + LessThanOrEqual BinaryOperator = 9 + GreaterThan BinaryOperator = 10 + GreaterThanOrEqual BinaryOperator = 11 // logical - And binaryOperator = 12 - Or binaryOperator = 13 + And BinaryOperator = 12 + Or BinaryOperator = 13 // RegEx expressions - Like binaryOperator = 14 // where column_name like "patte%n_with_wi%dcard_" + Like BinaryOperator = 14 // where column_name like "patte%n_with_wi%dcard_" ) -type supportedFunctions int +type SupportedFunctions int const ( - Upper supportedFunctions = 1 - Lower supportedFunctions = 2 - Abs supportedFunctions = 3 - Round supportedFunctions = 4 + Upper SupportedFunctions = 1 + Lower SupportedFunctions = 2 + Abs SupportedFunctions = 3 + Round SupportedFunctions = 4 ) type aggFunctions = int @@ -91,6 +93,19 @@ type Expression interface { fmt.Stringer } +// To_aggr_name extracts the column name from an expression for use in aggregation schema building. +// Returns the alias name if present, otherwise the column name. +func To_aggr_name(expr Expression) string { + switch e := expr.(type) { + case *ColumnResolve: + return e.Name + case *Alias: + return e.Name + default: + return expr.String() + } +} + func EvalExpression(expr Expression, batch *operators.RecordBatch) (arrow.Array, error) { switch e := expr.(type) { case *Alias: @@ -199,7 +214,7 @@ func EvalColumn(c *ColumnResolve, batch *operators.RecordBatch) (arrow.Array, er for i, f := range batch.Schema.Fields() { if f.Name == c.Name { col := batch.Columns[i] - col.Retain() + //col.Retain() return col, nil } } @@ -213,8 +228,7 @@ func (c *ColumnResolve) String() string { // Evaluates to a column of length = batch-size, filled with this literal. // sql: select 1 type LiteralResolve struct { - Type arrow.DataType - // dont forget to cast the value. so string("hello") not just "hello" + Type arrow.DataType Value any } @@ -425,11 +439,11 @@ func (l *LiteralResolve) String() string { type BinaryExpr struct { Left Expression - Op binaryOperator + Op BinaryOperator Right Expression } -func NewBinaryExpr(left Expression, op binaryOperator, right Expression) *BinaryExpr { +func NewBinaryExpr(left Expression, op BinaryOperator, right Expression) *BinaryExpr { return &BinaryExpr{ Left: left, Op: op, @@ -438,6 +452,7 @@ func NewBinaryExpr(left Expression, op binaryOperator, right Expression) *Binary } func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error) { + logger := config.GetLogger() leftArr, err := EvalExpression(b.Left, batch) if err != nil { return nil, err @@ -446,6 +461,11 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error if err != nil { return nil, err } + logger.Debug("Evaluating binary expression", + zap.String("operator", fmt.Sprintf("%v", b.Op)), + zap.Int("left_len", leftArr.Len()), + zap.Int("right_len", rightArr.Len()), + ) ctx := context.Background() opt := compute.ArithmeticOptions{} switch b.Op { @@ -578,11 +598,11 @@ func unpackDatum(d compute.Datum) (arrow.Array, error) { } type ScalarFunction struct { - Function supportedFunctions + Function SupportedFunctions Arguments Expression // resolve to something you can process IE, literal/coloumn Resolve } -func NewScalarFunction(function supportedFunctions, Argument Expression) *ScalarFunction { +func NewScalarFunction(function SupportedFunctions, Argument Expression) *ScalarFunction { return &ScalarFunction{ Function: function, Arguments: Argument, @@ -736,7 +756,7 @@ func lowerImpl(arr arrow.Array) (arrow.Array, error) { return b.NewArray(), nil } } -func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arrow.DataType { +func inferScalarFunctionType(fn SupportedFunctions, argType arrow.DataType) arrow.DataType { switch fn { case Upper, Lower: @@ -753,7 +773,7 @@ func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arro } } -func inferBinaryType(left arrow.DataType, op binaryOperator, right arrow.DataType) arrow.DataType { +func inferBinaryType(left arrow.DataType, op BinaryOperator, right arrow.DataType) arrow.DataType { switch op { case Addition, Subtraction, Multiplication, Division: @@ -816,3 +836,57 @@ func validRegEx(columnValue, regExExpr string) bool { return ok } +func FnToScalarFunction(s string) SupportedFunctions { + switch s { + case "Upper": + return 1 + case "Lower": + return 2 + case "Abs": + return 3 + case "Round": + return 4 + } + return 1 +} + +// matchesBinaryOperator returns true if `name` matches the binaryOperator constant +// represented by `opInt`, using ONLY the exact names in your const block. +func MatchesBinaryOperator(name string, opInt int) bool { + want := BinaryOperator(opInt) + + switch name { + case "Addition": + return want == Addition + case "Subtraction": + return want == Subtraction + case "Multiplication": + return want == Multiplication + case "Division": + return want == Division + + case "Equal": + return want == Equal + case "NotEqual": + return want == NotEqual + case "LessThan": + return want == LessThan + case "LessThanOrEqual": + return want == LessThanOrEqual + case "GreaterThan": + return want == GreaterThan + case "GreaterThanOrEqual": + return want == GreaterThanOrEqual + + case "And": + return want == And + case "Or": + return want == Or + + case "Like": + return want == Like + + default: + return false + } +} diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go index f0d2f43..81570b2 100644 --- a/src/Backend/opti-sql-go/Expr/expr_test.go +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -1117,7 +1117,7 @@ func TestInferScalarFunctionType(t *testing.T) { t.Fatalf("expected panic for unknown function, got none") } }() - _ = inferScalarFunctionType(supportedFunctions(9999), arrow.PrimitiveTypes.Int32) + _ = inferScalarFunctionType(SupportedFunctions(9999), arrow.PrimitiveTypes.Int32) }) } diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index 17154fe..7dbdc45 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "go.uber.org/zap" "gopkg.in/yaml.v3" ) @@ -27,6 +28,7 @@ type serverConfig struct { Host string `yaml:"host"` Timeout int `yaml:"timeout"` MaxRequestSizeMB uint64 `yaml:"max_request_size_mb"` // max size of a file upload. passed in by grpc request + RedisAddr string `yaml:"redis_addr"` } type batchConfig struct { Size int `yaml:"size"` @@ -63,10 +65,11 @@ type secretesConfig struct { var configInstance *Config = &Config{ Server: serverConfig{ - Port: 8080, - Host: "localhost", + Port: 7024, + Host: "0.0.0.0", Timeout: 30, MaxRequestSizeMB: 15, + RedisAddr: "104.236.210.9", }, Batch: batchConfig{ Size: 1024 * 8, // rows per bathch @@ -106,20 +109,31 @@ func GetConfig() *Config { // overwrite global instance with loaded config func Decode(filePath string) error { + logger := GetLogger() + logger.Info("Loading config file", zap.String("file_path", filePath)) + suffix := strings.Split(filePath, ".")[len(strings.Split(filePath, "."))-1] if suffix != "yaml" && suffix != "yml" { + logger.Error("Invalid config file extension", zap.String("extension", suffix)) return errors.New("file must be a .yaml or .yml file") } r, err := os.Open(filePath) if err != nil { + logger.Error("Failed to open config file", zap.Error(err), zap.String("file_path", filePath)) return err } config := make(map[string]interface{}) decoder := yaml.NewDecoder(r) if err := decoder.Decode(config); err != nil { + logger.Error("Failed to decode YAML config", zap.Error(err)) return fmt.Errorf("failed to decode config: %w", err) } + logger.Info("Config file decoded successfully") mergeConfig(configInstance, config) + logger.Info("Config merged successfully", + zap.Int("server_port", configInstance.Server.Port), + zap.Int("batch_size", configInstance.Batch.Size), + ) return nil } func mergeConfig(dst *Config, src map[string]interface{}) { @@ -139,6 +153,9 @@ func mergeConfig(dst *Config, src map[string]interface{}) { if v, ok := server["max_request_size_mb"].(int); ok { dst.Server.MaxRequestSizeMB = uint64(v) } + if v, ok := server["redis_addr"].(string); ok { + dst.Server.RedisAddr = v + } } // ============================= diff --git a/src/Backend/opti-sql-go/config/log.go b/src/Backend/opti-sql-go/config/log.go new file mode 100644 index 0000000..5536fd1 --- /dev/null +++ b/src/Backend/opti-sql-go/config/log.go @@ -0,0 +1,48 @@ +package config + +import ( + "sync" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const logFileName = "App.log" + +var ( + globLogger *zap.Logger + once sync.Once +) + +func GetLogger() *zap.Logger { + once.Do(func() { + globLogger = createLogger() + }) + return globLogger +} + +func createLogger() *zap.Logger { + encoderCfg := zap.NewProductionEncoderConfig() + encoderCfg.TimeKey = "timestamp" + encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder + + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.InfoLevel), + Development: false, + DisableCaller: false, + DisableStacktrace: false, + Sampling: nil, + Encoding: "json", + EncoderConfig: encoderCfg, + OutputPaths: []string{ + "stdout", + logFileName, + }, + ErrorOutputPaths: []string{ + "stdout", + }, + InitialFields: map[string]any{}, + } + + return zap.Must(config.Build()) +} diff --git a/src/Backend/opti-sql-go/config/log_test.go b/src/Backend/opti-sql-go/config/log_test.go new file mode 100644 index 0000000..c31b830 --- /dev/null +++ b/src/Backend/opti-sql-go/config/log_test.go @@ -0,0 +1,17 @@ +package config + +import ( + "reflect" + "testing" + + "go.uber.org/zap" +) + +func TestLoggerInit(t *testing.T) { + original := GetLogger() + for i := range 100 { + l := GetLogger() + t.Logf("equal to original\t%v\n", reflect.DeepEqual(original, l)) + l.Info("msg:", zap.Int("loggers generated", i)) + } +} diff --git a/src/Backend/opti-sql-go/dockerfile b/src/Backend/opti-sql-go/dockerfile new file mode 100644 index 0000000..2f46af8 --- /dev/null +++ b/src/Backend/opti-sql-go/dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.24 + +WORKDIR /app + +COPY go.mod go.sum ./ + +RUN go mod download + +COPY . . + +RUN go build -o execution-engine . + +EXPOSE 7024 + +CMD ["./execution-engine"] \ No newline at end of file diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index 5b872b6..10bd61f 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -27,6 +27,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect github.com/aws/smithy-go v1.23.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/goccy/go-json v0.10.3 // indirect @@ -41,7 +43,10 @@ require ( github.com/minio/minio-go v6.0.14+incompatible // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/redis/go-redis/v9 v9.17.3 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.1 // indirect golang.org/x/crypto v0.24.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.18.0 // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 7c4ee5c..50406cf 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -32,9 +32,13 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.90.2 h1:DhdbtDl4FdNlj31+xiRXANxEE+eC7 github.com/aws/aws-sdk-go-v2/service/s3 v1.90.2/go.mod h1:+wArOOrcHUevqdto9k1tKOF5++YTe9JEcPSc9Tx2ZSw= github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= @@ -70,6 +74,8 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= +github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -77,6 +83,10 @@ github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= diff --git a/src/Backend/opti-sql-go/operators/Join/hashJoin.go b/src/Backend/opti-sql-go/operators/Join/hashJoin.go index 13a6969..3fc9249 100644 --- a/src/Backend/opti-sql-go/operators/Join/hashJoin.go +++ b/src/Backend/opti-sql-go/operators/Join/hashJoin.go @@ -8,6 +8,7 @@ import ( "io" "math" "opti-sql-go/Expr" + "opti-sql-go/config" "opti-sql-go/operators" "strings" @@ -15,10 +16,9 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/compute" "github.com/apache/arrow/go/v17/arrow/memory" + "go.uber.org/zap" ) -// TODO: see ticket #27 - var ( ErrInvalidJoinClauseCount = func(l, r int) error { return fmt.Errorf("mismatched number of join expressions between left and right, left: %d vs right: %d", l, r) @@ -53,23 +53,23 @@ func (j JoinType) String() string { // taking in arrays of expressions allows for multiple join clauses // Example: JOIN t2 ON t1.region = t2.region AND t1.city = t2.city type JoinClause struct { - leftS []Expr.Expression - rightS []Expr.Expression + LeftS []Expr.Expression + RightS []Expr.Expression } func (j *JoinClause) String() string { var b bytes.Buffer // defensive: if lengths differ, print whatever pairs exist - n := len(j.leftS) - if len(j.rightS) < n { - n = len(j.rightS) + n := len(j.LeftS) + if len(j.RightS) < n { + n = len(j.RightS) } for i := 0; i < n; i++ { - b.WriteString(j.leftS[i].String()) + b.WriteString(j.LeftS[i].String()) b.WriteString(" = ") - b.WriteString(j.rightS[i].String()) + b.WriteString(j.RightS[i].String()) // add separator between pairs if i < n-1 { @@ -82,8 +82,8 @@ func (j *JoinClause) String() string { func NewJoinClause(leftS, rightS []Expr.Expression) JoinClause { return JoinClause{ - leftS: leftS, - rightS: rightS, + LeftS: leftS, + RightS: rightS, } } @@ -163,8 +163,8 @@ func NewHashJoinExec(left operators.Operator, right operators.Operator, clause J if err != nil { return nil, err } - if len(clause.leftS) != len(clause.rightS) { - return nil, ErrInvalidJoinClauseCount(len(clause.leftS), len(clause.rightS)) + if len(clause.LeftS) != len(clause.RightS) { + return nil, ErrInvalidJoinClauseCount(len(clause.LeftS), len(clause.RightS)) } return &HashJoinExec{ leftSource: left, @@ -178,9 +178,11 @@ func NewHashJoinExec(left operators.Operator, right operators.Operator, clause J } func (hj *HashJoinExec) Next(_ uint16) (*operators.RecordBatch, error) { + logger := config.GetLogger() if hj.done { return nil, io.EOF } + logger.Debug("Hash join starting", zap.String("join_type", hj.joinType.String()), zap.Int("join_conditions", len(hj.clause.LeftS))) mem := memory.NewGoAllocator() leftArr, err := consumeOperator(hj.leftSource, mem) if err != nil { @@ -201,18 +203,20 @@ func (hj *HashJoinExec) Next(_ uint16) (*operators.RecordBatch, error) { } leftRowCount := leftArr[0].Len() rightRowCount := rightArr[0].Len() - leftComp, err := buildComptables(hj.clause.leftS, leftArr, hj.leftSource.Schema()) + leftComp, err := buildComptables(hj.clause.LeftS, leftArr, hj.leftSource.Schema()) if err != nil { return nil, err } - rightComp, err := buildComptables(hj.clause.rightS, rightArr, hj.rightSource.Schema()) + rightComp, err := buildComptables(hj.clause.RightS, rightArr, hj.rightSource.Schema()) if err != nil { return nil, err } ht := buildRightHashTable(rightComp, rightRowCount) + logger.Debug("Hash table built", zap.Int("left_rows", leftRowCount), zap.Int("right_rows", rightRowCount)) pairs := probeJoin(leftComp, ht, leftRowCount) if len(pairs) == 0 { + logger.Debug("Join produced no matching pairs") hj.done = true return &operators.RecordBatch{ Schema: hj.Schema(), @@ -229,6 +233,7 @@ func (hj *HashJoinExec) Next(_ uint16) (*operators.RecordBatch, error) { if err != nil { return nil, err } + logger.Info("Hash join complete", zap.Int("output_rows", outArr[0].Len()), zap.Int("matching_pairs", len(pairs))) hj.done = true return &operators.RecordBatch{ Schema: hj.schema, @@ -249,6 +254,9 @@ func (hj *HashJoinExec) Close() error { } return nil } +func (hj *HashJoinExec) Name() string { + return "Join" +} func consumeOperator(o operators.Operator, mem memory.Allocator) ([]arrow.Array, error) { diff --git a/src/Backend/opti-sql-go/operators/Join/hashJoin_test.go b/src/Backend/opti-sql-go/operators/Join/hashJoin_test.go index e22872f..dcdc381 100644 --- a/src/Backend/opti-sql-go/operators/Join/hashJoin_test.go +++ b/src/Backend/opti-sql-go/operators/Join/hashJoin_test.go @@ -9,9 +9,9 @@ import ( "strings" "testing" - "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" ) func generateDataset1WithNulls(mem memory.Allocator) ([]string, []arrow.Array) { diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 7ca86ea..1be8b89 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -5,12 +5,14 @@ import ( "fmt" "io" "opti-sql-go/Expr" + "opti-sql-go/config" "opti-sql-go/operators" "strings" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/memory" + "go.uber.org/zap" ) /* @@ -36,10 +38,15 @@ type GroupByExec struct { } func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, groupBy []Expr.Expression) (*GroupByExec, error) { + logger := config.GetLogger() s, err := buildGroupBySchema(child.Schema(), groupBy, groupExpr) if err != nil { return nil, err } + logger.Info("GroupBy schema created", + zap.Strings("input_columns", operators.GetSchemaFieldNames(child.Schema())), + zap.Strings("output_columns", operators.GetSchemaFieldNames(s)), + ) return &GroupByExec{ input: child, @@ -55,9 +62,11 @@ func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, gr grab child rows */ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { + logger := config.GetLogger() if g.done { return nil, io.EOF } + logger.Debug("GroupBy operator starting", zap.Int("num_group_by_cols", len(g.groupByExpr)), zap.Int("num_aggregations", len(g.groupExpr))) for { childBatch, err := g.input.Next(batchSize) @@ -92,14 +101,18 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { operators.ReleaseArrays(childBatch.Columns) return nil, err } - arr, err = castArrayToFloat64(arr) - if err != nil { - operators.ReleaseArrays(aggrArrays) - operators.ReleaseArrays(groupArrays) - operators.ReleaseArrays(childBatch.Columns) - return nil, err + if agg.AggrFunc != Count { // handle count case + arr, err = castArrayToFloat64(arr) + if err != nil { + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) + return nil, err + } + aggrArrays[i] = arr + } else { + aggrArrays[i] = arr } - aggrArrays[i] = arr } // 3. process rows @@ -129,8 +142,20 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { if arr.IsNull(row) { continue } - val := arr.(*array.Float64).Value(row) - g.groups[key][i].Update(val) + // handle count + // if it can be cast to float64 do it otherwise its count + arr, ok := arr.(*array.Float64) + if ok { + val := arr.Value(row) + g.groups[key][i].Update(val) + continue + + } + // otherwise we know its count + countOp, ok := g.groups[key][i].(*countAggrAccumulator) + if ok { + countOp.Update(1) + } } } // 4. release temp arrays @@ -139,6 +164,7 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { operators.ReleaseArrays(childBatch.Columns) } + logger.Info("GroupBy aggregation complete", zap.Int("num_groups", len(g.groups))) // 4. Build output RecordBatch batch := buildGroupByOutput(g) @@ -153,6 +179,10 @@ func (g *GroupByExec) Close() error { return g.input.Close() } +func (g *GroupByExec) Name() string { + return "Group By" +} + // handles validation and building of schema for group by func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression, aggrExprs []AggregateFunctions) (*arrow.Schema, error) { @@ -166,7 +196,7 @@ func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression } fields = append(fields, arrow.Field{ - Name: fmt.Sprintf("group_%s", expr.String()), + Name: Expr.To_aggr_name(expr), Type: dt, Nullable: true, }) @@ -175,14 +205,15 @@ func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression // 2. Add aggregate columns for _, agg := range aggrExprs { dt, err := Expr.ExprDataType(agg.Child, childSchema) - if err != nil || !validAggrType(dt) { + if err != nil || !validAggrType(agg, dt) { return nil, ErrInvalidAggrColumnType(dt) } // All aggregates produce float64 - fieldName := fmt.Sprintf("%s_%s", + /*fieldName := fmt.Sprintf("%s_%s", strings.ToLower(aggrToString(int(agg.AggrFunc))), agg.Child.String(), - ) + )*/ + fieldName := fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) fields = append(fields, arrow.Field{ Name: fieldName, diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 41434ac..fdc52d4 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -2,11 +2,9 @@ package aggr import ( "errors" - "fmt" "io" "opti-sql-go/Expr" "opti-sql-go/operators/project" - "strings" "testing" "github.com/apache/arrow/go/v17/arrow" @@ -158,17 +156,14 @@ func TestNewGroupByExecAndSchema(t *testing.T) { // group field f0 := schema.Field(0) - expName := "group_" + groupBy[0].String() + expName := Expr.To_aggr_name(groupBy[0]) if f0.Name != expName { t.Fatalf("expected group field name %q, got %q", expName, f0.Name) } // aggregate field f1 := schema.Field(1) - properAggName := fmt.Sprintf("%s_%s", - strings.ToLower(aggrToString(int(aggs[0].AggrFunc))), - aggs[0].Child.String(), - ) + properAggName := Expr.To_aggr_name(aggs[0].Child) if f1.Name != properAggName { t.Fatalf("expected agg field %q, got %q", properAggName, f1.Name) } @@ -205,7 +200,7 @@ func TestNewGroupByExecAndSchema(t *testing.T) { // group fields first for i, gexpr := range groupBy { f := schema.Field(i) - exp := "group_" + gexpr.String() + exp := Expr.To_aggr_name(gexpr) if f.Name != exp { t.Fatalf("group field[%d] mismatch: want %q got %q", i, exp, f.Name) } @@ -215,10 +210,7 @@ func TestNewGroupByExecAndSchema(t *testing.T) { offset := len(groupBy) for j, agg := range aggs { f := schema.Field(offset + j) - expAggName := fmt.Sprintf("%s_%s", - strings.ToLower(aggrToString(int(agg.AggrFunc))), - agg.Child.String(), - ) + expAggName := Expr.To_aggr_name(agg.Child) if f.Name != expAggName { t.Fatalf("agg field name mismatch: want %q got %q", expAggName, f.Name) } @@ -263,7 +255,7 @@ func TestNewGroupByExecAndSchema(t *testing.T) { } f := schema.Field(0) - exp := "group_" + groupBy[0].String() + exp := Expr.To_aggr_name(groupBy[0]) if f.Name != exp { t.Fatalf("wrong group field name: want %q got %q", exp, f.Name) } @@ -316,8 +308,8 @@ func TestNewGroupByExecAndSchema(t *testing.T) { schema := gb.Schema() - expected0 := "group_" + gbExpr[0].String() // group_Column(seniority) - expected1 := "group_" + gbExpr[1].String() // group_Column(region) + expected0 := Expr.To_aggr_name(gbExpr[0]) // seniority + expected1 := Expr.To_aggr_name(gbExpr[1]) // region if schema.Field(0).Name != expected0 { t.Fatalf("wrong field[0] name: want %q got %q", expected0, schema.Field(0).Name) @@ -327,7 +319,7 @@ func TestNewGroupByExecAndSchema(t *testing.T) { } // count column - expectedAgg := "count_" + aggs[0].Child.String() + expectedAgg := Expr.To_aggr_name(aggs[0].Child) if schema.Field(2).Name != expectedAgg { t.Fatalf("wrong agg field name: want %q got %q", expectedAgg, schema.Field(2).Name) } diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go index a2a559f..15248d2 100644 --- a/src/Backend/opti-sql-go/operators/aggr/having.go +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -75,3 +75,7 @@ func (h *HavingExec) Schema() *arrow.Schema { func (h *HavingExec) Close() error { return h.input.Close() } + +func (h *HavingExec) Name() string { + return "Having" +} diff --git a/src/Backend/opti-sql-go/operators/aggr/having_test.go b/src/Backend/opti-sql-go/operators/aggr/having_test.go index 45275b2..f56a45c 100644 --- a/src/Backend/opti-sql-go/operators/aggr/having_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/having_test.go @@ -3,6 +3,7 @@ package aggr import ( "errors" "io" + "math" "strings" "testing" @@ -31,7 +32,7 @@ func TestHavingExec_OnGroupBy(t *testing.T) { t.Fatalf("unexpected GroupBy error: %v", err) } - sumCol := "sum_Column(salary)" + sumCol := "salary" // SUM(salary) > 600000 havingExpr := Expr.NewBinaryExpr( @@ -76,7 +77,7 @@ func TestHavingExec_OnGroupBy(t *testing.T) { t.Fatalf("unexpected GroupBy err: %v", err) } - countCol := "count_Column(id)" + countCol := "id" havingExpr := Expr.NewBinaryExpr( Expr.NewColumnResolve(countCol), @@ -113,7 +114,7 @@ func TestHavingExec_OnGroupBy(t *testing.T) { gb, _ := NewGroupByExec(child, aggs, groupBy) - sumCol := "sum_Column(salary)" + sumCol := "salary" // Impossible condition havingExpr := Expr.NewBinaryExpr( @@ -148,7 +149,7 @@ func TestHavingExec_OnGroupBy(t *testing.T) { gb, _ := NewGroupByExec(child, aggs, groupBy) // invalid: resolves to float, not boolean - invalidExpr := Expr.NewColumnResolve("sum_Column(salary)") + invalidExpr := Expr.NewColumnResolve("salary") having, _ := NewHavingExec(gb, invalidExpr) @@ -175,7 +176,7 @@ func TestHavingExec_OnGroupBy(t *testing.T) { gb, _ := NewGroupByExec(child, aggs, groupBy) - countCol := "count_Column(id)" + countCol := "id" havingExpr := Expr.NewBinaryExpr( Expr.NewColumnResolve(countCol), @@ -184,9 +185,9 @@ func TestHavingExec_OnGroupBy(t *testing.T) { ) h, _ := NewHavingExec(gb, havingExpr) - h.done = true - _, err := h.Next(10) + _, err := h.Next(math.MaxUint16) + _, err = h.Next(math.MaxUint16) if !errors.Is(err, io.EOF) { t.Fatalf("expected EOF, got: %v", err) } diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 0f7c3b5..ad54933 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -6,11 +6,13 @@ import ( "fmt" "io" "opti-sql-go/Expr" + "opti-sql-go/config" "opti-sql-go/operators" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/compute" + "go.uber.org/zap" ) var ( @@ -161,25 +163,25 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) fields := make([]arrow.Field, len(aggExprs)) for i, agg := range aggExprs { dt, err := Expr.ExprDataType(agg.Child, child.Schema()) - if err != nil || !validAggrType(dt) { + if err != nil || !validAggrType(agg, dt) { return nil, ErrInvalidAggrColumnType(dt) } var fieldName string switch agg.AggrFunc { case Min: - fieldName = fmt.Sprintf("min_%s", agg.Child.String()) + fieldName = fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) accs[i] = newMinAggr() case Max: - fieldName = fmt.Sprintf("max_%s", agg.Child.String()) + fieldName = fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) accs[i] = newMaxAggr() case Count: - fieldName = fmt.Sprintf("count_%s", agg.Child.String()) + fieldName = fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) accs[i] = newCountAggr() case Sum: - fieldName = fmt.Sprintf("sum_%s", agg.Child.String()) + fieldName = fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) accs[i] = newSumAggr() case Avg: - fieldName = fmt.Sprintf("avg_%s", agg.Child.String()) + fieldName = fmt.Sprintf("%s", Expr.To_aggr_name(agg.Child)) accs[i] = newAvgAggr() default: @@ -188,9 +190,20 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) fields[i] = arrow.Field{ Name: fieldName, Type: arrow.PrimitiveTypes.Float64, - Nullable: true, + Nullable: false, } } + logger := config.GetLogger() + logger.Info("Global aggregation schema created", + zap.Strings("input_columns", operators.GetSchemaFieldNames(child.Schema())), + zap.Strings("output_columns", func() []string { + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names + }()), + ) return &AggrExec{ input: child, schema: arrow.NewSchema(fields, nil), @@ -203,9 +216,11 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) // updates the accumulators for each value, and returns a single output batch containing // the final aggregation results. It returns io.EOF after producing the result batch. func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { + logger := config.GetLogger() if a.done { return nil, io.EOF } + logger.Debug("Global aggregation starting", zap.Int("num_aggregations", len(a.aggExpressions))) for { childBatch, err := a.input.Next(n) if err != nil { @@ -219,6 +234,17 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { if err != nil { return nil, err } + if aggExpr.AggrFunc == Count { + accumulator := a.accumulators[i] + for j := 0; j < agrArray.Len(); j++ { + if agrArray.IsNull(j) { + continue + } + accumulator.Update(1) // doesnt matter what we pass here + } + continue + } + agrArray, err = castArrayToFloat64(agrArray) if err != nil { return nil, err @@ -254,8 +280,14 @@ func (a *AggrExec) Schema() *arrow.Schema { func (a *AggrExec) Close() error { return a.input.Close() } +func (a *AggrExec) Name() string { + return "Global Aggregate" +} -func validAggrType(dt arrow.DataType) bool { +func validAggrType(aggrT AggregateFunctions, dt arrow.DataType) bool { + if aggrT.AggrFunc == Count { + return true + } switch dt.ID() { case arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, arrow.FLOAT16, arrow.FLOAT32, arrow.FLOAT64: diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go index 9b5af24..ff669de 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -159,7 +159,7 @@ func TestNewAggrExec(t *testing.T) { t.Fatalf("expected 1 schema field, got %d", exec.Schema().NumFields()) } - expectedName := "min_Column(age)" + expectedName := "age" if exec.Schema().Field(0).Name != expectedName { t.Fatalf("expected name %s, got %s", expectedName, exec.Schema().Field(0).Name) @@ -184,9 +184,9 @@ func TestNewAggrExec(t *testing.T) { schema := exec.Schema() expected := []string{ - "min_Column(id)", - "max_Column(salary)", - "avg_Column(age)", + "id", + "salary", + "age", } for i, f := range schema.Fields() { @@ -562,9 +562,9 @@ func TestAggregateExecNext(t *testing.T) { s := aggrExec.Schema() expectedNames := []string{ - "min_Column(id)", - "sum_Column(age)", - "count_Column(salary)", + "id", + "age", + "salary", } for i, f := range s.Fields() { diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index 1b731f8..423c807 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -7,6 +7,7 @@ import ( "io" "math" "opti-sql-go/Expr" + "opti-sql-go/config" "opti-sql-go/operators" "sort" @@ -14,6 +15,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/compute" "github.com/apache/arrow/go/v17/arrow/memory" + "go.uber.org/zap" ) // order by col asc, col 2 desc .... etc @@ -76,10 +78,12 @@ func NewSortExec(child operators.Operator, sortKeys []SortKey) (*SortExec, error // n is the number of records we will return,sortExec will read in 2^16-1 column entries from its child, this is more efficient that trusting the caller to pass in a reasonable // n so that we avoid small/frequent IO operations func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { + logger := config.GetLogger() if s.done { return nil, io.EOF } if !s.consumed { + logger.Debug("Sort operator consuming input", zap.Int("sort_keys", len(s.sortKeys))) allColumns := make([]arrow.Array, len(s.schema.Fields())) // concated columns mem := memory.NewGoAllocator() var count uint64 @@ -107,6 +111,7 @@ func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { if len(allColumns) > 0 { count = uint64(allColumns[0].Len()) } + logger.Info("Sort operator consumed all input", zap.Uint64("total_rows", count), zap.Int("num_columns", len(allColumns))) idx, err := sortBatches(&operators.RecordBatch{ Schema: s.schema, Columns: allColumns, @@ -159,6 +164,9 @@ func (s *SortExec) Schema() *arrow.Schema { func (s *SortExec) Close() error { return s.input.Close() } +func (s *SortExec) Name() string { + return "Sort" +} func (s *SortExec) consumeSortedBatch(readsize uint64, mem memory.Allocator) ([]arrow.Array, error) { ctx := context.Background() resultColumns := make([]arrow.Array, len(s.schema.Fields())) @@ -261,6 +269,9 @@ func (t *TopKSortExec) Schema() *arrow.Schema { func (t *TopKSortExec) Close() error { return t.input.Close() } +func (t *TopKSortExec) Name() string { + return "Top K Exec" +} type heapRow struct { rowIdx uint64 diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index d09f4a2..d493832 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -3,7 +3,6 @@ package filter import ( "context" "errors" - "fmt" "io" "opti-sql-go/Expr" "opti-sql-go/operators" @@ -119,6 +118,9 @@ func (f *FilterExec) Schema() *arrow.Schema { func (f *FilterExec) Close() error { return f.input.Close() } +func (f *FilterExec) Name() string { + return "Filter" +} func ApplyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { datum, err := compute.Filter( @@ -163,11 +165,9 @@ func validPredicates(pred Expr.Expression, schema *arrow.Schema) bool { if err != nil { return false } - fmt.Printf("dt1:\t%v\ndt2:\t%v\n", dt1, dt2) if !arrow.TypeEqual(dt1, dt2) { return false } - fmt.Printf("left:\t%v\nright:\t%v\n", p.Left, p.Right) return validPredicates(p.Left, schema) && validPredicates(p.Right, schema) diff --git a/src/Backend/opti-sql-go/operators/filter/limit.go b/src/Backend/opti-sql-go/operators/filter/limit.go index 6a5aa86..c241862 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit.go +++ b/src/Backend/opti-sql-go/operators/filter/limit.go @@ -23,7 +23,7 @@ var ( type LimitExec struct { input operators.Operator schema *arrow.Schema - remaining uint16 + Remaining uint16 done bool } @@ -31,7 +31,7 @@ func NewLimitExec(input operators.Operator, count uint16) (*LimitExec, error) { return &LimitExec{ input: input, schema: input.Schema(), - remaining: count, + Remaining: count, }, nil } @@ -43,26 +43,26 @@ func (l *LimitExec) Next(n uint16) (*operators.RecordBatch, error) { RowCount: 0, }, nil } - if l.remaining == 0 { + if l.Remaining == 0 { return nil, io.EOF } var childN uint16 switch { - case n < l.remaining: + case n < l.Remaining: // We can fulfill the request fully childN = n - l.remaining -= n + l.Remaining -= n - case n == l.remaining: + case n == l.Remaining: // Exact request - done afterwards childN = n - l.remaining = 0 + l.Remaining = 0 l.done = true - case n > l.remaining: + case n > l.Remaining: // Only have l.remaining left - childN = l.remaining - l.remaining = 0 + childN = l.Remaining + l.Remaining = 0 l.done = true } childBatch, err := l.input.Next(childN) @@ -78,6 +78,9 @@ func (l *LimitExec) Schema() *arrow.Schema { func (l *LimitExec) Close() error { return l.input.Close() } +func (l *LimitExec) Name() string { + return "Limit" +} type DistinctExec struct { input operators.Operator @@ -209,6 +212,9 @@ func (d *DistinctExec) Close() error { operators.ReleaseArrays(d.distinctValuesArray) return d.input.Close() } +func (d *DistinctExec) Name() string { + return "Distinct" +} func (d *DistinctExec) consumeDistinctArrays(readSize uint64, mem memory.Allocator) ([]arrow.Array, error) { ctx := context.Background() resultColumns := make([]arrow.Array, len(d.schema.Fields())) diff --git a/src/Backend/opti-sql-go/operators/project/csv.go b/src/Backend/opti-sql-go/operators/project/csv.go index 7f57686..be9c8a4 100644 --- a/src/Backend/opti-sql-go/operators/project/csv.go +++ b/src/Backend/opti-sql-go/operators/project/csv.go @@ -97,6 +97,9 @@ func (csvS *CSVSource) Close() error { func (csvS *CSVSource) Schema() *arrow.Schema { return csvS.schema } +func (csvS *CSVSource) Name() string { + return "CSV Source" +} func (csvS *CSVSource) initBuilders() []array.Builder { fields := csvS.schema.Fields() builders := make([]array.Builder, len(fields)) @@ -204,7 +207,6 @@ func parseDataType(sample string) arrow.DataType { if sample == "" || strings.EqualFold(sample, "NULL") { return arrow.BinaryTypes.String } - // Boolean if sample == "true" || sample == "false" { return arrow.FixedWidthTypes.Boolean diff --git a/src/Backend/opti-sql-go/operators/project/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go index 0816600..de36e17 100644 --- a/src/Backend/opti-sql-go/operators/project/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -138,6 +138,9 @@ func (ms *InMemorySource) Close() error { func (ms *InMemorySource) Schema() *arrow.Schema { return ms.schema } +func (ms *InMemorySource) Name() string { + return "In Memory Source" +} func unpackColumn(name string, col any) (arrow.Field, arrow.Array, error) { // need to not only build the array; but also need the schema var field arrow.Field diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index 50aa856..2c494ca 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -14,6 +14,7 @@ import ( "github.com/apache/arrow/go/v17/parquet" "github.com/apache/arrow/go/v17/parquet/file" "github.com/apache/arrow/go/v17/parquet/pqarrow" + "go.uber.org/zap" ) var ( @@ -25,7 +26,9 @@ type ParquetSource struct { schema *arrow.Schema projectionPushDown []string // columns to project up reader pqarrow.RecordReader - done bool // if set to true always return io.EOF + bufferedCols []arrow.Array // internal buffer for excess rows + bufferedSize int64 // number of rows currently buffered + done bool // if set to true always return io.EOF } func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { @@ -37,7 +40,8 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { defer func() { if err := filerReader.Close(); err != nil { - fmt.Printf("warning: failed to close parquet reader: %v\n", err) + logger := config.GetLogger() + logger.Warn("Failed to close parquet reader", zap.Error(err)) } }() @@ -58,6 +62,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { schema: rdr.Schema(), projectionPushDown: []string{}, reader: rdr, + bufferedCols: make([]arrow.Array, rdr.Schema().NumFields()), }, nil } @@ -75,7 +80,8 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq defer func() { if err := filerReader.Close(); err != nil { - fmt.Printf("warning: failed to close parquet reader: %v\n", err) + logger := config.GetLogger() + logger.Warn("Failed to close parquet reader", zap.Error(err)) } }() @@ -107,55 +113,107 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq schema: rdr.Schema(), projectionPushDown: columns, reader: rdr, + bufferedCols: make([]arrow.Array, rdr.Schema().NumFields()), }, nil } // double check that this return exactly n rows in a column. +// ! buffer in memory if what is read in is too much > n func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { - if ps.reader == nil || ps.done || !ps.reader.Next() { + if ps.reader == nil || (ps.done && ps.bufferedSize == 0) { return nil, io.EOF } - columns := make([]arrow.Array, len(ps.schema.Fields())) - curRow := 0 - for curRow < int(n) && ps.reader.Next() { + + mem := memory.NewGoAllocator() + + // Read more data if buffer doesn't have enough rows + for ps.bufferedSize < int64(n) && !ps.done { + if !ps.reader.Next() { + ps.done = true + break + } + err := ps.reader.Err() if err != nil { return nil, err } + record := ps.reader.Record() numCols := int(record.NumCols()) numRows := int(record.NumRows()) for colIdx := 0; colIdx < numCols; colIdx++ { - batchCol := record.Column(colIdx) - existing := columns[colIdx] - // First time seeing this column → just assign it + existing := ps.bufferedCols[colIdx] + if existing == nil { batchCol.Retain() - columns[colIdx] = batchCol - continue + ps.bufferedCols[colIdx] = batchCol + } else { + // Concatenate existing + new + combined, err := array.Concatenate([]arrow.Array{existing, batchCol}, mem) + if err != nil { + return nil, err + } + existing.Release() + ps.bufferedCols[colIdx] = combined } - - // Otherwise combine existing + new batch column - combined := CombineArray(existing, batchCol) - - // Replace - columns[colIdx] = combined - - // Release the old existing array to avoid leaks - existing.Release() } + + ps.bufferedSize += int64(numRows) record.Release() + } + + // If buffer is empty, return EOF + if ps.bufferedSize == 0 { + return nil, io.EOF + } - curRow += numRows + // Emit up to n rows + toEmit := min(int64(n), ps.bufferedSize) + out, err := ps.sliceBufferCols(toEmit, mem) + if err != nil { + return nil, err } + return &operators.RecordBatch{ Schema: ps.schema, - Columns: columns, - RowCount: uint64(curRow), + Columns: out, + RowCount: uint64(toEmit), }, nil } + +func (ps *ParquetSource) sliceBufferCols(n int64, mem memory.Allocator) ([]arrow.Array, error) { + out := make([]arrow.Array, len(ps.bufferedCols)) + + total := ps.bufferedSize + limit := n + if limit > total { + limit = total + } + + // For each column: slice out rows to emit and rows to keep + for i, col := range ps.bufferedCols { + // emit slice [0:limit] + sliceOut := array.NewSlice(col, 0, limit) + out[i] = sliceOut + + // keep remaining slice [limit:total] + keepSlice := array.NewSlice(col, limit, total) + + // release old buffer column + col.Release() + + // store updated buffer + ps.bufferedCols[i] = keepSlice + } + + // update size + ps.bufferedSize = total - limit + + return out, nil +} + func (ps *ParquetSource) Close() error { ps.reader.Release() ps.reader = nil @@ -165,6 +223,10 @@ func (ps *ParquetSource) Schema() *arrow.Schema { return ps.schema } +func (ps *ParquetSource) Name() string { + return "Parquet Source" +} + // append arr2 to arr1 so (arr1 + arr2) = arr1-arr2 func CombineArray(a1, a2 arrow.Array) arrow.Array { if a1 == nil { diff --git a/src/Backend/opti-sql-go/operators/project/parquet_test.go b/src/Backend/opti-sql-go/operators/project/parquet_test.go index c051d9f..7da4462 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/parquet_test.go @@ -11,6 +11,7 @@ import ( ) const ParquetTestDatafile = "../../../test_data/parquet/capitals_clean.parquet" +const ParquetTestDatafile2 = "../../../test_data/parquet/fortune1000_2024.parquet" func getTestParquetFile() *os.File { file, err := os.Open(ParquetTestDatafile) @@ -19,6 +20,13 @@ func getTestParquetFile() *os.File { } return file } +func getTestParquetFile2() *os.File { + file, err := os.Open(ParquetTestDatafile2) + if err != nil { + panic(err) + } + return file +} /* schema: @@ -721,3 +729,51 @@ func TestCombineArray_UnsupportedType(t *testing.T) { // Call CombineArray with unsupported type _ = CombineArray(arr, arr) } + +// ! test that you get back the number of records you requested and set +func TestRecordBatchCount(t *testing.T) { + tests := []struct { + id int + expectedCount uint16 + }{ + { + id: 1, + expectedCount: 10, + }, + { + id: 2, + expectedCount: 1, + }, + + { + id: 3, + expectedCount: 500, + }, + { + id: 4, + expectedCount: 27, + }, + { + id: 1, + expectedCount: 909, + }, + } + for _, tt := range tests { + f := getTestParquetFile2() + pq, err := NewParquetSource(f) + if err != nil { + t.Fatalf("failed to create parquet source node: %v\n", err) + } + + rc, err := pq.Next(tt.expectedCount) + if err != nil { + t.Fatalf("failed to read %d record batches:%v\n", tt.expectedCount, err) + } + // should return up to the requested amount, + if rc.RowCount != uint64(tt.expectedCount) { + t.Errorf("test id:%d failed to return %d record batches , returned %d", tt.id, tt.expectedCount, rc.RowCount) + } + + } + +} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index abd3da8..acc13b7 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -5,9 +5,11 @@ import ( "fmt" "io" "opti-sql-go/Expr" + "opti-sql-go/config" "opti-sql-go/operators" "github.com/apache/arrow/go/v17/arrow" + "go.uber.org/zap" ) var ( @@ -80,10 +82,11 @@ func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*Project // pretty simple, read from child operator and prune columns // pass through error && handles EOF alike func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { + logger := config.GetLogger() if p.done { return nil, io.EOF } - + logger.Debug("Project operator fetching from child", zap.String("child_operator", p.input.Name())) childBatch, err := p.input.Next(n) if err != nil { return nil, err @@ -105,7 +108,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { outPutCols[i] = arr arr.Retain() } - operators.ReleaseArrays(childBatch.Columns) + //operators.ReleaseArrays(childBatch.Columns) return &operators.RecordBatch{ Schema: &p.outputschema, Columns: outPutCols, @@ -118,6 +121,9 @@ func (p *ProjectExec) Close() error { func (p *ProjectExec) Schema() *arrow.Schema { return &p.outputschema } +func (p *ProjectExec) Name() string { + return "Project" +} // handle keeping only the request columns but make sure the schema and columns are also aligned // returns error if a column doesnt exist diff --git a/src/Backend/opti-sql-go/operators/project/s3.go b/src/Backend/opti-sql-go/operators/project/s3.go index b418503..16692f7 100644 --- a/src/Backend/opti-sql-go/operators/project/s3.go +++ b/src/Backend/opti-sql-go/operators/project/s3.go @@ -1,11 +1,12 @@ package project import ( + "bytes" "fmt" "io" "opti-sql-go/config" "os" - "time" + "strings" "github.com/minio/minio-go" ) @@ -80,8 +81,8 @@ func (n *NetworkResource) Seek(offset int64, whence int) (int64, error) { return 0, fmt.Errorf("unsupported seek mode for S3: %d", whence) } } -func (n *NetworkResource) DownloadLocally() (*os.File, error) { - f, err := os.Create(fmt.Sprintf("%s-%d", n.key, time.Now().UnixNano())) +func (n *NetworkResource) DownloadLocally(scramble string) (*os.File, error) { + f, err := os.Create(fmt.Sprintf("%s-%s", n.key, strings.ReplaceAll(scramble, " ", "-"))) if err != nil { return nil, err } @@ -103,3 +104,22 @@ func (n *NetworkResource) DownloadLocally() (*os.File, error) { return f, nil } + +func UploadResults(fileName string, content []byte) error { + + accessKey := secretes.AccessKey + secretKey := secretes.SecretKey + endpoint := secretes.EndpointURL + bucket := secretes.BucketName + useSSL := true + + client, err := minio.New(endpoint, accessKey, secretKey, useSSL) + if err != nil { + return err + } + _, err = client.PutObject(bucket, fileName, bytes.NewReader(content), int64(len(content)), minio.PutObjectOptions{UserMetadata: map[string]string{ + "x-amz-acl": "public-read", + }}) + return err + +} diff --git a/src/Backend/opti-sql-go/operators/project/source_test.go b/src/Backend/opti-sql-go/operators/project/source_test.go index facce88..bdc763c 100644 --- a/src/Backend/opti-sql-go/operators/project/source_test.go +++ b/src/Backend/opti-sql-go/operators/project/source_test.go @@ -11,6 +11,7 @@ const ( s3CSVFile = "country_full.csv" s3ParquetFile = "userdata.parquet" s3TxtFile = "example.txt" + tmp_scamble = "random_test" ) // test s3 as a source first then run test for other source files here @@ -82,7 +83,7 @@ func TestS3Download(t *testing.T) { if err != nil { t.Fatalf("failed to create s3 object: %v", err) } - newFile, err := nr.DownloadLocally() + newFile, err := nr.DownloadLocally(tmp_scamble) if err != nil { t.Fatalf("failed to download file locally %v", err) } @@ -115,7 +116,7 @@ func TestS3Download(t *testing.T) { if err != nil { t.Fatalf("failed to create s3 object: %v", err) } - newFile, err := nr.DownloadLocally() + newFile, err := nr.DownloadLocally(tmp_scamble) if err != nil { t.Fatalf("failed to download file locally %v", err) } @@ -148,7 +149,7 @@ func TestS3Download(t *testing.T) { if err != nil { t.Fatalf("failed to create s3 object: %v", err) } - newFile, err := nr.DownloadLocally() + newFile, err := nr.DownloadLocally(tmp_scamble) if err != nil { t.Fatalf("failed to download file locally %v", err) } @@ -217,7 +218,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to create s3 object: %v", err) } - f, err := nr.DownloadLocally() + f, err := nr.DownloadLocally(tmp_scamble) if err != nil { t.Fatalf("failed to download s3 object locally: %v", err) } @@ -248,7 +249,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to create s3 object: %v", err) } - f, err := nr.DownloadLocally() + f, err := nr.DownloadLocally(tmp_scamble) if err != nil { t.Fatalf("failed to download s3 object locally: %v", err) } @@ -302,3 +303,12 @@ func TestS3Source(t *testing.T) { t.Logf("read %d bytes from s3 object stream: %s\n", n, string(buf[:n])) }) } + +func TestUploadS3(t *testing.T) { + content := []byte("name,id,age,girth\nrich,1,32,9.32") + fName := "upload-test_1.csv" + err := UploadResults(fName, content) + if err != nil { + t.Errorf("test failed this error %v", err) + } +} diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 6678ef4..82d7227 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -1,7 +1,10 @@ package operators import ( + "bytes" + "encoding/csv" "fmt" + "strconv" "strings" "github.com/apache/arrow/go/v17/arrow" @@ -15,11 +18,22 @@ var ( } ) +// GetSchemaFieldNames returns the names of all fields in a schema +// Useful for debugging and logging +func GetSchemaFieldNames(s *arrow.Schema) []string { + names := make([]string, s.NumFields()) + for i := 0; i < s.NumFields(); i++ { + names[i] = s.Field(i).Name + } + return names +} + type Operator interface { Next(uint16) (*RecordBatch, error) Schema() *arrow.Schema // Call Operator.Close() after Next returns an io.EOF to clean up resources Close() error + Name() string } type RecordBatch struct { Schema *arrow.Schema @@ -375,6 +389,67 @@ func (rb *RecordBatch) PrettyPrint() string { return b.String() } +func (rb *RecordBatch) ToCSV() ([]byte, error) { + var buf bytes.Buffer + w := csv.NewWriter(&buf) + + // 1. Write header + headers := make([]string, len(rb.Schema.Fields())) + for i, field := range rb.Schema.Fields() { + headers[i] = field.Name + } + err := w.Write(headers) + if err != nil { + return nil, err + } + + // 2. Write rows + for row := 0; row < int(rb.RowCount); row++ { + record := make([]string, len(rb.Columns)) + + for colIdx, col := range rb.Columns { + if col.IsNull(row) { + record[colIdx] = "" + continue + } + + switch arr := col.(type) { + case *array.String: + record[colIdx] = arr.Value(row) + + case *array.LargeString: + record[colIdx] = arr.Value(row) + + case *array.Int64: + record[colIdx] = strconv.FormatInt(arr.Value(row), 10) + + case *array.Int32: + record[colIdx] = strconv.FormatInt(int64(arr.Value(row)), 10) + + case *array.Float64: + record[colIdx] = strconv.FormatFloat(arr.Value(row), 'f', -1, 64) + + case *array.Float32: + record[colIdx] = strconv.FormatFloat(float64(arr.Value(row)), 'f', -1, 32) + + case *array.Boolean: + record[colIdx] = strconv.FormatBool(arr.Value(row)) + + default: + // Fallback — avoid panic, but make debugging obvious + record[colIdx] = fmt.Sprintf("", col) + } + } + + if err = w.Write(record); err != nil { + return nil, err + } + + } + + w.Flush() + return buf.Bytes(), nil +} // ------------------------------- // Helper Functions diff --git a/src/Backend/opti-sql-go/operators/test/intergration_test.go b/src/Backend/opti-sql-go/operators/test/intergration_test.go index 15786a9..5979b23 100644 --- a/src/Backend/opti-sql-go/operators/test/intergration_test.go +++ b/src/Backend/opti-sql-go/operators/test/intergration_test.go @@ -6,9 +6,9 @@ import ( "io" "opti-sql-go/Expr" "opti-sql-go/operators" - join "opti-sql-go/operators/Join" aggr "opti-sql-go/operators/aggr" "opti-sql-go/operators/filter" + join "opti-sql-go/operators/join" "opti-sql-go/operators/project" "os" "testing" @@ -92,7 +92,7 @@ func TestSelectFilterLimit(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } projExprs := Expr.NewExpressions( @@ -102,17 +102,17 @@ func TestSelectFilterLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, projExprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 10) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(10) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { @@ -141,7 +141,7 @@ func TestSelectFilterLimit(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } projExprs := Expr.NewExpressions( @@ -150,17 +150,17 @@ func TestSelectFilterLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, projExprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 3) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { @@ -182,7 +182,7 @@ func TestSelectFilterLimit(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } projExprs := Expr.NewExpressions( @@ -191,17 +191,17 @@ func TestSelectFilterLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, projExprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 7) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { @@ -231,7 +231,7 @@ func TestFilterScalarFunctions(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } exprs := Expr.NewExpressions( @@ -241,12 +241,12 @@ func TestFilterScalarFunctions(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(2A) got nil batch (possibly EOF)") @@ -267,7 +267,7 @@ func TestFilterScalarFunctions(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } exprs := Expr.NewExpressions( @@ -276,14 +276,14 @@ func TestFilterScalarFunctions(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch != nil { - t.Fatalf("was expecting an empty batch but recieved %s\n", batch.PrettyPrint()) + t.Errorf("was expecting an empty batch but recieved %s\n", batch.PrettyPrint()) return } }) @@ -304,17 +304,17 @@ func TestSelectSort(t *testing.T) { ) proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } sk := aggr.NewSortKey(Expr.NewColumnResolve("account_balance_usd"), true) sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } batch, err := sortExec.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(3A) got nil batch (possibly EOF)") @@ -332,16 +332,16 @@ func TestSelectSort(t *testing.T) { ) proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } sk := aggr.NewSortKey(Expr.NewColumnResolve("favorite_color"), true) sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } batch, err := sortExec.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(3B) got nil batch (possibly EOF)") @@ -366,7 +366,7 @@ func TestJoinSelect(t *testing.T) { ) j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } exprs := Expr.NewExpressions( Expr.NewAlias(Expr.NewColumnResolve("left_id"), "id"), @@ -376,11 +376,11 @@ func TestJoinSelect(t *testing.T) { t.Logf("\t%v\n", j.Schema()) proj, err := project.NewProjectExec(j, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(4A) got nil batch (possibly EOF)") @@ -399,7 +399,7 @@ func TestJoinSelect(t *testing.T) { ) j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } exprs := Expr.NewExpressions( Expr.NewAlias(Expr.NewColumnResolve("left_id"), "cool_guy_id"), @@ -408,11 +408,11 @@ func TestJoinSelect(t *testing.T) { ) proj, err := project.NewProjectExec(j, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(4B) got nil batch (possibly EOF)") @@ -423,6 +423,7 @@ func TestJoinSelect(t *testing.T) { } func TestGroupByAggregation(t *testing.T) { + // ! query doesnt match code // (5.A) SELECT favorite_color, AVG(age_years) AS avg_age, SUM(account_balance_usd) AS total_balance FROM source1 GROUP BY favorite_color order by avg_age; t.Run("5A", func(t *testing.T) { src := source1Project() @@ -435,16 +436,16 @@ func TestGroupByAggregation(t *testing.T) { gb, err := aggr.NewGroupByExec(src, aggs, groupBy) if err != nil { - t.Fatalf("groupby init failed: %v", err) + t.Errorf("groupby init failed: %v", err) } - sortExec, err := aggr.NewSortExec(gb, aggr.CombineSortKeys(aggr.NewSortKey(Expr.NewColumnResolve("avg_Column(age_years)"), true))) + sortExec, err := aggr.NewSortExec(gb, aggr.CombineSortKeys(aggr.NewSortKey(Expr.NewColumnResolve("age_years"), true))) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } batch, err := sortExec.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(5A) got nil batch (possibly EOF)") @@ -464,12 +465,12 @@ func TestGroupByAggregation(t *testing.T) { gb, err := aggr.NewGroupByExec(src, aggs, groupBy) if err != nil { - t.Fatalf("groupby init failed: %v", err) + t.Errorf("groupby init failed: %v", err) } batch, err := gb.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(5B) got nil batch (possibly EOF)") @@ -494,22 +495,22 @@ func TestDistinctSort(t *testing.T) { cols := []Expr.Expression{Expr.NewColumnResolve("favorite_color")} distinct, err := filter.NewDistinctExec(src, cols) if err != nil { - t.Fatalf("distinct init failed: %v", err) + t.Errorf("distinct init failed: %v", err) } sk := aggr.NewSortKey(Expr.NewColumnResolve("favorite_color"), false) // DESC sortExec, err := aggr.NewSortExec(distinct, aggr.CombineSortKeys(sk)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } proj, err := project.NewProjectExec(sortExec, Expr.NewExpressions(Expr.NewColumnResolve("favorite_color"))) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(6A) got nil batch (possibly EOF)") @@ -525,22 +526,22 @@ func TestDistinctSort(t *testing.T) { cols := []Expr.Expression{Expr.NewColumnResolve("is_active")} distinct, err := filter.NewDistinctExec(src, cols) if err != nil { - t.Fatalf("distinct init failed: %v", err) + t.Errorf("distinct init failed: %v", err) } sk := aggr.NewSortKey(Expr.NewColumnResolve("is_active"), false) // DESC sortExec, err := aggr.NewSortExec(distinct, aggr.CombineSortKeys(sk)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } proj, err := project.NewProjectExec(sortExec, Expr.NewExpressions(Expr.NewColumnResolve("is_active"))) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(6B) got nil batch (possibly EOF)") @@ -565,7 +566,7 @@ func TestJoinFilterProjLimit(t *testing.T) { ) j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } pred := Expr.NewBinaryExpr( Expr.NewColumnResolve("age_years"), @@ -575,7 +576,7 @@ func TestJoinFilterProjLimit(t *testing.T) { filt, err := filter.NewFilterExec(j, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } exprs := Expr.NewExpressions( @@ -585,17 +586,17 @@ func TestJoinFilterProjLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 5) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(7A) got nil batch (possibly EOF)") @@ -614,7 +615,7 @@ func TestJoinFilterProjLimit(t *testing.T) { ) j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } pred := Expr.NewBinaryExpr( @@ -625,7 +626,7 @@ func TestJoinFilterProjLimit(t *testing.T) { filt, err := filter.NewFilterExec(j, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } exprs := Expr.NewExpressions( @@ -634,17 +635,17 @@ func TestJoinFilterProjLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 3) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(7B) got nil batch (possibly EOF)") @@ -663,7 +664,7 @@ func TestJoinFilterProjLimit(t *testing.T) { ) j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } pred := Expr.NewBinaryExpr( @@ -674,7 +675,7 @@ func TestJoinFilterProjLimit(t *testing.T) { filt, err := filter.NewFilterExec(j, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } exprs := Expr.NewExpressions( @@ -683,17 +684,17 @@ func TestJoinFilterProjLimit(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } lim, err := filter.NewLimitExec(proj, 2) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(7C) got nil batch (possibly EOF)") @@ -733,7 +734,7 @@ func TestScalarAbsRound(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } // projection: id, ROUND(ABS(average_session_minutes)) as rounded_session @@ -744,12 +745,12 @@ func TestScalarAbsRound(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(8A) got nil batch (possibly EOF)") @@ -771,7 +772,7 @@ func TestScalarAbsRound(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } roundExpr := Expr.NewScalarFunction(Expr.Round, Expr.NewColumnResolve("account_balance_usd")) @@ -781,12 +782,12 @@ func TestScalarAbsRound(t *testing.T) { ) proj, err := project.NewProjectExec(filt, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := proj.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(8B) got nil batch (possibly EOF)") @@ -810,19 +811,19 @@ func TestSelectMultiSort(t *testing.T) { ) proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } sk1 := aggr.NewSortKey(Expr.NewColumnResolve("age_years"), false) // DESC sk2 := aggr.NewSortKey(Expr.NewColumnResolve("username"), true) // ASC sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk1, sk2)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } batch, err := sortExec.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(9A) got nil batch (possibly EOF)") @@ -841,19 +842,19 @@ func TestSelectMultiSort(t *testing.T) { ) proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } sk1 := aggr.NewSortKey(Expr.NewColumnResolve("age_years"), true) // ASC sk2 := aggr.NewSortKey(Expr.NewColumnResolve("email_address"), false) // DESC sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk1, sk2)) if err != nil { - t.Fatalf("sort init failed: %v", err) + t.Errorf("sort init failed: %v", err) } batch, err := sortExec.Next(100) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { t.Logf("(9B) got nil batch (possibly EOF)") diff --git a/src/Backend/opti-sql-go/operators/test/regression_test.go b/src/Backend/opti-sql-go/operators/test/regression_test.go new file mode 100644 index 0000000..9954483 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/test/regression_test.go @@ -0,0 +1,591 @@ +package test + +import ( + "errors" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + aggr "opti-sql-go/operators/aggr" + "opti-sql-go/operators/filter" + "opti-sql-go/operators/project" + "testing" + + "github.com/apache/arrow/go/v17/arrow" +) + +// TestAliasRegressionProjectFilter - Test 1: Project with alias, then filter on aliased column +// SQL: SELECT id AS user_id, username FROM source1 WHERE user_id > 5 +func TestAliasRegressionProjectFilter(t *testing.T) { + src := source1Project() + + projExprs := Expr.NewExpressions( + Expr.NewAlias(Expr.NewColumnResolve("id"), "user_id"), + Expr.NewColumnResolve("username"), + ) + proj, err := project.NewProjectExec(src, projExprs) + if err != nil { + t.Errorf("project failed: %v", err) + } + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("user_id"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, 5), + ) + + filt, err := filter.NewFilterExec(proj, pred) + if err != nil { + t.Errorf("filter failed (aliasing broken): %v\nAvailable: %v", + err, operators.GetSchemaFieldNames(proj.Schema())) + } + + batch, err := filt.Next(10) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("SUCCESS - Result:\n%v", batch.PrettyPrint()) + } +} + +// TestAliasRegressionFilterBeforeAggr - Test 2a: Filter BEFORE aggregation (WHERE clause) +// SQL: SELECT username, AVG(account_balance_usd) FROM source1 WHERE id > 5 GROUP BY username +func TestAliasRegressionFilterBeforeAggr(t *testing.T) { + src := source1Project() + + // WHERE id > 5 + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("id"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, 5), + ) + + filt, err := filter.NewFilterExec(src, pred) + if err != nil { + t.Errorf("filter before aggregation failed: %v", err) + } + + // GROUP BY username + groupByExprs := []Expr.Expression{ + Expr.NewColumnResolve("username"), + } + + // AVG(account_balance_usd) + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Avg, + Expr.NewColumnResolve("account_balance_usd")), + } + + groupByOp, err := aggr.NewGroupByExec(filt, aggrExprs, groupByExprs) + if err != nil { + t.Errorf("group by after filter failed: %v", err) + } + t.Logf("%v\n", groupByOp.Schema()) + + outputCols := operators.GetSchemaFieldNames(groupByOp.Schema()) + t.Logf("Output columns after filter->groupby: %v", outputCols) + + batch, err := groupByOp.Next(10) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestAliasRegressionFilterAfterAggr - Test 2b: Filter AFTER aggregation (HAVING clause) +// SQL: SELECT username, AVG(account_balance_usd) FROM source1 GROUP BY username HAVING AVG(account_balance_usd) > 500 +func TestAliasRegressionFilterAfterAggr(t *testing.T) { + src := source1Project() + + // GROUP BY username + groupByExprs := []Expr.Expression{ + Expr.NewColumnResolve("username"), + } + + // AVG(account_balance_usd) + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Avg, + Expr.NewColumnResolve("account_balance_usd")), + } + + groupByOp, err := aggr.NewGroupByExec(src, aggrExprs, groupByExprs) + if err != nil { + t.Errorf("group by failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(groupByOp.Schema()) + t.Logf("GroupBy columns: %v", outputCols) + + if len(outputCols) < 2 { + t.Error("Expected at least 2 columns (groupby + aggregation)") + } + + // HAVING AVG(account_balance_usd) > 500 + // The aggregation column should be at index 1 + avgColName := outputCols[1] + t.Logf("Attempting HAVING on aggregation column: %s", avgColName) + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve(avgColName), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 500.0), + ) + + havingFilter, err := filter.NewFilterExec(groupByOp, pred) + if err != nil { + t.Errorf("HAVING (filter after aggregation) failed: %v\nColumn: %s\nAvailable: %v", + err, avgColName, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestAliasRegressionGroupBy - Test 3: Group by with HAVING clause +// SQL: SELECT username, SUM(account_balance_usd) FROM source1 GROUP BY username HAVING SUM(account_balance_usd) > 500 +func TestAliasRegressionGroupBy(t *testing.T) { + src := source1Project() + + groupByExprs := []Expr.Expression{ + Expr.NewColumnResolve("username"), + } + + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Sum, Expr.NewColumnResolve("account_balance_usd")), + aggr.NewAggregateFunctions(aggr.Count, Expr.NewColumnResolve("id")), + } + + groupByOp, err := aggr.NewGroupByExec(src, aggrExprs, groupByExprs) + if err != nil { + t.Errorf("group by failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(groupByOp.Schema()) + t.Logf("GroupBy columns: %v", outputCols) + + if len(outputCols) < 2 { + t.Error("Expected at least 2 columns") + } + + sumColName := outputCols[1] + t.Logf("Attempting HAVING on column: %s", sumColName) + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve(sumColName), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 500.0), + ) + + havingFilter, err := filter.NewFilterExec(groupByOp, pred) + if err != nil { + t.Errorf("HAVING failed (aliasing broken): %v\nColumn: %s\nAvailable: %v", + err, sumColName, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("SUCCESS - Result:\n%v", batch.PrettyPrint()) + } +} + +// TestAggrWithAliasFilterOnAlias tests that when using an alias in aggregation, +// the filter (HAVING) should reference the alias name, NOT the underlying column +// SQL: SELECT SUM(account_balance_usd) AS total_balance FROM source1 HAVING total_balance > 1000 +func TestAggrWithAliasFilterOnAlias(t *testing.T) { + src := source1Project() + + // SUM(account_balance_usd) AS total_balance + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Sum, + Expr.NewAlias(Expr.NewColumnResolve("account_balance_usd"), "total_balance")), + } + + aggrOp, err := aggr.NewGlobalAggrExec(src, aggrExprs) + if err != nil { + t.Errorf("aggregation failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(aggrOp.Schema()) + t.Logf("Aggregation output columns: %v", outputCols) + + // Expected: column name should be "total_balance", NOT "sum_account_balance_usd" + if len(outputCols) != 1 || outputCols[0] != "total_balance" { + t.Errorf("Expected column name 'total_balance', got: %v", outputCols) + } + + // HAVING total_balance > 1000 + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("total_balance"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 1000.0), + ) + + havingFilter, err := filter.NewFilterExec(aggrOp, pred) + if err != nil { + t.Errorf("HAVING on aliased aggregation failed: %v\nAvailable columns: %v", err, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestAggrWithoutAliasUsesColumnName tests that without an alias, +// the column name should just be the column name (no prefix) +// SQL: SELECT SUM(account_balance_usd) FROM source1 HAVING account_balance_usd > 500 +func TestAggrWithoutAliasUsesColumnName(t *testing.T) { + src := source1Project() + + // SUM(account_balance_usd) - no alias + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Sum, + Expr.NewColumnResolve("account_balance_usd")), + } + + aggrOp, err := aggr.NewGlobalAggrExec(src, aggrExprs) + if err != nil { + t.Errorf("aggregation failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(aggrOp.Schema()) + t.Logf("Aggregation output columns: %v", outputCols) + + // Expected: column name should be just "account_balance_usd" + if len(outputCols) != 1 || outputCols[0] != "account_balance_usd" { + t.Errorf("Expected column name 'account_balance_usd', got: %v", outputCols) + } + + // Filter on the column name + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("account_balance_usd"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 500.0), + ) + + havingFilter, err := filter.NewFilterExec(aggrOp, pred) + if err != nil { + t.Errorf("Filter on non-aliased aggregation failed: %v\nAvailable columns: %v", err, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestGroupByWithAliasedAggregation tests GROUP BY with aliased aggregation +// SQL: SELECT username, AVG(account_balance_usd) AS avg_balance FROM source1 GROUP BY username HAVING avg_balance > 500 +func TestGroupByWithAliasedAggregation(t *testing.T) { + src := source1Project() + + groupByExprs := []Expr.Expression{ + Expr.NewColumnResolve("username"), + } + + // AVG(account_balance_usd) AS avg_balance + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Avg, + Expr.NewAlias(Expr.NewColumnResolve("account_balance_usd"), "avg_balance")), + } + + groupByOp, err := aggr.NewGroupByExec(src, aggrExprs, groupByExprs) + if err != nil { + t.Errorf("group by failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(groupByOp.Schema()) + t.Logf("GroupBy output columns: %v", outputCols) + + // Expected: columns should be ["username", "avg_balance"] + if len(outputCols) != 2 { + t.Errorf("Expected 2 columns, got %d: %v", len(outputCols), outputCols) + } + if outputCols[0] != "username" { + t.Errorf("Expected first column 'username', got: %s", outputCols[0]) + } + if outputCols[1] != "avg_balance" { + t.Errorf("Expected second column 'avg_balance', got: %s", outputCols[1]) + } + + // HAVING avg_balance > 500 + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("avg_balance"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 500.0), + ) + + havingFilter, err := filter.NewFilterExec(groupByOp, pred) + if err != nil { + t.Errorf("HAVING on aliased aggregation failed: %v\nAvailable columns: %v", err, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestGroupByWithoutAliasedAggregation tests GROUP BY without alias - should use column name +// SQL: SELECT username, COUNT(id) FROM source1 GROUP BY username HAVING id > 5 +func TestGroupByWithoutAliasedAggregation(t *testing.T) { + src := source1Project() + + groupByExprs := []Expr.Expression{ + Expr.NewColumnResolve("username"), + } + + // COUNT(id) - no alias + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Count, + Expr.NewColumnResolve("id")), + } + + groupByOp, err := aggr.NewGroupByExec(src, aggrExprs, groupByExprs) + if err != nil { + t.Errorf("group by failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(groupByOp.Schema()) + t.Logf("GroupBy output columns: %v", outputCols) + + // Expected: columns should be ["username", "id"] + if len(outputCols) != 2 { + t.Errorf("Expected 2 columns, got %d: %v", len(outputCols), outputCols) + } + if outputCols[0] != "username" { + t.Errorf("Expected first column 'username', got: %s", outputCols[0]) + } + if outputCols[1] != "id" { + t.Errorf("Expected second column 'id', got: %s", outputCols[1]) + } + + // HAVING id > 5 (referencing the COUNT result by the column name) + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("id"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 5.0), + ) + + havingFilter, err := filter.NewFilterExec(groupByOp, pred) + if err != nil { + t.Errorf("HAVING on non-aliased aggregation failed: %v\nAvailable columns: %v", err, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestMultipleAggrWithMixedAliasing tests multiple aggregations with some aliased, some not +// SQL: SELECT COUNT(id) AS user_count, SUM(account_balance_usd) FROM source1 HAVING user_count > 5 +func TestMultipleAggrWithMixedAliasing(t *testing.T) { + src := source1Project() + + // COUNT(id) AS user_count, SUM(account_balance_usd) (no alias) + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Count, + Expr.NewAlias(Expr.NewColumnResolve("id"), "user_count")), + aggr.NewAggregateFunctions(aggr.Sum, + Expr.NewColumnResolve("account_balance_usd")), + } + + aggrOp, err := aggr.NewGlobalAggrExec(src, aggrExprs) + if err != nil { + t.Errorf("aggregation failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(aggrOp.Schema()) + t.Logf("Aggregation output columns: %v", outputCols) + + // Expected: ["user_count", "account_balance_usd"] + if len(outputCols) != 2 { + t.Errorf("Expected 2 columns, got %d: %v", len(outputCols), outputCols) + } + if outputCols[0] != "user_count" { + t.Errorf("Expected first column 'user_count', got: %s", outputCols[0]) + } + if outputCols[1] != "account_balance_usd" { + t.Errorf("Expected second column 'account_balance_usd', got: %s", outputCols[1]) + } + + // Filter on the aliased column + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("user_count"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 5.0), + ) + + havingFilter, err := filter.NewFilterExec(aggrOp, pred) + if err != nil { + t.Errorf("HAVING on user_count failed: %v\nAvailable columns: %v", err, outputCols) + } + + batch, err := havingFilter.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestProjectThenAggrWithAlias tests projection followed by aggregation with alias +// SQL: SELECT id AS user_id FROM source1; then SELECT COUNT(user_id) AS total FROM ... +func TestProjectThenAggrWithAlias(t *testing.T) { + src := source1Project() + + // First project: id AS user_id + projExprs := Expr.NewExpressions( + Expr.NewAlias(Expr.NewColumnResolve("id"), "user_id"), + Expr.NewColumnResolve("username"), + ) + proj, err := project.NewProjectExec(src, projExprs) + if err != nil { + t.Errorf("project failed: %v", err) + } + + // Then aggregate: COUNT(user_id) AS total + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Count, + Expr.NewAlias(Expr.NewColumnResolve("user_id"), "total")), + } + + aggrOp, err := aggr.NewGlobalAggrExec(proj, aggrExprs) + if err != nil { + t.Errorf("aggregation failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(aggrOp.Schema()) + t.Logf("Aggregation output columns: %v", outputCols) + + // Expected: column name should be "total" + if len(outputCols) != 1 || outputCols[0] != "total" { + t.Errorf("Expected column name 'total', got: %v", outputCols) + } + + batch, err := aggrOp.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + } + + if batch != nil { + t.Logf("Result:\n%v", batch.PrettyPrint()) + } +} + +// TestInvalidFilterOnWrongColumnName tests that filtering on the wrong column name fails +// This is a negative test - it should FAIL if the column name is wrong +// SQL: SELECT SUM(balance) AS total FROM source1 HAVING account_balance_usd > 500 (should fail - account_balance_usd doesn't exist after aggregation) +func TestInvalidFilterOnWrongColumnName(t *testing.T) { + src := source1Project() + + // SUM(account_balance_usd) AS total + aggrExprs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Sum, + Expr.NewAlias(Expr.NewColumnResolve("account_balance_usd"), "total")), + } + + aggrOp, err := aggr.NewGlobalAggrExec(src, aggrExprs) + if err != nil { + t.Errorf("aggregation failed: %v", err) + } + + outputCols := operators.GetSchemaFieldNames(aggrOp.Schema()) + t.Logf("Aggregation output columns: %v", outputCols) + + // Try to filter on "account_balance_usd" which no longer exists (should be "total") + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("account_balance_usd"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 500.0), + ) + + _, err = filter.NewFilterExec(aggrOp, pred) + if err == nil { + t.Error("Expected filter to fail when referencing non-existent column 'account_balance_usd', but it succeeded") + } + + t.Logf("Correctly failed: %v", err) +} + +// TestSimpleAliasWithFilter tests basic aliasing followed by filtering on the alias +// SQL: SELECT age_years AS age FROM source1 WHERE age > 5 +func TestSimpleAliasWithFilter(t *testing.T) { + src := source1Project() + + // SELECT age_years AS age + projExprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewAlias(Expr.NewColumnResolve("age_years"), "age"), + ) + + proj, err := project.NewProjectExec(src, projExprs) + if err != nil { + t.Errorf("project failed: %v", err) + return + } + + outputCols := operators.GetSchemaFieldNames(proj.Schema()) + t.Logf("Project output columns: %v", outputCols) + + // Expected: column should be "age" + if len(outputCols) != 2 { + t.Errorf("Expected column names `id` 'age', got: %v", outputCols) + } + + // WHERE age > 5 + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, 5), + ) + + filt, err := filter.NewFilterExec(proj, pred) + if err != nil { + t.Errorf("filter on aliased column failed: %v\nAvailable columns: %v", err, outputCols) + return + } + + batch, err := filt.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Errorf("execution failed: %v", err) + return + } + + if batch == nil { + t.Errorf("empty record batch was returned") + } +} diff --git a/src/Backend/opti-sql-go/operators/test/t1_test.go b/src/Backend/opti-sql-go/operators/test/t1_test.go index dd728fb..1825fd8 100644 --- a/src/Backend/opti-sql-go/operators/test/t1_test.go +++ b/src/Backend/opti-sql-go/operators/test/t1_test.go @@ -6,9 +6,9 @@ import ( "math" "opti-sql-go/Expr" "opti-sql-go/operators" - join "opti-sql-go/operators/Join" "opti-sql-go/operators/aggr" "opti-sql-go/operators/filter" + join "opti-sql-go/operators/join" "opti-sql-go/operators/project" "strings" "testing" @@ -222,7 +222,7 @@ func TestProjectExec(t *testing.T) { src, err := NewIntegrationSource1(mem) if err != nil { - t.Fatalf("failed to create integration source: %v", err) + t.Errorf("failed to create integration source: %v", err) } exprs := Expr.NewExpressions( Expr.NewColumnResolve("id"), @@ -232,17 +232,17 @@ func TestProjectExec(t *testing.T) { ) basicProj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("unexpected error\t%v\n", basicProj) + t.Errorf("unexpected error\t%v\n", basicProj) } //t.Logf("%v\n", basicProj.Schema()) rc, err := basicProj.Next(100) if err != nil { if !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error %v\n", err) + t.Errorf("unexpected error %v\n", err) } } if rc.RowCount != 20 { - t.Fatalf("expected 20 rows, got %d", rc.RowCount) + t.Errorf("expected 20 rows, got %d", rc.RowCount) } }) t.Run("projection_with_alias", func(t *testing.T) { @@ -256,14 +256,14 @@ func TestProjectExec(t *testing.T) { proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("error: %v", err) + t.Errorf("error: %v", err) } batch, _ := proj.Next(50) // verify alias appears in schema if batch.Schema.Fields()[1].Name != "emp_salary" { - t.Fatalf("expected alias emp_salary, got %s", batch.Schema.Fields()[1].Name) + t.Errorf("expected alias emp_salary, got %s", batch.Schema.Fields()[1].Name) } }) t.Run("projection_expression_math", func(t *testing.T) { @@ -284,7 +284,7 @@ func TestProjectExec(t *testing.T) { proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("error: %v", err) + t.Errorf("error: %v", err) } batch, _ := proj.Next(50) @@ -294,13 +294,13 @@ func TestProjectExec(t *testing.T) { sal := origin[4].(*array.Float64) // check: for a non-null salary (row 0 = 50000) if adjCol.Len() != sal.Len() { - t.Fatalf("expected adjusted salary length %d, got %d", sal.Len(), adjCol.Len()) + t.Errorf("expected adjusted salary length %d, got %d", sal.Len(), adjCol.Len()) } for i := 0; i < adjCol.Len(); i++ { if !sal.IsNull(i) { expected := sal.Value(i) * 1.10 if adjCol.Value(i) != expected { - t.Fatalf("row %d: expected adjusted salary %f, got %f", i, expected, adjCol.Value(i)) + t.Errorf("row %d: expected adjusted salary %f, got %f", i, expected, adjCol.Value(i)) } } } @@ -310,7 +310,7 @@ func TestProjectExec(t *testing.T) { src, err := NewIntegrationSource1(mem) if err != nil { - t.Fatalf("failed to create integration source: %v", err) + t.Errorf("failed to create integration source: %v", err) } exprs := Expr.NewExpressions( @@ -322,15 +322,15 @@ func TestProjectExec(t *testing.T) { proj, err := project.NewProjectExec(src, exprs) if err != nil { - t.Fatalf("unexpected project exec error: %v", err) + t.Errorf("unexpected project exec error: %v", err) } batch, err := proj.Next(100) // pull all rows at once if err != nil { - t.Fatalf("unexpected error on Next: %v", err) + t.Errorf("unexpected error on Next: %v", err) } if batch == nil { - t.Fatalf("expected a batch but got nil") + t.Errorf("expected a batch but got nil") } // ---- get projected column (index 0) ---- @@ -341,7 +341,7 @@ func TestProjectExec(t *testing.T) { firstNameCol := originCols[1].(*array.String) // index 1 is first_name if upperCol.Len() != firstNameCol.Len() { - t.Fatalf("length mismatch: expected %d got %d", + t.Errorf("length mismatch: expected %d got %d", firstNameCol.Len(), upperCol.Len()) } @@ -349,7 +349,7 @@ func TestProjectExec(t *testing.T) { for i := 0; i < upperCol.Len(); i++ { if firstNameCol.IsNull(i) { if !upperCol.IsNull(i) { - t.Fatalf("row %d: expected NULL but got value", i) + t.Errorf("row %d: expected NULL but got value", i) } continue } @@ -358,7 +358,7 @@ func TestProjectExec(t *testing.T) { got := upperCol.Value(i) if expected != got { - t.Fatalf("row %d: expected %q, got %q", i, expected, got) + t.Errorf("row %d: expected %q, got %q", i, expected, got) } } }) @@ -378,7 +378,7 @@ func TestFilterExec(t *testing.T) { names, cols := generateIntegrationDataset1(mem) src, err := project.NewInMemoryProjectExecFromArrays(names, cols) if err != nil { - t.Fatalf("failed to create in-memory source: %v", err) + t.Errorf("failed to create in-memory source: %v", err) } pred := Expr.NewBinaryExpr( Expr.NewColumnResolve("age"), @@ -388,22 +388,22 @@ func TestFilterExec(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } batch, err := filt.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { - t.Fatalf("expected rows, got nil batch") + t.Errorf("expected rows, got nil batch") } ageCol, _ := batch.ColumnByName("age") for i := 0; i < ageCol.Len(); i++ { ageValue := ageCol.(*array.Int32).Value(i) if ageValue <= 30 { - t.Fatalf("expected age > 30, got %d", ageValue) + t.Errorf("expected age > 30, got %d", ageValue) } } @@ -414,7 +414,7 @@ func TestFilterExec(t *testing.T) { names, cols := generateIntegrationDataset1(mem) src, err := project.NewInMemoryProjectExecFromArrays(names, cols) if err != nil { - t.Fatalf("failed to create in-memory source: %v", err) + t.Errorf("failed to create in-memory source: %v", err) } pred := Expr.NewBinaryExpr( Expr.NewBinaryExpr( @@ -433,15 +433,15 @@ func TestFilterExec(t *testing.T) { filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } batch, err := filt.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { - t.Fatalf("expected non-nil batch") + t.Errorf("expected non-nil batch") } // validate @@ -451,10 +451,10 @@ func TestFilterExec(t *testing.T) { salColumn, _ := salCol.(*array.Float64) for i := 0; i < int(batch.RowCount); i++ { if depColumn.Value(i) != "Engineering" { - t.Fatalf("expected department 'Engineering', got %s", depColumn.Value(i)) + t.Errorf("expected department 'Engineering', got %s", depColumn.Value(i)) } if salColumn.Value(i) <= 70000 { - t.Fatalf("expected salary > 70000, got %f", salColumn.Value(i)) + t.Errorf("expected salary > 70000, got %f", salColumn.Value(i)) } } }) @@ -464,24 +464,24 @@ func TestFilterExec(t *testing.T) { names, cols := generateIntegrationDataset1(mem) src, err := project.NewInMemoryProjectExecFromArrays(names, cols) if err != nil { - t.Fatalf("failed to create in-memory source: %v", err) + t.Errorf("failed to create in-memory source: %v", err) } // We're filtering region IS NULL pred := Expr.NewNullCheckExpr(Expr.NewColumnResolve("region")) filt, err := filter.NewFilterExec(src, pred) if err != nil { - t.Fatalf("filter init failed: %v", err) + t.Errorf("filter init failed: %v", err) } batch, err := filt.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } if batch == nil { // possible: no NULLS - t.Fatalf("expected atleast one null") + t.Errorf("expected atleast one null") return } t.Logf("batch: \t%v\n", batch.PrettyPrint()) @@ -490,7 +490,7 @@ func TestFilterExec(t *testing.T) { regionArr := regionCol.(*array.String) for i := 0; i < int(batch.RowCount); i++ { if regionArr.IsNull(i) { - t.Fatalf("expected NULL region but got value=%s", regionArr.Value(i)) + t.Errorf("expected NULL region but got value=%s", regionArr.Value(i)) } } }) @@ -514,12 +514,12 @@ func TestSortTest(t *testing.T) { sortExec, err := aggr.NewSortExec(src, sortKeys) if err != nil { - t.Fatalf("failed to create sort exec: %v", err) + t.Errorf("failed to create sort exec: %v", err) } batch, err := sortExec.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } salaryArr := batch.Columns[4].(*array.Float64) @@ -529,7 +529,7 @@ func TestSortTest(t *testing.T) { continue } if salaryArr.Value(i) < salaryArr.Value(i-1) { - t.Fatalf("salary not sorted ASC at row %d: %f < %f", + t.Errorf("salary not sorted ASC at row %d: %f < %f", i, salaryArr.Value(i), salaryArr.Value(i-1)) } } @@ -547,12 +547,12 @@ func TestSortTest(t *testing.T) { sortExec, err := aggr.NewSortExec(src, sortKeys) if err != nil { - t.Fatalf("failed to create sort exec: %v", err) + t.Errorf("failed to create sort exec: %v", err) } batch, err := sortExec.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } lastArr := batch.Columns[2].(*array.String) @@ -564,7 +564,7 @@ func TestSortTest(t *testing.T) { // descending → current <= previous if lastArr.Value(i) > lastArr.Value(i-1) { - t.Fatalf("last_name not sorted DESC at %d: %s > %s", + t.Errorf("last_name not sorted DESC at %d: %s > %s", i, lastArr.Value(i), lastArr.Value(i-1)) } } @@ -582,12 +582,12 @@ func TestSortTest(t *testing.T) { sortExec, err := aggr.NewSortExec(src, sortKeys) if err != nil { - t.Fatalf("failed to create sort exec: %v", err) + t.Errorf("failed to create sort exec: %v", err) } batch, err := sortExec.Next(1000) if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("unexpected error: %v", err) + t.Errorf("unexpected error: %v", err) } deptArr := batch.Columns[5].(*array.String) @@ -603,7 +603,7 @@ func TestSortTest(t *testing.T) { // department ascending grouping if currDept < prevDept { - t.Fatalf("department not sorted ASC at %d: %s < %s", + t.Errorf("department not sorted ASC at %d: %s < %s", i, currDept, prevDept) } @@ -611,7 +611,7 @@ func TestSortTest(t *testing.T) { if currDept == prevDept { if !salaryArr.IsNull(i) && !salaryArr.IsNull(i-1) { if salaryArr.Value(i) > salaryArr.Value(i-1) { - t.Fatalf("salary not DESC within department '%s' at row %d", + t.Errorf("salary not DESC within department '%s' at row %d", currDept, i) } } @@ -666,12 +666,12 @@ func TestIntegrationAggregations(t *testing.T) { aggr.NewAggregateFunctions(aggr.Min, salCol), aggr.NewAggregateFunctions(aggr.Max, salCol)}) if err != nil { - t.Fatalf("aggregation init failed: %v", err) + t.Errorf("aggregation init failed: %v", err) } batch, err := agg.Next(100) if err != nil { - t.Fatalf("aggregation next failed: %v", err) + t.Errorf("aggregation next failed: %v", err) } // Extract columns from result @@ -681,16 +681,16 @@ func TestIntegrationAggregations(t *testing.T) { maxArr := batch.Columns[3].(*array.Float64) if sumArr.Value(0) != sum { - t.Fatalf("SUM mismatch: expected %f, got %f", sum, sumArr.Value(0)) + t.Errorf("SUM mismatch: expected %f, got %f", sum, sumArr.Value(0)) } if avgArr.Value(0) != avg { - t.Fatalf("AVG mismatch: expected %f, got %f", avg, avgArr.Value(0)) + t.Errorf("AVG mismatch: expected %f, got %f", avg, avgArr.Value(0)) } if minArr.Value(0) != min { - t.Fatalf("MIN mismatch: expected %f, got %f", min, minArr.Value(0)) + t.Errorf("MIN mismatch: expected %f, got %f", min, minArr.Value(0)) } if maxArr.Value(0) != max { - t.Fatalf("MAX mismatch: expected %f, got %f", max, maxArr.Value(0)) + t.Errorf("MAX mismatch: expected %f, got %f", max, maxArr.Value(0)) } }) @@ -719,14 +719,14 @@ func TestIntegrationAggregations(t *testing.T) { }, ) if err != nil { - t.Fatalf("agg init failed: %v", err) + t.Errorf("agg init failed: %v", err) } batch, _ := agg.Next(100) sumArr := batch.Columns[0].(*array.Float64) // SUM(int32) -> int64 if sumArr.Value(0) != float64(sum) { - t.Fatalf("SUM(age) mismatch: expected %v, got %v", sum, sumArr.Value(0)) + t.Errorf("SUM(age) mismatch: expected %v, got %v", sum, sumArr.Value(0)) } }) @@ -761,7 +761,7 @@ func TestIntegrationAggregations(t *testing.T) { aggr.NewAggregateFunctions(aggr.Max, Expr.NewColumnResolve("age")), }) if err != nil { - t.Fatalf("agg init failed: %v", err) + t.Errorf("agg init failed: %v", err) } batch, _ := agg.Next(100) @@ -770,10 +770,10 @@ func TestIntegrationAggregations(t *testing.T) { maxArr := batch.Columns[1].(*array.Float64) if minArr.Value(0) != float64(min) { - t.Fatalf("MIN(age) mismatch: expected %v, got %v", min, minArr.Value(0)) + t.Errorf("MIN(age) mismatch: expected %v, got %v", min, minArr.Value(0)) } if maxArr.Value(0) != float64(max) { - t.Fatalf("MAX(age) mismatch: expected %v, got %v", max, maxArr.Value(0)) + t.Errorf("MAX(age) mismatch: expected %v, got %v", max, maxArr.Value(0)) } }) } @@ -803,12 +803,12 @@ func TestGroupByExec(t *testing.T) { gb, err := aggr.NewGroupByExec(src, aggs, groupByExpr) if err != nil { - t.Fatalf("gb init failed: %v", err) + t.Errorf("gb init failed: %v", err) } batch, err := gb.Next(1024) if err != nil { - t.Fatalf("group by Next failed: %v", err) + t.Errorf("group by Next failed: %v", err) } deptCol := batch.Columns[0].(*array.String) @@ -835,7 +835,7 @@ func TestGroupByExec(t *testing.T) { want := expected[key] if got != want { - t.Fatalf("group %s: expected %d, got %d", key, want, got) + t.Errorf("group %s: expected %d, got %d", key, want, got) } } }) @@ -854,12 +854,12 @@ func TestGroupByExec(t *testing.T) { gb, err := aggr.NewGroupByExec(src, aggs, groupByExpr) if err != nil { - t.Fatalf("init failed: %v", err) + t.Errorf("init failed: %v", err) } batch, err := gb.Next(1024) if err != nil { - t.Fatalf("Next failed: %v", err) + t.Errorf("Next failed: %v", err) } deptCol := batch.Columns[0].(*array.String) @@ -903,7 +903,7 @@ func TestGroupByExec(t *testing.T) { want := expected[key] if got != want { - t.Fatalf("(%s,%s): expected sum=%f, got %f", d, r, want, got) + t.Errorf("(%s,%s): expected sum=%f, got %f", d, r, want, got) } } }) @@ -924,7 +924,7 @@ func TestGroupByExec(t *testing.T) { batch, err := gb.Next(1024) if err != nil { - t.Fatalf("Next failed: %v", err) + t.Errorf("Next failed: %v", err) } regionCol := batch.Columns[0].(*array.String) @@ -951,7 +951,7 @@ func TestGroupByExec(t *testing.T) { want := expected[k] if got != want { - t.Fatalf("region=%s expected %d got %d", k, want, got) + t.Errorf("region=%s expected %d got %d", k, want, got) } } }) @@ -984,7 +984,7 @@ func TestHavingExec(t *testing.T) { gb := buildDeptAvg() having := Expr.NewBinaryExpr( - Expr.NewColumnResolve("avg_Column(salary)"), + Expr.NewColumnResolve("salary"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 75000.0), ) @@ -992,7 +992,7 @@ func TestHavingExec(t *testing.T) { hv, _ := aggr.NewHavingExec(gb, having) batch, err := hv.Next(500) if err != nil { - t.Fatalf("having next failed: %v", err) + t.Errorf("having next failed: %v", err) } t.Logf("batch:\t%v\n", batch.PrettyPrint()) @@ -1001,7 +1001,7 @@ func TestHavingExec(t *testing.T) { for i := 0; i < int(batch.RowCount); i++ { if avgCol.Value(i) <= 75000 { - t.Fatalf("expected avg > 75k, got %f for dept %s", + t.Errorf("expected avg > 75k, got %f for dept %s", avgCol.Value(i), deptCol.Value(i)) } } @@ -1012,16 +1012,22 @@ func TestHavingExec(t *testing.T) { gb := buildDeptAvg() having := Expr.NewBinaryExpr( - Expr.NewColumnResolve("avg_Column(salary)"), + Expr.NewColumnResolve("salary"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 999999.0), ) - hv, _ := aggr.NewHavingExec(gb, having) - batch, _ := hv.Next(100) + hv, err := aggr.NewHavingExec(gb, having) + if err != nil { + t.Fatalf("failed to to construct havingExec: %v\n", err) + } + batch, err := hv.Next(100) + if err != nil { + t.Fatalf("failed to to grab recordbatch for havingExec: %v\n", err) + } if batch.RowCount != 0 { - t.Fatalf("expected empty result") + t.Errorf("expected empty result") } }) @@ -1030,7 +1036,7 @@ func TestHavingExec(t *testing.T) { gb := buildDeptAvg() having := Expr.NewBinaryExpr( - Expr.NewColumnResolve("avg_Column(salary)"), + Expr.NewColumnResolve("salary"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(0.0)), ) @@ -1039,7 +1045,7 @@ func TestHavingExec(t *testing.T) { batch, _ := hv.Next(1000) if batch.RowCount == 0 { - t.Fatalf("expected some rows") + t.Errorf("expected some rows") } }) } @@ -1056,7 +1062,7 @@ func TestDistinctExec(t *testing.T) { names, cols := generateIntegrationDataset1(mem) src, err := project.NewInMemoryProjectExecFromArrays(names, cols) if err != nil { - t.Fatalf("failed to create source: %v", err) + t.Errorf("failed to create source: %v", err) } // ------------------------------- @@ -1069,12 +1075,12 @@ func TestDistinctExec(t *testing.T) { de, err := filter.NewDistinctExec(src, expr) if err != nil { - t.Fatalf("distinct init failed: %v", err) + t.Errorf("distinct init failed: %v", err) } batch, err := de.Next(100) if err != nil { - t.Fatalf("distinct next failed: %v", err) + t.Errorf("distinct next failed: %v", err) } //deptArr := batch.Columns[5].(*array.String) @@ -1091,7 +1097,7 @@ func TestDistinctExec(t *testing.T) { } if int(batch.RowCount) != len(expected) { - t.Fatalf("expected %d distinct departments, got %d", + t.Errorf("expected %d distinct departments, got %d", len(expected), batch.RowCount) } }) @@ -1109,12 +1115,12 @@ func TestDistinctExec(t *testing.T) { de, err := filter.NewDistinctExec(src2, expr) if err != nil { - t.Fatalf("distinct init failed: %v", err) + t.Errorf("distinct init failed: %v", err) } batch, err := de.Next(100) if err != nil { - t.Fatalf("distinct next failed: %v", err) + t.Errorf("distinct next failed: %v", err) } regionArr := batch.Columns[6].(*array.String) @@ -1130,7 +1136,7 @@ func TestDistinctExec(t *testing.T) { } if int(regionArr.Len()) != len(expected) { - t.Fatalf("expected %d distinct regions, got %d", + t.Errorf("expected %d distinct regions, got %d", len(expected), regionArr.Len()) } }) @@ -1147,16 +1153,16 @@ func TestDistinctExec(t *testing.T) { de, err := filter.NewDistinctExec(src3, expr) if err != nil { - t.Fatalf("distinct init failed: %v", err) + t.Errorf("distinct init failed: %v", err) } batch, err := de.Next(100) if err != nil { - t.Fatalf("distinct next failed: %v", err) + t.Errorf("distinct next failed: %v", err) } if batch.RowCount != 20 { - t.Fatalf("expected 20 distinct id rows, got %d", batch.RowCount) + t.Errorf("expected 20 distinct id rows, got %d", batch.RowCount) } }) } @@ -1178,16 +1184,16 @@ func TestLimitExec(t *testing.T) { lim, err := filter.NewLimitExec(src, 5) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil { - t.Fatalf("limit next error: %v", err) + t.Errorf("limit next error: %v", err) } if batch.RowCount != 5 { - t.Fatalf("expected 5 rows, got %d", batch.RowCount) + t.Errorf("expected 5 rows, got %d", batch.RowCount) } // verify first 5 IDs match original dataset @@ -1196,7 +1202,7 @@ func TestLimitExec(t *testing.T) { for i := 0; i < 5; i++ { if idArr.Value(i) != origID.Value(i) { - t.Fatalf("row %d: expected id=%d, got id=%d", + t.Errorf("row %d: expected id=%d, got id=%d", i, origID.Value(i), idArr.Value(i)) } } @@ -1210,16 +1216,16 @@ func TestLimitExec(t *testing.T) { lim, err := filter.NewLimitExec(src, 20) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil { - t.Fatalf("limit error: %v", err) + t.Errorf("limit error: %v", err) } if batch.RowCount != 20 { - t.Fatalf("expected 20 rows, got %d", batch.RowCount) + t.Errorf("expected 20 rows, got %d", batch.RowCount) } }) @@ -1231,16 +1237,16 @@ func TestLimitExec(t *testing.T) { lim, err := filter.NewLimitExec(src, 50) if err != nil { - t.Fatalf("limit init failed: %v", err) + t.Errorf("limit init failed: %v", err) } batch, err := lim.Next(100) if err != nil { - t.Fatalf("limit next failed: %v", err) + t.Errorf("limit next failed: %v", err) } if batch.RowCount != 20 { - t.Fatalf("expected 20 rows when limit > dataset size, got %d", batch.RowCount) + t.Errorf("expected 20 rows when limit > dataset size, got %d", batch.RowCount) } }) } @@ -1266,12 +1272,12 @@ func TestScalarStringFunctions(t *testing.T) { // Evaluate: UPPER(department) batch, err := src.Next(100) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } arr, err := Expr.EvalScalarFunction(upperExpr, batch) if err != nil { - t.Fatalf("upper eval failed: %v", err) + t.Errorf("upper eval failed: %v", err) } out := arr.(*array.String) @@ -1283,13 +1289,13 @@ func TestScalarStringFunctions(t *testing.T) { for i := 0; i < int(out.Len()); i++ { if deptArr.IsNull(i) { if !out.IsNull(i) { - t.Fatalf("expected null at %d", i) + t.Errorf("expected null at %d", i) } continue } expected := strings.ToUpper(deptArr.Value(i)) if out.Value(i) != expected { - t.Fatalf("UPPER mismatch at row %d: got %s, expected %s", + t.Errorf("UPPER mismatch at row %d: got %s, expected %s", i, out.Value(i), expected) } } @@ -1304,12 +1310,12 @@ func TestScalarStringFunctions(t *testing.T) { // Evaluate: LOWER(department) batch, err := src.Next(100) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } arr, err := Expr.EvalScalarFunction(lowerExpr, batch) if err != nil { - t.Fatalf("lower eval failed: %v", err) + t.Errorf("lower eval failed: %v", err) } out := arr.(*array.String) @@ -1320,13 +1326,13 @@ func TestScalarStringFunctions(t *testing.T) { for i := 0; i < int(out.Len()); i++ { if deptArr.IsNull(i) { if !out.IsNull(i) { - t.Fatalf("expected null at %d", i) + t.Errorf("expected null at %d", i) } continue } expected := strings.ToLower(deptArr.Value(i)) if out.Value(i) != expected { - t.Fatalf("LOWER mismatch at row %d: got %s, expected %s", + t.Errorf("LOWER mismatch at row %d: got %s, expected %s", i, out.Value(i), expected) } } @@ -1337,12 +1343,12 @@ func TestScalarStringFunctions(t *testing.T) { fn := Expr.NewScalarFunction(Expr.Abs, Expr.NewColumnResolve("salary")) exec, err := project.NewProjectExec(src, []Expr.Expression{fn}) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := exec.Next(50) if err != nil { - t.Fatalf("exec failed: %v", err) + t.Errorf("exec failed: %v", err) } out := batch.Columns[0].(*array.Float64) @@ -1350,7 +1356,7 @@ func TestScalarStringFunctions(t *testing.T) { for i := 0; i < out.Len(); i++ { val := out.Value(i) if val < 0 { - t.Fatalf("abs result should never be negative, got %v", val) + t.Errorf("abs result should never be negative, got %v", val) } } }) @@ -1365,12 +1371,12 @@ func TestScalarStringFunctions(t *testing.T) { fn := Expr.NewScalarFunction(Expr.Round, Expr.NewColumnResolve("salary")) exec, err := project.NewProjectExec(src, []Expr.Expression{fn}) if err != nil { - t.Fatalf("project init failed: %v", err) + t.Errorf("project init failed: %v", err) } batch, err := exec.Next(50) if err != nil { - t.Fatalf("exec failed: %v", err) + t.Errorf("exec failed: %v", err) } out := batch.Columns[0].(*array.Float64) @@ -1381,7 +1387,7 @@ func TestScalarStringFunctions(t *testing.T) { got := out.Value(i) if expected != got { - t.Fatalf("round mismatch at %d: expected=%v got=%v", i, expected, got) + t.Errorf("round mismatch at %d: expected=%v got=%v", i, expected, got) } } }) @@ -1406,16 +1412,16 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("inner join init failed: %v", err) + t.Errorf("inner join init failed: %v", err) } batch, err := j.Next(1000) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } if batch.RowCount == 0 { - t.Fatalf("inner join returned zero rows (expected matches)") + t.Errorf("inner join returned zero rows (expected matches)") } }) @@ -1430,16 +1436,16 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.LeftJoin, nil) if err != nil { - t.Fatalf("left join init failed: %v", err) + t.Errorf("left join init failed: %v", err) } batch, err := j.Next(1000) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } if batch.RowCount < 20 { - t.Fatalf("left join should preserve all 20 left rows, got %d", batch.RowCount) + t.Errorf("left join should preserve all 20 left rows, got %d", batch.RowCount) } }) @@ -1454,16 +1460,16 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.RightJoin, nil) if err != nil { - t.Fatalf("right join init failed: %v", err) + t.Errorf("right join init failed: %v", err) } batch, err := j.Next(1000) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } if batch.RowCount < 20 { - t.Fatalf("right join should preserve all 20 right rows, got %d", batch.RowCount) + t.Errorf("right join should preserve all 20 right rows, got %d", batch.RowCount) } }) @@ -1479,16 +1485,16 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("inner join init failed: %v", err) + t.Errorf("inner join init failed: %v", err) } batch, err := j.Next(1000) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } if batch.RowCount != 0 { - t.Fatalf("expected zero matches, got %d", batch.RowCount) + t.Errorf("expected zero matches, got %d", batch.RowCount) } }) @@ -1509,16 +1515,16 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("multi-col join init failed: %v", err) + t.Errorf("multi-col join init failed: %v", err) } batch, err := j.Next(1000) if err != nil { - t.Fatalf("unexpected: %v", err) + t.Errorf("unexpected: %v", err) } if batch.RowCount == 0 { - t.Fatalf("multi-column join should match some rows") + t.Errorf("multi-column join should match some rows") } }) @@ -1533,7 +1539,7 @@ func TestHashJoinExec(t *testing.T) { j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) if err != nil { - t.Fatalf("join init failed: %v", err) + t.Errorf("join init failed: %v", err) } schema := j.Schema() @@ -1552,7 +1558,7 @@ func TestHashJoinExec(t *testing.T) { } if !foundLeft || !foundRight { - t.Fatalf("schema prefixing failed: left_department=%v right_department=%v", foundLeft, foundRight) + t.Errorf("schema prefixing failed: left_department=%v right_department=%v", foundLeft, foundRight) } }) } diff --git a/src/Backend/opti-sql-go/physical-optimizer/optimize.go b/src/Backend/opti-sql-go/physical-optimizer/optimize.go deleted file mode 100644 index 5d6461f..0000000 --- a/src/Backend/opti-sql-go/physical-optimizer/optimize.go +++ /dev/null @@ -1,3 +0,0 @@ -package physicaloptimizer - -// optimize the parsed plan diff --git a/src/Backend/opti-sql-go/physical-optimizer/optimize_test.go b/src/Backend/opti-sql-go/physical-optimizer/optimize_test.go deleted file mode 100644 index 00a5dbb..0000000 --- a/src/Backend/opti-sql-go/physical-optimizer/optimize_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package physicaloptimizer - -import "testing" - -func TestOptimize(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/physical-optimizer/parse.go b/src/Backend/opti-sql-go/physical-optimizer/parse.go deleted file mode 100644 index b160f02..0000000 --- a/src/Backend/opti-sql-go/physical-optimizer/parse.go +++ /dev/null @@ -1,3 +0,0 @@ -package physicaloptimizer - -// parse substrait into a format we can work with and optimize diff --git a/src/Backend/opti-sql-go/physical-optimizer/parse_test.go b/src/Backend/opti-sql-go/physical-optimizer/parse_test.go deleted file mode 100644 index e7f3646..0000000 --- a/src/Backend/opti-sql-go/physical-optimizer/parse_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package physicaloptimizer - -import "testing" - -func TestParse(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/substrait/GC.go b/src/Backend/opti-sql-go/substrait/GC.go new file mode 100644 index 0000000..a0cb0fe --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/GC.go @@ -0,0 +1,93 @@ +package substrait + +import ( + "context" + "fmt" + "opti-sql-go/config" + "time" + + "github.com/minio/minio-go" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// garbage collection for removing files from s3 storage after expiration +var dontTouchTestFiles = []string{"country_full.csv", "userdata.parquet", "example.txt", "random_test"} + +const ignoreFolder = "result-file-cache" +const loggerPrefix = "Garbage-Collection" +const waitTime = time.Second * 5 + +func garbageCollection() { + logger := config.GetLogger() + logger.Info(fmt.Sprintf("[%v]starting garbage collection, won't touch these files %v", loggerPrefix, dontTouchTestFiles)) + config := config.GetConfig() + redisInstance := redis.NewClient(&redis.Options{ + Addr: config.Server.RedisAddr + ":6379", + Password: "", // no password + DB: 0, // use default DB + Protocol: 2, + }) + secretes := config.Secretes + accessKey := secretes.AccessKey + secretKey := secretes.SecretKey + endpoint := secretes.EndpointURL + bucket := secretes.BucketName + useSSL := true + + client, err := minio.New(endpoint, accessKey, secretKey, useSSL) + if err != nil { + logger.Fatal("failed to construct s3 client to delete old files", zap.String("error message", fmt.Sprintf("%v", err))) + } + var failedAttempts = 0 + for { + start: + if failedAttempts > 5 { + logger.Warn("removing files has failed over 5 times, check redis and s3 for issues !!!") + } + fmt.Printf("waiting %v minutes before check for files to clear from s3", waitTime.Minutes()) + time.Sleep(waitTime) + start := time.Now() + entries, err := redisInstance.LRange(context.TODO(), ignoreFolder, 0, -1).Result() + if err != nil { + logger.Error(fmt.Sprintf("failed to read in files from %v", ignoreFolder), zap.Int("fail counter", failedAttempts)) + failedAttempts++ + goto start // try again + } + // read all the files in s3 + doneChan := make(chan struct{}) + readCount := 0 + var nonValidFiles []string + validMap := buildMap(dontTouchTestFiles, entries) + for fileName := range client.ListObjects(bucket, "", true, doneChan) { + if !validMap[fileName.Key] { + nonValidFiles = append(nonValidFiles, fileName.Key) + } + readCount++ + } + var removedFiles = 0 + for _, invalidFile := range nonValidFiles { + err := client.RemoveObject(bucket, invalidFile) + if err != nil { + logger.Warn(fmt.Sprintf("error removing %v from s3: %v", invalidFile, err)) + // log and move on + } else { + removedFiles++ + } + } + failedAttempts = 0 // reset failed attempts back to zero + logger.Info("Garbage Collection metrics", zap.Any("to-keep map", validMap), zap.Int("total-files count", readCount), zap.Int("removed-files count", removedFiles), zap.Any("time-taken", time.Since(start))) + + } + +} +func buildMap(source1 []string, source2 []string) map[string]bool { + result := make(map[string]bool) + for _, k := range source1 { + result[k] = true + } + for _, k := range source2 { + result[k] = true + } + return result +} diff --git a/src/Backend/opti-sql-go/substrait/expr.md b/src/Backend/opti-sql-go/substrait/expr.md new file mode 100644 index 0000000..35a4156 --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/expr.md @@ -0,0 +1,188 @@ +--- +## Expressions + +Expressions are encoded as **tagged objects** using `expr_type`. +They are evaluated row-wise against a record batch and return an Arrow array. + +When applicable, expressions **may include an explicit Arrow type** to avoid inference ambiguity. +--- + +## `Valid Literal Types` + +```bash +"int" +"string" +"boolean" +"float64" +``` + +## `Valid Binary Operators` + +```bash +"Addition" +"Subtraction" +"Multiplication" +"Division" +# comparison +"Equal" +"NotEqual" +"LessThan" +"LessThanOrEqual" +"GreaterThan" +"GreaterThanOrEqual" +# logical +"And" +"Or" +``` + +## `Valid Scalar functions` + +```bash +"Upper" +"Lower" +"Abs" +"Round" +``` + +## `Valid Aggregations functions` + +#### note they are lower case + +```bash +"sum" +"count" +"avg" +"min" +"max" +``` + +## `ColumnResolve` + +Resolves a column from the input batch. + +```bash +{ + "expr_type": "ColumnResolve", + "name": "a" +} +``` + +--- + +## `LiteralResolve` + +Represents a constant literal value. + +```bash +{ + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" +} +``` + +### Notes + +- `lit_type` is **optional** +- When provided, it must be a valid Arrow primitive type +- If omitted, the engine may infer the type + +--- + +## `BinaryExpr` + +Applies a binary operator to two expressions. + +```bash +{ + "expr_type": "BinaryExpr", + "op": "GreaterThan", # or any (valid) binary operator + "left": { + "expr_type": "ColumnResolve", + "name": "a" + }, + "right": { + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" + } +} +``` + +- Comparison and logical operators must return a boolean array +- Left and right expressions must resolve to compatible Arrow types + +--- + +## `ScalarFunction` + +Applies a scalar function element-wise. + +```bash +{ + "expr_type": "ScalarFunction", + "func": "Upper", # or any (Valid) scalar function + "expr": + { + "expr_type": "ColumnResolve", + "name": "name" + } + +} +``` + +--- + +## `Alias` + +Attaches a name to an expression. + +```bash +{ + "expr_type": "Alias", + "expr": { + "expr_type": "ColumnResolve", + "name": "a" + }, + "name": "alias_a" +} +``` + +- Alias affects **naming only** +- Evaluation result is unchanged + +--- + +## `CastExpr` + +Casts the result of an expression to a specific Arrow type. + +```bash +{ + "expr_type": "CastExpr", + "expr": { + "expr_type": "ColumnResolve", + "name": "a" + }, + "to_type": "float64" +} +``` + +--- + +## Expression Type Enum + +`expr_type` is a **closed enum**. + +```bash +ColumnResolve +LiteralResolve +BinaryExpr +ScalarFunction +Alias +CastExpr +``` + +Each expression object **must** contain exactly one `expr_type`. + +--- diff --git a/src/Backend/opti-sql-go/substrait/format.md b/src/Backend/opti-sql-go/substrait/format.md new file mode 100644 index 0000000..a76b7ed --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/format.md @@ -0,0 +1,506 @@ +# Custom intermediate in memory representation of sql logical/physical plans + +### why? + +_The primary reason for this layer is flexibility. By decoupling intermediate data representation from Substrait plans, we can accept multiple data formats. As long as we interpret them into this IR, the physical operators work unchanged_ + +## Source operator + +```bash +{ + "Operator": "Source", + "Source": { + "file-name": "link-to-s3", + "local": false + } +} +# file ext must end in .csv or .parquet +# local? download to local machine or keep streaming from s3 bucket +``` + +--- + +## Project operator + +**sql** : `select a , b , c` + +```bash +{ + "Operator": "Project", + "Project": { + "input": {operator}, + "expressions": [{Expression},{Expression},{Expression}] + } +} +``` + +#### example + +```bash +{ + "Operator": "Project", + "Project": { + "input": { + "Operator": "Source", + "Source": { + "file-name": "country-full.csv", + "local": false + } + }, + "expressions": [ + { "Expression": "" }, + { "Expression": "" }, + { "Expression": "" } + ] + } +} +``` + +--- + +## Filter Operator + +**sql**: `select a,b from source where a > 10` + +```bash +{ + "Operator": "Filter", + "Filter": { + "input": {operator}, + "expression": {Expression} + } +} +``` + +**Example** + +```bash +{ + "Operator": "Filter", + "Filter": { + "input": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/data.csv", + "local": false + } + } + }, + "expression": { + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" + } + } +} +``` + +--- + +## Distinct Operator + +**sql**: `select distinct a, b from source` + +```bash +{ + "Operator": "Distinct", + "Distinct": { + "input": {operator}, + "expressions": [{Expression},{Expression},{Expression}] + } +} +``` + +- Removes duplicate rows based on the specified columns +- Output includes only the listed columns + +--- + +## Limit Operator + +**sql**: `select a,b from source limit 10` + +```bash +{ + "Operator": "Limit", + "Limit": { + "input": {operator}, + "limit": 10 + } +} +``` + +#### max value for limit is 2^16-1 (max uint16) + +--- + +## Sort Operator + +**sql**: `select a,b from source order by a desc, b asc` + +```bash +{ + "Operator": "Sort", + "Sort": { + "input": {operator}, + "by": [ + { + "expr": {Expression}, + "asc": boolean + }, + { + "expr": {Expression}, + "asc": boolean + } + ] + } +} +``` + +- `order` defaults to `ASC` if omitted + +--- + +## Single Column Aggregation Operator + +**sql**: `select sum(a) from source` + +```bash +{ + "Operator": "Aggregate", + "Aggregate": { + "input": {operator}, + "aggrs": [ + { + "function": "sum", + "expr": {Expression}, + } + ] + } +} +``` + +- Operates on exactly one column +- `alias` is **optional** +- Output contains a single row + +--- + +## Having Operator + +**sql**: `select sum(a) from source having sum(a) > 10` + +```bash +{ + "Operator": "Having", + "Having": { + "input": {operator}, + "expression": {Expression} + } +} +``` + +- Semantics identical to `Filter` +- Applied after aggregation +- Expression must resolve to a boolean mask + +--- + +## Join Operator + +**sql**: +`select * from a join b on a.id = b.id` + +```bash +{ + "Operator": "Join", + "Join": { + "left": {operator}, + "right": {operator}, + "join_type": "Inner", + "on": [ + { + "left": { "expr_type": "ColumnResolve", "name": "a.id" }, + "right": { "expr_type": "ColumnResolve", "name": "b.id" } + }, + { + "left": { "expr_type": "ColumnResolve", "name": "a.age" }, + "right": { "expr_type": "ColumnResolve", "name": "b.distance" } + } + ] + } +} +``` + +### Supported Join Types + +- `Inner` + +### Notes + +- `on` is an array to support multi-column joins +- Join condition expressions must be **equality comparisons** + +--- + +## Group By Operator + +**sql**: +`select b, sum(a) from source group by b` + +```bash +{ + "Operator": "GroupBy", + "GroupBy": { + "input": {operator}, + "group_by": [ + { "expr_type": "ColumnResolve", "name": "b" } + ], + "aggrs": [ + { + "function": "Sum", + "expr": {Expression}, + } + ] + } +} +``` + +### Notes + +- `group_by` defines the grouping keys +- Each aggregate operates on **exactly one column** +- `alias` on aggregates is optional +- One output row is produced per group + +--- + +## Example 1 — Source → Filter + +**sql**: `select * from source where a > 10` + +```bash +{ + "Emit": { + "Operator": "Filter", + "Filter": { + "input": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/data.csv", + "local": false + } + } + }, + "expression": { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": { "expr_type": "ColumnResolve", "name": "a" }, + "right": { + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" + } + } + } + + } +} +``` + +--- + +## Example 2 — Source → Project → Sort + +**sql**: `select a, b from source order by a` + +```bash +{ + "Emit": { + "Operator": "Sort", + "Sort": { + "input": { + "Operator": "Project", + "Project": { + "input": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/data.csv", + "local": false + } + } + }, + "expressions": [ + { "Expression": "a" }, + { "Expression": "b" } + ] + } + }, + "by": [ + { + "Expr": { "expr_type": "ColumnResolve", "name": "a" }, + "asc": true + } + ] + } + } +} +``` + +--- + +## Example 3 — Source → Group By → Aggregate + +**sql**: `select b, count(a) from source group by b` + +```bash +{ + "Emit": { + "Operator": "GroupBy", + "GroupBy": { + "input": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/data.csv", + "local": false + } + } + }, + "group_by": [ + { "expr_type": "ColumnResolve", "name": "b" } + ], + "aggregates": [ + { + "function": "Count", + "column": "a", + "alias": "count_a" + } + ] + } + } +} +``` + +--- + +## Example 4 — Source → Distinct → Limit + +**sql**: `select distinct a from source limit 5` + +```bash +{ + "Emit": { + "Operator": "Limit", + "Limit": { + "input": { + "Operator": "Distinct", + "Distinct": { + "input": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/data.csv", + "local": false + } + } + }, + "expressions": [ + { "Expression": "a" } + ] + } + }, + "limit": 5 + } + } +} +``` + +--- + +## Example 5 — Join → Filter → Sort → Limit + +**sql**: + +```sql +select u.name, o.amount +from users u +join orders o on u.id = o.user_id +where o.amount > 50 +order by o.amount desc +limit 10 +``` + +```bash +{ + "Emit": { + "Operator": "Limit", + "Limit": { + "input": { + "Operator": "Sort", + "Sort": { + "input": { + "Operator": "Filter", + "Filter": { + "input": { + "Operator": "Join", + "Join": { + "left": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/users.csv", + "local": false + } + } + }, + "right": { + "Operator": "Source", + "Source": { + "source-node": { + "file-name": "s3://bucket/orders.csv", + "local": false + } + } + }, + "join_type": "Inner", + "on": [ + { + "left": { "expr_type": "ColumnResolve", "name": "u.id" }, + "right": { "expr_type": "ColumnResolve", "name": "o.user_id" } + } + ] + } + }, + "expression": { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": { "expr_type": "ColumnResolve", "name": "o.amount" }, + "right": { + "expr_type": "LiteralResolve", + "value": 50, + "lit_type": "int" + } + } + } + }, + "by": [ + { + "Expr": { "expr_type": "ColumnResolve", "name": "o.amount" }, + "asc": false + } + ] + } + }, + "limit": 10 + } + } +} +``` + +--- + +SELECT id, age_years as age from integration_test_data WHERE age > 15 LIMIT 5 diff --git a/src/Backend/opti-sql-go/substrait/operation.pb.go b/src/Backend/opti-sql-go/substrait/operation.pb.go index 00f49f2..09a7b52 100644 --- a/src/Backend/opti-sql-go/substrait/operation.pb.go +++ b/src/Backend/opti-sql-go/substrait/operation.pb.go @@ -1,17 +1,18 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.10 +// protoc-gen-go v1.36.11 // protoc v6.32.0 // source: operation.proto package substrait import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -84,13 +85,12 @@ func (ReturnTypes) EnumDescriptor() ([]byte, []int) { // The request message containing the operation details. type QueryExecutionRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - SubstraitLogical []byte `protobuf:"bytes,1,opt,name=substrait_logical,json=substraitLogical,proto3" json:"substrait_logical,omitempty"` //SS logical plan - SqlStatement string `protobuf:"bytes,2,opt,name=sql_statement,json=sqlStatement,proto3" json:"sql_statement,omitempty"` // original sql statement - Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"` // unique id for this client - Source *SourceType `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // (s3 link| base64 data) - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + LogicalPlan string `protobuf:"bytes,1,opt,name=logical_plan,json=logicalPlan,proto3" json:"logical_plan,omitempty"` // Substrait logical plan: serialized representation of the query execution (contains s3 link to the source data) + SqlStatement string `protobuf:"bytes,2,opt,name=sql_statement,json=sqlStatement,proto3" json:"sql_statement,omitempty"` // original sql statement + Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"` // unique id for this client + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *QueryExecutionRequest) Reset() { @@ -123,11 +123,11 @@ func (*QueryExecutionRequest) Descriptor() ([]byte, []int) { return file_operation_proto_rawDescGZIP(), []int{0} } -func (x *QueryExecutionRequest) GetSubstraitLogical() []byte { +func (x *QueryExecutionRequest) GetLogicalPlan() string { if x != nil { - return x.SubstraitLogical + return x.LogicalPlan } - return nil + return "" } func (x *QueryExecutionRequest) GetSqlStatement() string { @@ -144,13 +144,6 @@ func (x *QueryExecutionRequest) GetId() string { return "" } -func (x *QueryExecutionRequest) GetSource() *SourceType { - if x != nil { - return x.Source - } - return nil -} - // The response message containing the result. type QueryExecutionResponse struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -204,58 +197,6 @@ func (x *QueryExecutionResponse) GetErrorType() *ErrorDetails { return nil } -type SourceType struct { - state protoimpl.MessageState `protogen:"open.v1"` - S3Source string `protobuf:"bytes,1,opt,name=s3_source,json=s3Source,proto3" json:"s3_source,omitempty"` // s3 link to the source data - Mime string `protobuf:"bytes,2,opt,name=mime,proto3" json:"mime,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *SourceType) Reset() { - *x = SourceType{} - mi := &file_operation_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *SourceType) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*SourceType) ProtoMessage() {} - -func (x *SourceType) ProtoReflect() protoreflect.Message { - mi := &file_operation_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use SourceType.ProtoReflect.Descriptor instead. -func (*SourceType) Descriptor() ([]byte, []int) { - return file_operation_proto_rawDescGZIP(), []int{2} -} - -func (x *SourceType) GetS3Source() string { - if x != nil { - return x.S3Source - } - return "" -} - -func (x *SourceType) GetMime() string { - if x != nil { - return x.Mime - } - return "" -} - type ErrorDetails struct { state protoimpl.MessageState `protogen:"open.v1"` ErrorType ReturnTypes `protobuf:"varint,1,opt,name=error_type,json=errorType,proto3,enum=contract.ReturnTypes" json:"error_type,omitempty"` @@ -266,7 +207,7 @@ type ErrorDetails struct { func (x *ErrorDetails) Reset() { *x = ErrorDetails{} - mi := &file_operation_proto_msgTypes[3] + mi := &file_operation_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -278,7 +219,7 @@ func (x *ErrorDetails) String() string { func (*ErrorDetails) ProtoMessage() {} func (x *ErrorDetails) ProtoReflect() protoreflect.Message { - mi := &file_operation_proto_msgTypes[3] + mi := &file_operation_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -291,7 +232,7 @@ func (x *ErrorDetails) ProtoReflect() protoreflect.Message { // Deprecated: Use ErrorDetails.ProtoReflect.Descriptor instead. func (*ErrorDetails) Descriptor() ([]byte, []int) { - return file_operation_proto_rawDescGZIP(), []int{3} + return file_operation_proto_rawDescGZIP(), []int{2} } func (x *ErrorDetails) GetErrorType() ReturnTypes { @@ -312,20 +253,15 @@ var File_operation_proto protoreflect.FileDescriptor const file_operation_proto_rawDesc = "" + "\n" + - "\x0foperation.proto\x12\bcontract\"\xa7\x01\n" + - "\x15QueryExecutionRequest\x12+\n" + - "\x11substrait_logical\x18\x01 \x01(\fR\x10substraitLogical\x12#\n" + + "\x0foperation.proto\x12\bcontract\"o\n" + + "\x15QueryExecutionRequest\x12!\n" + + "\flogical_plan\x18\x01 \x01(\tR\vlogicalPlan\x12#\n" + "\rsql_statement\x18\x02 \x01(\tR\fsqlStatement\x12\x0e\n" + - "\x02id\x18\x03 \x01(\tR\x02id\x12,\n" + - "\x06source\x18\x04 \x01(\v2\x14.contract.SourceTypeR\x06source\"u\n" + + "\x02id\x18\x03 \x01(\tR\x02id\"u\n" + "\x16QueryExecutionResponse\x12$\n" + "\x0es3_result_link\x18\x01 \x01(\tR\fs3ResultLink\x125\n" + "\n" + - "error_type\x18\x02 \x01(\v2\x16.contract.ErrorDetailsR\terrorType\"=\n" + - "\n" + - "SourceType\x12\x1b\n" + - "\ts3_source\x18\x01 \x01(\tR\bs3Source\x12\x12\n" + - "\x04mime\x18\x02 \x01(\tR\x04mime\"^\n" + + "error_type\x18\x02 \x01(\v2\x16.contract.ErrorDetailsR\terrorType\"^\n" + "\fErrorDetails\x124\n" + "\n" + "error_type\x18\x01 \x01(\x0e2\x15.contract.returnTypesR\terrorType\x12\x18\n" + @@ -354,25 +290,23 @@ func file_operation_proto_rawDescGZIP() []byte { } var file_operation_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_operation_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_operation_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_operation_proto_goTypes = []any{ (ReturnTypes)(0), // 0: contract.returnTypes (*QueryExecutionRequest)(nil), // 1: contract.QueryExecutionRequest (*QueryExecutionResponse)(nil), // 2: contract.QueryExecutionResponse - (*SourceType)(nil), // 3: contract.SourceType - (*ErrorDetails)(nil), // 4: contract.ErrorDetails + (*ErrorDetails)(nil), // 3: contract.ErrorDetails } var file_operation_proto_depIdxs = []int32{ - 3, // 0: contract.QueryExecutionRequest.source:type_name -> contract.SourceType - 4, // 1: contract.QueryExecutionResponse.error_type:type_name -> contract.ErrorDetails - 0, // 2: contract.ErrorDetails.error_type:type_name -> contract.returnTypes - 1, // 3: contract.SSOperation.ExecuteQuery:input_type -> contract.QueryExecutionRequest - 2, // 4: contract.SSOperation.ExecuteQuery:output_type -> contract.QueryExecutionResponse - 4, // [4:5] is the sub-list for method output_type - 3, // [3:4] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 3, // 0: contract.QueryExecutionResponse.error_type:type_name -> contract.ErrorDetails + 0, // 1: contract.ErrorDetails.error_type:type_name -> contract.returnTypes + 1, // 2: contract.SSOperation.ExecuteQuery:input_type -> contract.QueryExecutionRequest + 2, // 3: contract.SSOperation.ExecuteQuery:output_type -> contract.QueryExecutionResponse + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_operation_proto_init() } @@ -386,7 +320,7 @@ func file_operation_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_operation_proto_rawDesc), len(file_operation_proto_rawDesc)), NumEnums: 1, - NumMessages: 4, + NumMessages: 3, NumExtensions: 0, NumServices: 1, }, diff --git a/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go b/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go index 3b87fab..cbe80d5 100644 --- a/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go +++ b/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 +// - protoc-gen-go-grpc v1.6.0 // - protoc v6.32.0 // source: operation.proto @@ -68,7 +68,7 @@ type SSOperationServer interface { type UnimplementedSSOperationServer struct{} func (UnimplementedSSOperationServer) ExecuteQuery(context.Context, *QueryExecutionRequest) (*QueryExecutionResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ExecuteQuery not implemented") + return nil, status.Error(codes.Unimplemented, "method ExecuteQuery not implemented") } func (UnimplementedSSOperationServer) mustEmbedUnimplementedSSOperationServer() {} func (UnimplementedSSOperationServer) testEmbeddedByValue() {} @@ -81,7 +81,7 @@ type UnsafeSSOperationServer interface { } func RegisterSSOperationServer(s grpc.ServiceRegistrar, srv SSOperationServer) { - // If the following call pancis, it indicates UnimplementedSSOperationServer was + // If the following call panics, it indicates UnimplementedSSOperationServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/src/Backend/opti-sql-go/substrait/server.go b/src/Backend/opti-sql-go/substrait/server.go index 5fe5107..159154d 100644 --- a/src/Backend/opti-sql-go/substrait/server.go +++ b/src/Backend/opti-sql-go/substrait/server.go @@ -2,14 +2,19 @@ package substrait import ( "context" + "encoding/base64" "fmt" "log" + "math/rand/v2" "net" "opti-sql-go/config" + "opti-sql-go/operators/project" "os" "os/signal" + "strings" "syscall" + "go.uber.org/zap" "google.golang.org/grpc" ) @@ -26,12 +31,136 @@ func newSubstraitServer(l *net.Listener) *SubstraitServer { } // ExecuteQuery implements the gRPC service method -func (s *SubstraitServer) ExecuteQuery(ctx context.Context, req *QueryExecutionRequest) (*QueryExecutionResponse, error) { - fmt.Printf("Received query request: logical_plan:%v\n sql:%s\n id:%v\n source: %v\n", req.SubstraitLogical, req.SqlStatement, req.Id, req.Source) +func (s *SubstraitServer) ExecuteQuery(ctx context.Context, req *QueryExecutionRequest) (resp *QueryExecutionResponse, err error) { + logger := config.GetLogger() + + // Panic recovery to prevent one failing query from taking down the entire server + defer func() { + if r := recover(); r != nil { + logger.Error("Panic recovered in ExecuteQuery", + zap.Any("panic", r), + zap.String("query_id", req.Id), + zap.String("sql", req.SqlStatement), + zap.Stack("stack_trace"), + ) + resp = &QueryExecutionResponse{ + S3ResultLink: "NAN", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_EXECUTION_ERROR, + Message: fmt.Sprintf("Internal server error: panic recovered: %v", r), + }, + } + err = nil // Return error as part of response, not as gRPC error + } + }() + + logger.Info("Received query execution request", + zap.String("query_id", req.Id), + zap.String("sql", req.SqlStatement), + zap.Int("plan_size_bytes", len(req.LogicalPlan)), + ) + + decodedPlan, err := base64.StdEncoding.DecodeString(req.LogicalPlan) + if err != nil { + logger.Error("Failed to decode base64 plan", + zap.Error(err), + zap.String("query_id", req.Id), + ) + return nil, fmt.Errorf("failed to base64 decode logical plan: %w", err) + } + logger.Debug("Plan decoded successfully", zap.Int("decoded_size_bytes", len(decodedPlan))) + logger.Debug("Received query request details", zap.String("logical_plan", string(decodedPlan)), zap.String("sql", req.SqlStatement), zap.String("id", req.Id)) + planM := newPlanMetaData(req.Id) + source := strings.NewReader(string(decodedPlan)) + + logger.Info("Parsing logical plan", zap.String("query_id", req.Id)) + results, err := consumePlan(source, planM) + if err != nil { + logger.Error("Failed to parse logical plan", + zap.Error(err), + zap.String("query_id", req.Id), + ) + return &QueryExecutionResponse{ + S3ResultLink: "NAN", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_PARSE_ERROR, + Message: err.Error(), + }, + }, nil + } + + logger.Info("Executing query plan", zap.String("query_id", req.Id)) + rc, err := results.consumeAll() + if err != nil { + logger.Error("Failed to execute query plan", + zap.Error(err), + zap.String("query_id", req.Id), + ) + return &QueryExecutionResponse{ + S3ResultLink: "NAN", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_EXECUTION_ERROR, + Message: err.Error(), + }, + }, nil + + } + + logger.Info("Converting results to CSV", + zap.String("query_id", req.Id), + zap.Uint64("row_count", rc.RowCount), + ) + csv, err := rc.ToCSV() + if err != nil { + logger.Error("Failed to convert results to CSV", + zap.Error(err), + zap.String("query_id", req.Id), + ) + return &QueryExecutionResponse{ + S3ResultLink: "NAN", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_UPLOAD_ERROR, + Message: err.Error(), + }, + }, nil + + } + // include random number for the sake of avoiding conflicts, should resolve this at the + // logical processing step but for now this works + fName := fmt.Sprintf("%s-%s-%d", strings.ReplaceAll(req.SqlStatement, " ", "-"), req.Id, rand.IntN(1000)) + // ! todo: finish debugging + logger.Debug("CSV file produced", zap.String("file_name", fName), zap.String("csv_content", string(csv))) + + logger.Info("Uploading results to S3", + zap.String("query_id", req.Id), + zap.String("file_name", fName), + zap.Int("csv_size_bytes", len(csv)), + ) + if err = project.UploadResults(fName, csv); err != nil { + logger.Error("Failed to upload results to S3", + zap.Error(err), + zap.String("query_id", req.Id), + zap.String("file_name", fName), + ) + return &QueryExecutionResponse{ + S3ResultLink: "NAN", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_UPLOAD_ERROR, + Message: err.Error(), + }, + }, nil + + } + + logger.Info("Query executed successfully", + zap.String("query_id", req.Id), + zap.String("s3_link", fName), + zap.Uint64("result_rows", rc.RowCount), + ) // Placeholder response return &QueryExecutionResponse{ - S3ResultLink: "", + S3ResultLink: fName, ErrorType: &ErrorDetails{ ErrorType: ReturnTypes_SUCCESS, Message: "Query executed successfully", @@ -45,15 +174,16 @@ func Start() chan struct{} { if err != nil { log.Fatalf("Failed to listen on port %d: %v", c.Server.Port, err) } - + logger := config.GetLogger() + logger.Info("Execution server is running", zap.String("host", c.Server.Host), zap.Int("port", c.Server.Port)) grpcServer := grpc.NewServer() ss := newSubstraitServer(&listener) RegisterSSOperationServer(grpcServer, ss) stopChan := make(chan struct{}) - log.Printf("Substrait server listening on port %d", c.Server.Port) go unifiedShutdownHandler(ss, grpcServer, stopChan) + go garbageCollection() go func() { if err := grpcServer.Serve(*ss.listener); err != nil { log.Fatalf("Failed to serve: %v", err) @@ -62,14 +192,15 @@ func Start() chan struct{} { return stopChan } func unifiedShutdownHandler(s *SubstraitServer, grpcServer *grpc.Server, stopChan chan struct{}) { + logger := config.GetLogger() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) select { case <-stopChan: - fmt.Println("Shutdown requested by caller.") + logger.Info("Shutdown requested by caller") case sig := <-sigChan: - fmt.Printf("Received signal: %v\n", sig) + logger.Info("Received signal", zap.String("signal", sig.String())) } l := *s.listener @@ -77,5 +208,6 @@ func unifiedShutdownHandler(s *SubstraitServer, grpcServer *grpc.Server, stopCha grpcServer.GracefulStop() - fmt.Println("Server shutdown complete") + logger.Info("Server shutdown complete") + os.Exit(1) } diff --git a/src/Backend/opti-sql-go/substrait/substrait.go b/src/Backend/opti-sql-go/substrait/substrait.go index a809ba7..4aa10d1 100644 --- a/src/Backend/opti-sql-go/substrait/substrait.go +++ b/src/Backend/opti-sql-go/substrait/substrait.go @@ -1 +1,1176 @@ package substrait + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "math" + "opti-sql-go/Expr" + "opti-sql-go/config" + "opti-sql-go/operators" + "opti-sql-go/operators/aggr" + "opti-sql-go/operators/filter" + "opti-sql-go/operators/join" + "opti-sql-go/operators/project" + "os" + "reflect" + "strings" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "go.uber.org/zap" +) + +var ( + ErrInvalidSubstraitPlan = func(e error) error { + return fmt.Errorf("invalid JSON from frontend: %s", e.Error()) + } + ErrMalformedEmitBody = fmt.Errorf("malformed logical plan: multiple root operators found, expected exactly one") + + ErrMissingEmitOperator = fmt.Errorf("malformed logical plan: missing 'Emit' operator") + + ErrInvalidEmitChildren = fmt.Errorf("malformed logical plan: 'Emit' input must be a key to a JSON object") + + ErrInvalidOperator = func(operator string) error { + return fmt.Errorf("invalid operator '%s': cannot be called directly before 'Emit'", operator) + } + ErrBuildTreeFailed = func(operator string, context string) error { + return fmt.Errorf("failed to build operator tree for '%s': %s", operator, context) + } +) + +type jsonOBJ = map[string]interface{} + +type Emiter struct { + emitOperator operators.Operator + p *planMetaData +} + +func (e *Emiter) consumeAll() (*operators.RecordBatch, error) { + logger := config.GetLogger() + var results *operators.RecordBatch + logger.Info("Starting consumeAll", zap.String("operator", e.emitOperator.Name()), zap.String("plan_id", e.p.id)) + logger.Debug("Inner operator name", zap.String("operator", e.emitOperator.Name())) + mem := memory.NewGoAllocator() + iterationCount := 0 + for { + intermediate, err := e.emitOperator.Next(math.MaxInt16) + if err != nil { + if errors.Is(err, io.EOF) { + logger.Info("Reached EOF", zap.Int("iterations", iterationCount)) + break + } + logger.Error("Error fetching next batch", zap.Error(err), zap.Int("iteration", iterationCount)) + return nil, err + } + iterationCount++ + // first iteration set results to the intermediate results + if results == nil { + logger.Info("First batch received", zap.Uint64("rows", intermediate.RowCount), zap.Int("columns", len(intermediate.Columns))) + results = intermediate + continue + } + // otherwise just append for each idx + logger.Debug("Concatenating batch", zap.Uint64("new_rows", intermediate.RowCount), zap.Uint64("total_rows", results.RowCount)) + for i := range intermediate.Columns { + oldArr := results.Columns[i] + newArr := intermediate.Columns[i] + joinArr, err := array.Concatenate([]arrow.Array{oldArr, newArr}, mem) + if err != nil { + logger.Error("Failed to concatenate arrays", zap.Error(err), zap.Int("column_index", i)) + return nil, err + } + results.Columns[i] = joinArr + } + results.RowCount += intermediate.RowCount + + } + // delete source files + logger.Info("Cleaning up local files", zap.Int("file_count", len(e.p.localFileNames))) + for _, file := range e.p.localFileNames { + if err := os.Remove(file); err != nil { + logger.Error("Failed to delete local file", zap.Error(err), zap.String("file", file)) + return nil, err + } + logger.Debug("Deleted local file", zap.String("file", file)) + } + _ = e.emitOperator.Close() + logger.Info("consumeAll completed", zap.Uint64("total_rows", results.RowCount), zap.Int("total_columns", len(results.Columns))) + return results, nil +} + +// post-order: children first, then your name. +// NOTE: This assumes every operator you care about has exactly ONE input child in a field named `Input` +// (Join is the common exception: Left/Right). + +type planMetaData struct { + id string + localFileNames []string // check if empty before deleting the file +} + +func newPlanMetaData(id string) *planMetaData { + return &planMetaData{id: id} + +} + +// first turn into json. The plan should fit into ram to consume it all +func consumePlan(r io.Reader, p *planMetaData) (*Emiter, error) { + logger := config.GetLogger() + logger.Info("Starting plan consumption", zap.String("plan_id", p.id)) + + contents, err := io.ReadAll(r) + if err != nil { + logger.Error("Failed to read plan", zap.Error(err)) + return nil, err + } + logger.Debug("Plan read successfully", zap.Int("bytes", len(contents))) + + inMemoryRepr := make(jsonOBJ) + err = json.Unmarshal(contents, &inMemoryRepr) + if err != nil { + logger.Error("Failed to unmarshal JSON plan", zap.Error(err)) + return nil, ErrInvalidSubstraitPlan(err) + } + if len(inMemoryRepr) != 1 { + logger.Error("Malformed plan body", zap.Int("root_count", len(inMemoryRepr))) + return nil, ErrMalformedEmitBody + } + _, exist := inMemoryRepr["Emit"] // TODO! standerdize the spelling and casing of this or else everythign else will break + if !exist { + logger.Error("Missing Emit operator") + return nil, ErrMissingEmitOperator + } + tree, ok := inMemoryRepr["Emit"].(map[string]any) + if !ok { + logger.Error("Invalid Emit children type") + return nil, ErrInvalidEmitChildren + } + logger.Info("Plan structure validated, building operator tree") + return buildTree(tree, p) +} + +func buildTree(m jsonOBJ, plan *planMetaData) (*Emiter, error) { + logger := config.GetLogger() + //key=Operator , value=arguments to that operator + + // the tree needs to be built from the bottom up. Recurse all the way down until you reach a leaf node || key == "Source" + const operator = "Operator" + + if err := containsFields([]string{operator}, m); err != nil { + logger.Error("Missing operator field in tree node", zap.Error(err)) + return nil, err + } + if err := correctFieldTypes([]string{operator}, []string{"string"}, m); err != nil { + logger.Error("Invalid operator field type", zap.Error(err)) + return nil, err + } + + operatorNode := m[operator].(string) + logger.Info("Building operator", zap.String("operator_type", operatorNode)) + body := m[operatorNode].(map[string]any) + var op operators.Operator + switch strings.ToLower(operatorNode) { + case "filter": + filterOP, err := parseFilter(body, plan) + if err != nil { + logger.Error("Failed to build filter operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("filter", err.Error()) + } + op = filterOP + logger.Info("Filter operator built successfully") + return &Emiter{op, plan}, nil + case "project": + projectOP, err := parseProject(body, plan) + if err != nil { + logger.Error("Failed to build project operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("project", err.Error()) + } + op = projectOP + logger.Info("Project operator built successfully") + return &Emiter{op, plan}, nil + case "sort": + sortOP, err := parseSort(body, plan) + if err != nil { + logger.Error("Failed to build sort operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("sort", err.Error()) + } + op = sortOP + logger.Info("Sort operator built successfully") + return &Emiter{op, plan}, nil + + case "distinct": + distinctOP, err := parseDistinct(body, plan) + if err != nil { + logger.Error("Failed to build distinct operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("distinct", err.Error()) + } + op = distinctOP + logger.Info("Distinct operator built successfully") + return &Emiter{op, plan}, nil + case "limit": + limitOP, err := parseLimit(body, plan) + if err != nil { + logger.Error("Failed to build limit operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("limit", err.Error()) + } + op = limitOP + logger.Info("Limit operator built successfully") + return &Emiter{op, plan}, nil + case "aggregate": + aggrOP, err := parseSingleAggr(body, plan) + if err != nil { + logger.Error("Failed to build aggregate operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("single-aggr", err.Error()) + } + op = aggrOP + logger.Info("Aggregate operator built successfully") + return &Emiter{op, plan}, nil + case "groupby": + groupByOP, err := parseGroupBy(body, plan) + if err != nil { + logger.Error("Failed to build groupby operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("group-by", err.Error()) + } + op = groupByOP + logger.Info("GroupBy operator built successfully") + return &Emiter{op, plan}, err + + case "join": + joinOP, err := parseJoin(body, plan) + if err != nil { + logger.Error("Failed to build join operator", zap.Error(err)) + return nil, ErrBuildTreeFailed("join", err.Error()) + } + op = joinOP + logger.Info("Join operator built successfully") + return &Emiter{op, plan}, err + + case "source", "expression": // invalid branch + //(1) Source:cannot directy return from source + //(2) expressions: cannot directy return expressions, need to call project on top + logger.Error("Invalid operator cannot be called before Emit", zap.String("operator", operatorNode)) + return nil, ErrInvalidOperator(operatorNode) + } + logger.Error("Unknown operator type", zap.String("operator", operatorNode)) + return nil, ErrBuildTreeFailed("unknown", "no valid operator found in logical plan") +} +func parseSource(sourceOBJ jsonOBJ, plan *planMetaData) (operators.Operator, error) { + logger := config.GetLogger() + fields := []string{"file-name", "local"} + err := containsFields(fields, sourceOBJ) + if err != nil { + logger.Error("Missing required fields in source", zap.Error(err)) + return nil, err + } + err = correctFieldTypes(fields, []string{"string", "boolean"}, sourceOBJ) + if err != nil { + logger.Error("Invalid field types in source", zap.Error(err)) + return nil, err + } + name := sourceOBJ["file-name"].(string) + logger.Info("Parsing source", zap.String("file_name", name)) + pieces := strings.Split(name, ".") + if len(pieces) < 1 { + return nil, fmt.Errorf("invalid file name used as source, must end in .csv or .parquet") + } + var kind string + switch strings.ToLower(pieces[len(pieces)-1]) { + case "csv": + kind = "csv" + case "parquet": + kind = "parquet" + default: + return nil, fmt.Errorf("invalid file mime was used in source operator") + } + local := sourceOBJ["local"].(bool) + ntwResource, err := project.NewStreamReader(name) + if err != nil { + return nil, err + } + if !local && kind == "parquet" { + parquetRootNode, err := project.NewParquetSource(ntwResource) + if err != nil { + return nil, err + } + return parquetRootNode, nil + } + localFile, err := ntwResource.DownloadLocally(plan.id) + if err != nil { + return nil, err + } + curDir, _ := os.Getwd() + plan.localFileNames = append(plan.localFileNames, fmt.Sprintf("%s/%s", curDir, localFile.Name())) + switch kind { + case "csv": + csvRootNode, err := project.NewProjectCSVLeaf(localFile) + if err != nil { + logger.Error("Failed to create CSV source", zap.Error(err)) + return nil, err + } + logger.Info("CSV source created successfully", zap.String("file", name)) + return csvRootNode, nil + case "parquet": + parquetRootNode, err := project.NewParquetSource(localFile) + if err != nil { + logger.Error("Failed to create Parquet source", zap.Error(err)) + return nil, err + } + logger.Info("Parquet source created successfully", zap.String("file", name)) + return parquetRootNode, nil + } + return nil, nil + +} +func parseFilter(filterOBJ jsonOBJ, plan *planMetaData) (*filter.FilterExec, error) { + fields := []string{"input", "expression"} + err := containsFields(fields, filterOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "object"}, filterOBJ) + if err != nil { + return nil, err + } + exprsVal, ok := filterOBJ["expression"].(map[string]any) + if !ok { + return nil, fmt.Errorf("expression field has invalid type, expected map[string]any") + } + expression, err := parseExpression(exprsVal) + if err != nil { + return nil, err + } + var validExpr func(e Expr.Expression) bool // only here so we can call validExpr recusivly + validExpr = func(e Expr.Expression) bool { + be, ok := e.(*Expr.BinaryExpr) + if !ok { + return false + } + switch be.Op { + case Expr.Equal, + Expr.NotEqual, + Expr.LessThan, + Expr.LessThanOrEqual, + Expr.GreaterThan, + Expr.GreaterThanOrEqual: + return true + case Expr.And, Expr.Or: + return validExpr(be.Left) && validExpr(be.Right) + default: + return false + } + } + if !validExpr(expression) { + return nil, fmt.Errorf("%s is not a valid filter/having expression, must evaluate to boolean mask", expression) + } + inp := filterOBJ["input"].(map[string]any) + input, err := resolveInput(inp, plan) + if err != nil { + return nil, err + } + return filter.NewFilterExec(input, expression) +} +func parseProject(projectOBJ jsonOBJ, plan *planMetaData) (*project.ProjectExec, error) { + fields := []string{"input", "expressions"} + err := containsFields(fields, projectOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "array"}, projectOBJ) + if err != nil { + return nil, err + } + var expres []Expr.Expression + switch exprs := projectOBJ["expressions"].(type) { + case []any: + for i, raw := range exprs { + m, ok := raw.(map[string]any) + if !ok { + return nil, fmt.Errorf("expressions[%d] invalid type, expected object but got %T", i, raw) + } + e, err := parseExpression(m) + if err != nil { + return nil, err + } + expres = append(expres, e) + } + // (tests) + case []map[string]any: + for i, m := range exprs { + e, err := parseExpression(m) + if err != nil { + return nil, err + } + expres = append(expres, e) + _ = i + } + + default: + return nil, fmt.Errorf("expressions field has invalid type, expected array but got %T", projectOBJ["expressions"]) + } + if len(expres) == 0 { + return nil, fmt.Errorf("project operator needs at least one expressions") + } + sourceInput, err := resolveInput(projectOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + return project.NewProjectExec(sourceInput, expres) +} +func parseSort(sortOBJ jsonOBJ, plan *planMetaData) (*aggr.SortExec, error) { + fields := []string{"input", "by"} + err := containsFields(fields, sortOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "array"}, sortOBJ) + if err != nil { + return nil, err + } + parseBy := func(obj []map[string]any) ([]aggr.SortKey, error) { + var outputKeys []aggr.SortKey + for _, byexpr := range obj { + byFields := []string{"expr", "asc"} + err := containsFields(byFields, byexpr) + if err != nil { + return nil, err + } + err = correctFieldTypes(byFields, []string{"object", "boolean"}, byexpr) + if err != nil { + return nil, err + } + expr, err := parseExpression(byexpr["expr"].(map[string]any)) + if err != nil { + return nil, err + } + asc := byexpr["asc"].(bool) + outputKeys = append(outputKeys, aggr.SortKey{ + Expr: expr, + Ascending: asc, + }) + + } + return outputKeys, nil + } + input, err := resolveInput(sortOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + var byField []map[string]any + + switch v := sortOBJ["by"].(type) { + case []any: + byField = make([]map[string]any, 0, len(v)) + for i, item := range v { + m, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("sort::by[%d] is malformed, expected object but got %T", i, item) + } + byField = append(byField, m) + } + + case []map[string]any: + // Go-literal tests may already have the correct type + byField = v + + default: + return nil, fmt.Errorf("sort::by field is malformed, should be an array of objects, got %T", sortOBJ["by"]) + } + sortKeys, err := parseBy(byField) + if err != nil { + return nil, err + } + if len(sortKeys) < 1 { + return nil, fmt.Errorf("sort keys must be present for Sort operator") + } + return aggr.NewSortExec(input, sortKeys) +} +func parseDistinct(distinctOBJ jsonOBJ, plan *planMetaData) (*filter.DistinctExec, error) { + fields := []string{"input", "expressions"} + err := containsFields(fields, distinctOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "array"}, distinctOBJ) + if err != nil { + return nil, err + } + var expres []Expr.Expression + switch exprs := distinctOBJ["expressions"].(type) { + case []any: + for i, raw := range exprs { + m, ok := raw.(map[string]any) + if !ok { + return nil, fmt.Errorf("expressions[%d] invalid type, expected object but got %T", i, raw) + } + e, err := parseExpression(m) + if err != nil { + return nil, err + } + expres = append(expres, e) + } + + case []map[string]any: + for i, m := range exprs { + e, err := parseExpression(m) + if err != nil { + return nil, err + } + expres = append(expres, e) + _ = i + } + default: + return nil, fmt.Errorf("expressions field has invalid type, expected array but got %T", distinctOBJ["expressions"]) + } + if len(expres) == 0 { + return nil, fmt.Errorf("distinct operator needs at least one expressions") + } + sourceInput, err := resolveInput(distinctOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + return filter.NewDistinctExec(sourceInput, expres) +} + +func parseLimit(limitOBJ jsonOBJ, plan *planMetaData) (*filter.LimitExec, error) { + fields := []string{"input", "limit"} + err := containsFields(fields, limitOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "int"}, limitOBJ) + if err != nil { + return nil, err + } + limit, ok := limitOBJ["limit"].(int) + if !ok { + // try to parse as float + l, ok1 := limitOBJ["limit"].(float64) + if !ok1 { + return nil, fmt.Errorf("limit field is not the correct type: true Type %T", limitOBJ["limit"]) + } + //workeds so cast to int + limit = int(l) + } + // must be a valid uint16 value 1-2^16 + if limit <= 0 || limit > math.MaxUint16 { + return nil, fmt.Errorf("limit field cannot be less than 1 or greater than %v, but %v was passed in", math.MaxUint16, limit) + } + sourceInput, err := resolveInput(limitOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + + return filter.NewLimitExec(sourceInput, uint16(limit)) +} + +func parseSingleAggr(aggrOBJ jsonOBJ, plan *planMetaData) (*aggr.AggrExec, error) { + fields := []string{"input", "aggrs"} + err := containsFields(fields, aggrOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "array"}, aggrOBJ) + if err != nil { + return nil, err + } + + input, err := resolveInput(aggrOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + var res []map[string]any + + switch agVal := aggrOBJ["aggrs"].(type) { + case []any: + res = make([]map[string]any, 0, len(agVal)) + for i, item := range agVal { + v, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("aggrs[%d] malformed, expected object but got %T", i, item) + } + res = append(res, v) + } + case []map[string]any: + res = agVal + + default: + return nil, fmt.Errorf("aggrs malformed, should be an array of aggregations: got %T", aggrOBJ["aggrs"]) + } + globalAggrs, err := generateAggrs(res) + if err != nil { + return nil, err + } + if len(globalAggrs) < 1 { + return nil, fmt.Errorf("there must be atleast one aggregation") + } + return aggr.NewGlobalAggrExec(input, globalAggrs) +} +func parseGroupBy(groupbyOBJ jsonOBJ, plan *planMetaData) (*aggr.GroupByExec, error) { + fields := []string{"input", "group_by", "aggrs"} + err := containsFields(fields, groupbyOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "array", "array"}, groupbyOBJ) + if err != nil { + return nil, err + } + input, err := resolveInput(groupbyOBJ["input"].(map[string]any), plan) + if err != nil { + return nil, err + } + var groupByStatments []Expr.Expression + + // ---- group_by: accept []any (json) OR []map[string]any (go literals) + switch gbVal := groupbyOBJ["group_by"].(type) { + case []any: + for i, gb := range gbVal { + m, ok := gb.(map[string]any) + if !ok { + return nil, fmt.Errorf("group_by[%d] malformed, expected object but got %T", i, gb) + } + e, err := parseExpression(m) + if err != nil { + return nil, err + } + groupByStatments = append(groupByStatments, e) + } + + case []map[string]any: + for i, m := range gbVal { + e, err := parseExpression(m) + if err != nil { + return nil, err + } + groupByStatments = append(groupByStatments, e) + _ = i + } + + default: + return nil, fmt.Errorf("group by statements are malformed, should be an array of expressions: got %T", groupbyOBJ["group_by"]) + } + // ---- aggrs: accept []any (json) OR []map[string]any (go literals) + var res []map[string]any + + switch agVal := groupbyOBJ["aggrs"].(type) { + case []any: + res = make([]map[string]any, 0, len(agVal)) + for i, item := range agVal { + v, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("aggrs[%d] malformed, expected object but got %T", i, item) + } + res = append(res, v) + } + case []map[string]any: + res = agVal + + default: + return nil, fmt.Errorf("aggrs malformed, should be an array of aggregations: got %T", groupbyOBJ["aggrs"]) + } + aggrs, err := generateAggrs(res) + if err != nil { + return nil, err + } + if len(groupByStatments) == 0 { + return nil, fmt.Errorf("invalid GROUP BY: must have at least one group_by key") + } + if len(aggrs) == 0 { + return nil, fmt.Errorf("invalid GROUP BY: must have at least one aggregation") + } + return aggr.NewGroupByExec(input, aggrs, groupByStatments) +} +func parseJoin(joinOBJ jsonOBJ, plan *planMetaData) (*join.HashJoinExec, error) { + fields := []string{"left", "right", "join_type", "on"} + err := containsFields(fields, joinOBJ) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"object", "object", "string", "array"}, joinOBJ) + if err != nil { + return nil, err + } + leftObj, ok := joinOBJ["left"].(map[string]any) + if !ok { + return nil, fmt.Errorf("malformed join body, `left` field must be an operator/object") + } + left, err := resolveInput(leftObj, plan) + if err != nil { + return nil, err + } + rightObj, ok := joinOBJ["right"].(map[string]any) + if !ok { + return nil, fmt.Errorf("malformed join body ,`right` field must be an operator/object") + } + right, err := resolveInput(rightObj, plan) + if err != nil { + return nil, err + } + joinType := strings.ToLower(joinOBJ["join_type"].(string)) + if joinType != "inner" { + return nil, fmt.Errorf("invalid join type provided %s only inner is supported", joinType) + } + clauseParer := func(clause []map[string]any) (join.JoinClause, error) { + var jc join.JoinClause + for _, c := range clause { + f := []string{"left", "right"} + err := containsFields(f, c) + if err != nil { + return jc, err + } + err = correctFieldTypes(f, []string{"object", "object"}, c) + if err != nil { + return jc, err + } + leftExpr, err := parseExpression(c["left"].(map[string]any)) + if err != nil { + return jc, err + } + rightExpr, err := parseExpression(c["right"].(map[string]any)) + if err != nil { + return jc, err + } + jc.LeftS = append(jc.LeftS, leftExpr) + jc.RightS = append(jc.RightS, rightExpr) + } + if len(jc.LeftS) < 1 { + return jc, fmt.Errorf("join clause cannot be empyy") + } + return jc, nil + } + rawOn, ok := joinOBJ["on"] + if !ok { + return nil, fmt.Errorf("join::on field is missing") + } + + var onClauses []map[string]any + + switch v := rawOn.(type) { + case []any: + onClauses = make([]map[string]any, 0, len(v)) + for i, item := range v { + m, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("join::on[%d] malformed, expected object but got %T", i, item) + } + onClauses = append(onClauses, m) + } + case []map[string]any: + onClauses = v + default: + return nil, fmt.Errorf("join::on field is malformed, expected array of objects, got %T", rawOn) + } + jc, err := clauseParer(onClauses) + if err != nil { + return nil, err + } + + return join.NewHashJoinExec(left, right, jc, join.InnerJoin, nil) +} + +// carbon clone of +func parseHaving(havingOBJ jsonOBJ, plan *planMetaData) (operators.Operator, error) { + return parseFilter(havingOBJ, plan) +} + +// expressions need to be handled in a special way since they contain serveral keys +func parseExpression(m jsonOBJ) (Expr.Expression, error) { + logger := config.GetLogger() + // grab tje expr_type and then parse based on that + logger.Debug("JSON object passed in for expression parsing", zap.Any("json_object", m)) + err := containsFields([]string{"expr_type"}, m) + if err != nil { + logger.Error("Malformed expression: missing expr_type", zap.Error(err)) + return nil, fmt.Errorf("malformed expression body. Doesnt contain expr_type field") + } + exprType := m["expr_type"].(string) + logger.Debug("Parsing expression", zap.String("expr_type", exprType)) + switch exprType { + case "ColumnResolve": + neededFields := []string{"name"} + fieldTypes := []string{"string"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + if m["name"] == "" { + return nil, fmt.Errorf("column resolve name cannot be empty") + } + cr := Expr.NewColumnResolve(m["name"].(string)) + return cr, nil + case "LiteralResolve": + neededFields := []string{"value", "lit_type"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + fieldTypes := []string{m["lit_type"].(string), "string"} + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body (Types): %v", err) + } + var value any + var arrowType arrow.DataType + logger.Debug("Parsing LiteralResolve", zap.Any("raw_value", m["value"]), zap.Any("lit_type", m["lit_type"])) + switch m["lit_type"].(string) { + case "int": + arrowType = arrow.PrimitiveTypes.Int64 + switch val := m["value"].(type) { + case int: + value = int64(val) + case float64: + value = int64(val) + } + case "string": + arrowType = arrow.BinaryTypes.String + v, _ := m["value"].(string) + value = string(v) + case "boolean": + arrowType = arrow.FixedWidthTypes.Boolean + v, _ := m["value"].(bool) + value = bool(v) + case "float64": + arrowType = arrow.PrimitiveTypes.Float64 + v, _ := m["value"].(float64) + value = float64(v) + default: + return nil, fmt.Errorf("invalid Literal Type was passed to Literal Resolve") + } + lr := Expr.NewLiteralResolve(arrowType, value) + return lr, nil + case "BinaryExpr": + neededFields := []string{"op", "left", "right"} + fieldTypes := []string{"string", "object", "object"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + left, err := parseExpression(m["left"].(map[string]any)) + if err != nil { + return nil, err + } + right, err := parseExpression(m["right"].(map[string]any)) + if err != nil { + return nil, err + } + op := m["op"].(string) + operator, err := validBinaryOp(op) + if err != nil { + return nil, err + } + binaryExpression := Expr.NewBinaryExpr(left, operator, right) + logger.Debug("BinaryExpr created", zap.String("expression", fmt.Sprintf("%v", binaryExpression))) + return binaryExpression, nil + case "ScalarFunction": + neededFields := []string{"func", "expr"} + fieldTypes := []string{"string", "object"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + function := m["func"].(string) + fn := Expr.SupportedFunctions(-1) + switch function { + case "Upper", "Lower", "Abs", "Round": + fn = Expr.FnToScalarFunction(function) + + } + if fn == Expr.SupportedFunctions(-1) { + return nil, fmt.Errorf("invalid scalr function provided %s", function) + + } + expr, err := parseExpression(m["expr"].(map[string]any)) + if err != nil { + return nil, err + } + sf := Expr.NewScalarFunction(Expr.FnToScalarFunction(function), expr) + return sf, nil + case "Alias": + neededFields := []string{"name", "expr"} + fieldTypes := []string{"string", "object"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + expr, err := parseExpression(m["expr"].(map[string]any)) + if err != nil { + return nil, err + } + name := m["name"].(string) + alias := Expr.NewAlias(expr, name) + return alias, nil + case "CastExpr": + neededFields := []string{"expr", "to_type"} + fieldTypes := []string{"object", "string"} + err := containsFields(neededFields, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + err = correctFieldTypes(neededFields, fieldTypes, m) + if err != nil { + return nil, fmt.Errorf("malformed expression body: %v", err) + } + expr, err := parseExpression(m["expr"].(map[string]any)) + if err != nil { + return nil, err + } + var T arrow.DataType + switch m["to_type"].(string) { + case "int": + T = arrow.PrimitiveTypes.Int64 + case "string": + T = arrow.BinaryTypes.String + case "boolean": + T = arrow.FixedWidthTypes.Boolean + case "float64": + T = arrow.PrimitiveTypes.Float64 + default: + return nil, fmt.Errorf("invalid type provided.%v", m["to_type"]) + } + cast := Expr.NewCastExpr(expr, T) + return cast, nil + default: + return nil, fmt.Errorf("invalid expression: %v", m["expr_type"]) + } + +} +func resolveInput(m jsonOBJ, plan *planMetaData) (operators.Operator, error) { + const OperatorStr = "Operator" + fields := []string{OperatorStr} + if err := containsFields(fields, m); err != nil { + return nil, err + } + opName := m[OperatorStr].(string) + _, ok := m[opName] + if !ok { + return nil, fmt.Errorf("malformed json body.operator body does not contain %s's body", opName) + } + + if err := correctFieldTypes([]string{OperatorStr, opName}, []string{"string", "object"}, m); err != nil { + return nil, err + } + newOBJ := m[opName].(map[string]any) + switch strings.ToLower(opName) { + // base case, we hit a leaf node (source node) + case "source": // return concrete base case here + return parseSource(newOBJ, plan) + case "project": + return parseProject(newOBJ, plan) + case "filter": + return parseFilter(newOBJ, plan) + case "distinct": + return parseDistinct(newOBJ, plan) + case "limit": + return parseLimit(newOBJ, plan) + case "sort": + return parseSort(newOBJ, plan) + case "aggregate": + return parseSingleAggr(newOBJ, plan) + case "having": + return parseHaving(newOBJ, plan) + case "join": + return parseJoin(newOBJ, plan) + case "groupby": + return parseGroupBy(newOBJ, plan) + } + + return nil, nil +} + +// check that all the fileds exist, if any are missing return and error indicating which fields are missing +// ignore any extra fields that may be present for now +func containsFields(fields []string, obj map[string]any) error { + var missing []string + + for _, f := range fields { + if _, ok := obj[f]; !ok { + missing = append(missing, f) + } + } + + if len(missing) > 0 { + return fmt.Errorf( + "missing required fields: %s", + strings.Join(missing, ", "), + ) + } + + return nil +} + +type misMatchTypes struct { + idx uint8 + fieldName string + value any // from "%v" formating + recievedType string + expectedDataType string +} + +func correctFieldTypes(fields []string, fieldTypes []string, obj jsonOBJ) error { + if len(fields) != len(fieldTypes) { + return fmt.Errorf("fields and fieldTypes must have the same number of elements") + } + var misMatches []misMatchTypes + + // dont need to do _,ok pattern here because we can assume contains fields is called before this one + for i, field := range fields { + value := obj[field] + expected := fieldTypes[i] + if !matchesExpectedType(value, expected) { + misMatches = append(misMatches, misMatchTypes{ + idx: uint8(i), + fieldName: field, + value: value, + recievedType: fmt.Sprintf("%T", value), + expectedDataType: expected, + }) + } + + } + if len(misMatches) > 0 { + return fmt.Errorf("all fields did not match their expected data types \t%#v", misMatches) + + } + return nil // mismatch in field and their expected types, field1 is not of expected type T1 +} + +func matchesExpectedType(value any, expected string) bool { + switch expected { + case "string": + _, ok := value.(string) + return ok + case "boolean": + _, ok := value.(bool) + return ok + case "int": + switch value.(type) { + case float64, float32, int: + return true + default: + return false + } + case "float64": + _, ok := value.(float64) + return ok + case "object": + _, ok := value.(map[string]any) + return ok + case "array": + // Use reflection to check if it's any kind of slice/array + return reflect.TypeOf(value).Kind() == reflect.Slice + default: + return false + } +} + +func validBinaryOp(s string) (Expr.BinaryOperator, error) { + switch s { + // arithmetic + case "Addition": + return Expr.BinaryOperator(Expr.Addition), nil + case "Subtraction": + return Expr.BinaryOperator(Expr.Subtraction), nil + case "Multiplication": + return Expr.BinaryOperator(Expr.Multiplication), nil + case "Division": + return Expr.BinaryOperator(Expr.Division), nil + + // comparison + case "Equal": + return Expr.BinaryOperator(Expr.Equal), nil + case "NotEqual": + return Expr.BinaryOperator(Expr.NotEqual), nil + case "LessThan": + return Expr.BinaryOperator(Expr.LessThan), nil + case "LessThanOrEqual": + return Expr.BinaryOperator(Expr.LessThanOrEqual), nil + case "GreaterThan": + return Expr.BinaryOperator(Expr.GreaterThan), nil + case "GreaterThanOrEqual": + return Expr.BinaryOperator(Expr.GreaterThanOrEqual), nil + + // logical + case "And": + return Expr.BinaryOperator(Expr.And), nil + case "Or": + return Expr.BinaryOperator(Expr.Or), nil + + // regex + case "Like": + return Expr.BinaryOperator(Expr.Like), nil + + default: + return Expr.BinaryOperator(-1), fmt.Errorf("invalid binary operator: %s", s) + } +} + +// call strings.toLower before invoking this method +func validFN(s string) bool { + switch s { + case "sum", "count", "avg", "min", "max": + return true + default: + return false + } +} + +func toAggrFn(s string) aggr.AggrFunc { + switch s { + case "sum": + return aggr.Sum + case "count": + return aggr.Count + case "avg": + return aggr.Avg + case "min": + return aggr.Min + case "max": + return aggr.Max + } + logger := config.GetLogger() + logger.Error("Unsupported aggregation function", zap.String("function", s)) + return -1 +} + +func generateAggrs(aggrs []map[string]any) ([]aggr.AggregateFunctions, error) { + var globalAggrs []aggr.AggregateFunctions + for _, a := range aggrs { + fields := []string{"function", "expr"} + err := containsFields(fields, a) + if err != nil { + return nil, err + } + err = correctFieldTypes(fields, []string{"string", "object"}, a) + if err != nil { + return nil, err + } + fn := strings.ToLower(a["function"].(string)) + if !validFN(fn) { + return nil, fmt.Errorf("%s is not a valid aggregation method", fn) + } + expr, err := parseExpression(a["expr"].(map[string]any)) + if err != nil { + return nil, err + } + globalAggrs = append(globalAggrs, aggr.NewAggregateFunctions(toAggrFn(fn), expr)) + } + return globalAggrs, nil +} diff --git a/src/Backend/opti-sql-go/substrait/substrait_integration_test.go b/src/Backend/opti-sql-go/substrait/substrait_integration_test.go new file mode 100644 index 0000000..87cdf65 --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/substrait_integration_test.go @@ -0,0 +1,1423 @@ +package substrait + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +//! to delete .csv files generated -> find . -type f -name '*\.csv*' -delete + +// IntegrationTest defines a single integration test case using buildTree +type IntegrationTest struct { + name string + shouldError bool + logicalPlan jsonOBJ + sqlEquiv string // SQL equivalent for documentation +} + +// FileIntegrationTest defines a test case that reads from a Substrait JSON file +type FileIntegrationTest struct { + name string + shouldError bool + filePath string + sqlEquiv string // Plan ID = SQL equivalent +} + +// TestOperatorsIntegration tests each operator using buildTree with struct-based test cases +func TestOperatorsIntegration(t *testing.T) { + defer func() { + curDir, err := os.Getwd() + if err != nil { + fmt.Printf("Failed to get current directory: %v\n", err) + } + + // Read directory contents + entries, err := os.ReadDir(curDir) + if err != nil { + fmt.Printf("Failed to read directory: %v\n", err) + } + + // Delete all files containing .csv in their name + for _, entry := range entries { + if !entry.IsDir() && strings.Contains(entry.Name(), "country_full.csv") { + filePath := fmt.Sprintf("%s/%s", curDir, entry.Name()) + err := os.Remove(filePath) + if err != nil { + fmt.Printf("error removing %s: %v\n", entry.Name(), err) + } else { + fmt.Printf("deleted: %s\n", entry.Name()) + } + } + } + }() + t.Run("Filter Operator Integration", func(t *testing.T) { + + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + filterTests := []IntegrationTest{ + { + name: "SELECT * FROM country WHERE region > 'Africa'", + shouldError: false, + sqlEquiv: "SELECT * FROM country WHERE region > 'Africa'", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Africa", + "lit_type": "string", + }, + }, + }, + }, + }, + { + name: "SELECT * FROM country WHERE country_code < 500", + shouldError: false, + sqlEquiv: "SELECT * FROM country WHERE country_code < 500", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "LessThan", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": 500, + "lit_type": "int", + }, + }, + }, + }, + }, + { + name: "SELECT * FROM country WHERE name = 'Canada' AND region = 'Americas'", + shouldError: false, + sqlEquiv: "SELECT * FROM country WHERE name = 'Canada' AND region = 'Americas'", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "And", + "left": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + "right": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Americas", + "lit_type": "string", + }, + }, + }, + }, + }, + }, + { + name: "Filter with missing expression field - should fail", + shouldError: true, + sqlEquiv: "Filter with missing expression field - should fail", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "input": sourceInput, + }, + }, + }, + { + name: "Filter with missing input field - should fail", + shouldError: true, + sqlEquiv: "Filter with missing input field - should fail", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "expression": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + } + + for _, test := range filterTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if (err != nil) != test.shouldError { + t.Errorf("buildTree error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + } + }) + } + }) + + t.Run("Project Operator Integration", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectTests := []IntegrationTest{ + { + name: "SELECT name, region FROM country", + shouldError: false, + sqlEquiv: "SELECT name, region FROM country", + logicalPlan: map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + }, + { + name: "SELECT country_code, name, region FROM country", + shouldError: false, + sqlEquiv: "SELECT country_code, name, region FROM country", + logicalPlan: map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "country-code", + }, + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + }, + { + name: "SELECT * - all columns using single column projection", + shouldError: false, + sqlEquiv: "SELECT * - all columns using single column projection", + logicalPlan: map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + }, + { + name: "Project with missing expressions field - should fail", + shouldError: true, + sqlEquiv: "Project with missing expressions field - should fail", + logicalPlan: map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + }, + }, + }, + { + name: "Project with empty expressions array - should fail", + shouldError: true, + sqlEquiv: "Project with empty expressions array - should fail", + logicalPlan: map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{}, + }, + }, + }, + } + + for _, test := range projectTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if (err != nil) != test.shouldError { + t.Errorf("buildTree error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + } + }) + } + }) + + t.Run("Sort Operator Integration", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + sortTests := []IntegrationTest{ + { + name: "SELECT * FROM country ORDER BY name ASC", + shouldError: false, + sqlEquiv: "SELECT * FROM country ORDER BY name ASC", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": sourceInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": true, + }, + }, + }, + }, + }, + { + name: "SELECT * FROM country ORDER BY country_code DESC", + shouldError: false, + sqlEquiv: "SELECT * FROM country ORDER BY country_code DESC", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": sourceInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + "asc": false, + }, + }, + }, + }, + }, + { + name: "SELECT * FROM country ORDER BY region ASC, name DESC", + shouldError: false, + sqlEquiv: "SELECT * FROM country ORDER BY region ASC, name DESC", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": sourceInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + "asc": true, + }, + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": false, + }, + }, + }, + }, + }, + { + name: "Sort with missing by field - should fail", + shouldError: true, + sqlEquiv: "Sort with missing by field - should fail", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": sourceInput, + }, + }, + }, + { + name: "Sort with empty by array - should fail", + shouldError: true, + sqlEquiv: "Sort with empty by array - should fail", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": sourceInput, + "by": []map[string]any{}, + }, + }, + }, + } + + for _, test := range sortTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if (err != nil) != test.shouldError { + t.Errorf("buildTree error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + } + }) + } + }) + + t.Run("Distinct Operator Integration", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + distinctTests := []IntegrationTest{ + { + name: "SELECT DISTINCT region FROM country", + shouldError: false, + sqlEquiv: "SELECT DISTINCT region FROM country", + logicalPlan: map[string]any{ + "Operator": "Distinct", + "Distinct": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + }, + { + name: "SELECT DISTINCT region, sub_region FROM country", + shouldError: false, + sqlEquiv: "SELECT DISTINCT region, sub_region FROM country", + logicalPlan: map[string]any{ + "Operator": "Distinct", + "Distinct": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + { + "expr_type": "ColumnResolve", + "name": "sub-region", + }, + }, + }, + }, + }, + { + name: "SELECT DISTINCT * FROM country - all columns", + shouldError: false, + sqlEquiv: "SELECT DISTINCT * FROM country - all columns", + logicalPlan: map[string]any{ + "Operator": "Distinct", + "Distinct": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "country-code", + }, + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + { + "expr_type": "ColumnResolve", + "name": "sub-region", + }, + }, + }, + }, + }, + { + name: "Distinct with missing input field - should fail", + shouldError: true, + sqlEquiv: "Distinct with missing input field - should fail", + logicalPlan: map[string]any{ + "Operator": "Distinct", + "Distinct": map[string]any{ + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + }, + } + + for _, test := range distinctTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if (err != nil) != test.shouldError { + t.Errorf("buildTree error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + } + }) + } + }) + + t.Run("Limit Operator Integration", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + limitTests := []IntegrationTest{ + { + name: "SELECT * FROM country LIMIT 10", + shouldError: false, + sqlEquiv: "SELECT * FROM country LIMIT 10", + logicalPlan: map[string]any{ + "Operator": "Limit", + "Limit": map[string]any{ + "input": sourceInput, + "limit": 10, + }, + }, + }, + { + name: "SELECT * FROM country LIMIT 1000000", + shouldError: true, + sqlEquiv: "SELECT * FROM country LIMIT 1000000", + logicalPlan: map[string]any{ + "Operator": "Limit", + "Limit": map[string]any{ + "input": sourceInput, + "limit": 1000000, + }, + }, + }, + { + name: "SELECT * FROM country LIMIT 1 - edge case minimum", + shouldError: false, + sqlEquiv: "SELECT * FROM country LIMIT 1 - edge case minimum", + logicalPlan: map[string]any{ + "Operator": "Limit", + "Limit": map[string]any{ + "input": sourceInput, + "limit": 1, + }, + }, + }, + { + name: "Limit with missing limit field - should fail", + shouldError: true, + sqlEquiv: "Limit with missing limit field - should fail", + logicalPlan: map[string]any{ + "Operator": "Limit", + "Limit": map[string]any{ + "input": sourceInput, + }, + }, + }, + { + name: "Limit with zero value - should fail", + shouldError: true, + sqlEquiv: "Limit with zero value - should fail", + logicalPlan: map[string]any{ + "Operator": "Limit", + "Limit": map[string]any{ + "input": sourceInput, + "limit": 0, + }, + }, + }, + } + + for _, test := range limitTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if (err != nil) != test.shouldError { + t.Errorf("buildTree error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + } + }) + } + }) + + t.Run("GroupBy Operator Integration", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + groupByTests := []IntegrationTest{ + { + name: "SELECT region, COUNT(*) FROM country GROUP BY region", + shouldError: false, + sqlEquiv: "SELECT region, COUNT(*) FROM country GROUP BY region", + logicalPlan: map[string]any{ + "Operator": "GroupBy", + "GroupBy": map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + }, + }, + { + name: "SELECT region, sub_region, SUM(country_code) FROM country GROUP BY region, sub_region", + shouldError: false, + sqlEquiv: "SELECT region, sub_region, SUM(country_code) FROM country GROUP BY region, sub_region", + logicalPlan: map[string]any{ + "Operator": "GroupBy", + "GroupBy": map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + { + "expr_type": "ColumnResolve", + "name": "sub-region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + }, + }, + { + name: "SELECT region, COUNT(*), AVG(country_code), MIN(country_code) FROM country GROUP BY region", + shouldError: false, + sqlEquiv: "SELECT region, COUNT(*), AVG(country_code), MIN(country_code) FROM country GROUP BY region", + logicalPlan: map[string]any{ + "Operator": "GroupBy", + "GroupBy": map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + { + "function": "Avg", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + { + "function": "Min", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + }, + }, + { + name: "GroupBy with missing group_by field - should fail", + shouldError: true, + sqlEquiv: "GroupBy with missing group_by field - should fail", + logicalPlan: map[string]any{ + "Operator": "GroupBy", + "GroupBy": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + }, + }, + { + name: "GroupBy with empty aggrs array - should fail", + shouldError: true, + sqlEquiv: "GroupBy with empty aggrs array - should fail", + logicalPlan: map[string]any{ + "Operator": "GroupBy", + "GroupBy": map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{}, + }, + }, + }, + } + + for _, test := range groupByTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if test.shouldError { + if err == nil { + t.Errorf("%s should have errored but recieved nil", test.name) + } + return + } + if err != nil { + t.Errorf("%s failed with error %v", test.name, err) + return + } + + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + }) + } + }) + + t.Run("Aggregate Operator Integration (Global)", func(t *testing.T) { + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + aggregateTests := []IntegrationTest{ + { + name: "SELECT SUM(country_code) FROM country", + shouldError: false, + sqlEquiv: "SELECT SUM(country_code) FROM country", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + }, + }, + { + name: "SELECT COUNT(name) FROM country", + shouldError: false, + sqlEquiv: "SELECT COUNT(name) FROM country", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + }, + }, + { + name: "SELECT AVG(country_code) FROM country", + shouldError: false, + sqlEquiv: "SELECT AVG(country_code) FROM country", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Avg", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + }, + }, + { + name: "SELECT MIN(country_code), MAX(country_code) FROM country", + shouldError: false, + sqlEquiv: "SELECT MIN(country_code), MAX(country_code) FROM country", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Min", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + { + "function": "Max", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + }, + }, + { + name: "Aggregate with missing aggrs field - should fail", + shouldError: true, + sqlEquiv: "Aggregate with missing aggrs field - should fail", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + }, + }, + }, + { + name: "Aggregate with empty aggrs array - should fail", + shouldError: true, + sqlEquiv: "Aggregate with empty aggrs array - should fail", + logicalPlan: map[string]any{ + "Operator": "Aggregate", + "Aggregate": map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{}, + }, + }, + }, + } + + for _, test := range aggregateTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if test.shouldError { + if err == nil { + t.Errorf("%s should have errored but received nil", test.name) + } + return + } + if err != nil { + t.Errorf("%s failed with error %v", test.name, err) + return + } + + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + }) + } + }) + + t.Run("Join Operator Integration", func(t *testing.T) { + leftSource := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "user_test_data.csv", + "local": false, + }, + } + + rightSource := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "company_test_data.csv", + "local": false, + }, + } + + joinTests := []IntegrationTest{ + { + name: "SELECT * FROM users JOIN companies ON users.id = companies.id", + shouldError: false, + sqlEquiv: "SELECT * FROM users JOIN companies ON users.id = companies.id", + logicalPlan: map[string]any{ + "Operator": "Join", + "Join": map[string]any{ + "left": leftSource, + "right": rightSource, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + }, + }, + { + name: "Join with Filter on users.age_years > 25", + shouldError: false, + sqlEquiv: "SELECT * FROM users JOIN companies ON users.id = companies.id WHERE age_years > 25", + logicalPlan: map[string]any{ + "Operator": "Filter", + "Filter": map[string]any{ + "input": map[string]any{ + "Operator": "Join", + "Join": map[string]any{ + "left": leftSource, // user data + "right": rightSource, // company data + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + }, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "age_years", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": 25, + "lit_type": "int", + }, + }, + }, + }, + }, + { + name: "Join with Sort on username", + shouldError: false, + sqlEquiv: "SELECT * FROM users JOIN companies ON users.id = companies.id ORDER BY username", + logicalPlan: map[string]any{ + "Operator": "Sort", + "Sort": map[string]any{ + "input": map[string]any{ + "Operator": "Join", + "Join": map[string]any{ + "left": leftSource, + "right": rightSource, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + }, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "username", + }, + "asc": true, + }, + }, + }, + }, + }, + { + name: "Join with missing left field - should fail", + shouldError: true, + sqlEquiv: "Join with missing left field - should fail", + logicalPlan: map[string]any{ + "Operator": "Join", + "Join": map[string]any{ + "right": rightSource, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + }, + }, + { + name: "Join with missing on field - should fail", + shouldError: true, + sqlEquiv: "Join with missing on field - should fail", + logicalPlan: map[string]any{ + "Operator": "Join", + "Join": map[string]any{ + "left": leftSource, + "right": rightSource, + "join_type": "Inner", + }, + }, + }, + } + + for _, test := range joinTests { + t.Run(test.sqlEquiv, func(t *testing.T) { + planMetaData := newPlanMetaData(test.sqlEquiv) + emitter, err := buildTree(test.logicalPlan, planMetaData) + + if test.shouldError { + if err == nil { + t.Errorf("%s should have errored but received nil", test.name) + } + return + } + if err != nil { + t.Errorf("%s failed with error %v", test.name, err) + return + } + + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s]\n%s\n", test.sqlEquiv, rb.PrettyPrint()) + } + }) + } + }) +} +func TestSubstraitFilesBasic(t *testing.T) { + basePath := filepath.Join("..", "..", "test_data", "substrait_plans", "basic") + + basicFileTests := []FileIntegrationTest{ + { + name: "basic__00_test.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_00_test.json"), + sqlEquiv: "tbd", + }, + + { + name: "basic_01_filter_project_sort.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_01_source_filter.json"), + sqlEquiv: "tbd", + }, + { + name: "basic_02_project.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_02_project.json"), + sqlEquiv: "", + }, + { + name: "basic_03_sort.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_03_sort.json"), + sqlEquiv: "tbd", + }, + { + name: "basic_04_distinct.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_04_distinct.json"), + sqlEquiv: "tbd", + }, + { + name: "basic_05_limit.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_05_limit.json"), + sqlEquiv: "tbd", + }, + { + name: "basic_06_aggr.json", + shouldError: false, + filePath: filepath.Join(basePath, "basic_06_aggr.json"), + sqlEquiv: "tbd", + }, + } + for _, test := range basicFileTests { + t.Run(test.name, func(t *testing.T) { + file, err := os.Open(test.filePath) + if err != nil { + t.Logf("Skipping %s: file not found err :%v \n", test.name, err) + return + } + defer func() { + if err := file.Close(); err != nil { + t.Logf("error closing file:\t%v\n", err) + } + }() + + emitter, err := ConsumeSubstraitPlan(file) + + if (err != nil) != test.shouldError { + t.Errorf("ConsumeSubstraitPlan error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s - %s]\n%s\n", test.sqlEquiv, test.name, rb.PrettyPrint()) + } + } + }) + } +} + +// TestSubstraitFilesMedium tests reading and executing medium-complexity Substrait plans from JSON files +func TestSubstraitFilesMedium(t *testing.T) { + basePath := filepath.Join("..", "..", "test_data", "substrait_plans", "medium") + + mediumFileTests := []FileIntegrationTest{ + { + name: "mid_01_filter_project_sort.json", + shouldError: false, + filePath: filepath.Join(basePath, "mid_01_filter_project_sort.json"), + sqlEquiv: "tbd", + }, + { + name: "mid_02_group_by_aggregate.json", + shouldError: false, + filePath: filepath.Join(basePath, "mid_02_group_by_aggregate.json"), + sqlEquiv: "tbd", + }, + { + name: "mid_03_join_filter.json", + shouldError: false, + filePath: filepath.Join(basePath, "mid_03_join_filter.json"), + sqlEquiv: "tbd", + }, + { + name: "mid_04_join_sort_limit.json", + shouldError: false, + filePath: filepath.Join(basePath, "mid_04_join_sort_limit.json"), + sqlEquiv: "tbd", + }, + } + + for _, test := range mediumFileTests { + t.Run(test.name, func(t *testing.T) { + file, err := os.Open(test.filePath) + if err != nil { + t.Logf("Skipping %s: file not found err :%v \n", test.name, err) + return + } + defer func() { + if err := file.Close(); err != nil { + t.Logf("error closing file:\t%v\n", err) + } + }() + + emitter, err := ConsumeSubstraitPlan(file) + + if (err != nil) != test.shouldError { + t.Errorf("ConsumeSubstraitPlan error = %v, shouldError = %v", err, test.shouldError) + return + } + + if !test.shouldError && emitter != nil { + rb, err := emitter.emitOperator.Next(5) + if err != nil { + t.Errorf("Next() error = %v", err) + return + } + if rb != nil { + t.Logf("[%s - %s]\n%s\n", test.sqlEquiv, test.name, rb.PrettyPrint()) + } + } + }) + } +} + +func TestSubstraitRegression(t *testing.T) { + basePath := filepath.Join("..", "..", "test_data", "base64-encoding") + tests := []struct { + id int + testName string + fileName string + }{ + /*{ + id: 1, + testName: "select id,name,age from employees where id > 5", + fileName: "select-filter.txt", + },*/ + { + id: 2, + testName: "SELECT name, age, salary FROM employees WHERE age > 30", + fileName: "select-filter-2.txt", + }, + } + for _, testObj := range tests { + t.Run(testObj.testName, func(t *testing.T) { + testID := testObj.id + fmt.Printf("%v\n", testID) + f, err := os.Open(filepath.Join(basePath, testObj.fileName)) + if err != nil { + t.Fatalf("failed to open %s, recieved this error: %v", testObj.fileName, err) + } + base64Content, err := io.ReadAll(f) + if err != nil { + t.Fatalf("failed to read file contents, recieved this error: %v", err) + } + fmt.Println("size of b64 content:", len(base64Content)) + decodedPlan, err := base64.StdEncoding.DecodeString(string(base64Content)) + if err != nil { + t.Fatalf("failed to base64 decode logical plan: %v", err) + } + source := strings.NewReader(string(decodedPlan)) + results, err := consumePlan(source, &planMetaData{id: testObj.testName}) + if err != nil { + t.Fatalf("error consuming logical plan: %v", err) + } + rc, err := results.consumeAll() + if err != nil { + t.Fatalf("error consuming all results: %v", err) + } + csv, err := rc.ToCSV() + fmt.Printf("csv content:\n%s", csv) + }) + } +} + +// ConsumeSubstraitPlan reads a Substrait plan from an io.Reader and returns an Emitter +func ConsumeSubstraitPlan(reader io.Reader) (*Emiter, error) { + // Read the JSON from the reader + data, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed to read plan: %w", err) + } + + // Unmarshal into a map + var planMap jsonOBJ + err = json.Unmarshal(data, &planMap) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal plan: %w", err) + } + + // The plan should have an "Emit" key containing the operator tree + if _, ok := planMap["Emit"]; !ok { + return nil, fmt.Errorf("plan missing 'Emit' key") + } + + emitObj := planMap["Emit"].(map[string]any) + + // Create plan metadata - use the first (and typically only) operator as ID + var planID string + for key := range emitObj { + planID = key + break + } + + planMetaData := newPlanMetaData(planID) + + // Build the tree starting from the Emit operator + emitter, err := buildTree(emitObj, planMetaData) + if err != nil { + return nil, fmt.Errorf("failed to build tree: %w", err) + } + + return emitter, nil +} diff --git a/src/Backend/opti-sql-go/substrait/substrait_test.go b/src/Backend/opti-sql-go/substrait/substrait_test.go index fe23790..fb0c40b 100644 --- a/src/Backend/opti-sql-go/substrait/substrait_test.go +++ b/src/Backend/opti-sql-go/substrait/substrait_test.go @@ -2,10 +2,47 @@ package substrait import ( "context" + "fmt" + "math" "net" + "opti-sql-go/Expr" + "os" + "path/filepath" + "strings" "testing" + "time" + + "github.com/apache/arrow/go/v17/arrow" ) +func testCleanUp() { + + // Get current directory + curDir, err := os.Getwd() + if err != nil { + fmt.Printf("Failed to get current directory: %v\n", err) + } + + // Read directory contents + entries, err := os.ReadDir(curDir) + if err != nil { + fmt.Printf("Failed to read directory: %v\n", err) + } + + // Delete all files containing .csv in their name + for _, entry := range entries { + if !entry.IsDir() && strings.Contains(entry.Name(), ".csv") { + filePath := fmt.Sprintf("%s/%s", curDir, entry.Name()) + err := os.Remove(filePath) + if err != nil { + fmt.Printf("error removing %s: %v\n", entry.Name(), err) + } else { + fmt.Printf("deleted: %s\n", entry.Name()) + } + } + } +} + func TestInitServer(t *testing.T) { // Simple passing test l, err := net.Listen("tcp", "0.0.0.0:1212") @@ -29,13 +66,9 @@ func TestDummyInput(t *testing.T) { t.Errorf("Expected non-nil Substrait server") } dummyRequest := &QueryExecutionRequest{ - SqlStatement: "SELECT * FROM table", - SubstraitLogical: []byte("CgJTUxIMCgpTZWxlY3QgKiBGUk9NIHRhYmxl"), - Id: "GenerateDTMoneyOHaasdavdasvasdvada", - Source: &SourceType{ - S3Source: "s3://my-bucket/data/table.parquet", - Mime: "application/vnd.apache.parquet", - }, + SqlStatement: "select * from table1 , asc ", + LogicalPlan: "ewogICAgIkVtaXQiOiAKICAgIHsKICAgICAgICAiT3BlcmF0b3IiOiAiU29ydCIsCiAgICAgICAgIlNvcnQiOiAKICAgICAgICB7CiAgICAgICAgICAgICJpbnB1dCI6IAogICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAiT3BlcmF0b3IiOiAiU291cmNlIiwKICAgICAgICAgICAgICAgICJTb3VyY2UiOiAKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZmlsZS1uYW1lIjogInVzZXJfdGVzdF9kYXRhLmNzdiIsCiAgICAgICAgICAgICAgICAgICAgImxvY2FsIjogZmFsc2UKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgImJ5IjogCiAgICAgICAgICAgIFsKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZXhwciI6IAogICAgICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAgICAgImV4cHJfdHlwZSI6ICJDb2x1bW5SZXNvbHZlIiwKICAgICAgICAgICAgICAgICAgICAgICAgIm5hbWUiOiAidXNlcm5hbWUiCiAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAiYXNjIjogdHJ1ZQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICBdCiAgICAgICAgfQogICAgfQp9", + Id: "97b61a8f-ffe1-4e4a-b6d7-73619698dc7a", } resp, err := ss.ExecuteQuery(context.Background(), dummyRequest) if err != nil { @@ -53,3 +86,2934 @@ func TestStartServer(t *testing.T) { } } + +// Plan parsing +const customIRPath = "../../test_data/substrait_plans/basic" + +func TestSubstraitPlanExist(t *testing.T) { + e, err := os.ReadDir(customIRPath) + if err != nil { + t.Fatalf("failed to open dir with error: %v\n", e) + } + for entries, name := range e { + t.Logf("entrie[%v]:\t%v\n", entries, name) + } + +} + +func TestSubstraitSourceParse(t *testing.T) { + t.Run("source parse test", func(t *testing.T) { + tests := []struct { + name string + fileName string + local bool + wantError bool + }{ + { + name: "csv file", + fileName: "country_full.csv", + local: true, + wantError: false, + }, + { + name: "parquet file with local true", + fileName: "userdata.parquet", + local: true, + wantError: false, + }, + { + name: "parquet file with local false", + fileName: "userdata.parquet", + local: false, + wantError: false, + }, + } + curDir, err := os.Getwd() + if err != nil { + t.Fatalf("failed to observe current working directory") + } + id := "richards-test-substrait-22" + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceObj := jsonOBJ{ + "file-name": tt.fileName, + "local": tt.local, + } + + op, err := parseSource(sourceObj, newPlanMetaData(id)) + if (err != nil) != tt.wantError { + t.Errorf("parseSource() error = %v, wantError %v", err, tt.wantError) + return + } + if !tt.wantError && op == nil { + t.Errorf("parseSource() returned nil operator when error was nil") + } + if tt.local { + path := fmt.Sprintf("%s/%s-%s", curDir, tt.fileName, id) + t.Logf("attempting to remove %s from path\n", path) + if err := os.Remove(path); err != nil { + t.Errorf("test:%s\n failed to delete %s from file system \n", tt.name, tt.fileName) + } + + } + }) + } + }) +} +func TestExpressionsParse(t *testing.T) { + // (1) all required fields exist + // (2) fields contain valid set of values (important for scalar functions and binary expr) + correctExpr := func(e Expr.Expression, wantedExpr string) bool { + switch e.(type) { + case *Expr.Alias: + return wantedExpr == "Alias" + case *Expr.ColumnResolve: + return wantedExpr == "ColumnResolve" + case *Expr.LiteralResolve: + return wantedExpr == "LiteralResolve" + case *Expr.BinaryExpr: + return wantedExpr == "BinaryExpr" + case *Expr.ScalarFunction: + return wantedExpr == "ScalarFunction" + case *Expr.CastExpr: + return wantedExpr == "CastExpr" + case *Expr.NullCheckExpr: + return wantedExpr == "NullCheckExpr" + default: + return false + } + } + t.Run("Column Resolve Test", func(t *testing.T) { + test := []struct { + testName string + jsonBody jsonOBJ + expectedColumn string + wantedExpreStr string + expectedError bool + }{ + { + testName: "basic column resolve", + jsonBody: map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + expectedColumn: "a", + wantedExpreStr: "ColumnResolve", + expectedError: false, + }, + { + testName: "column resolve with extra fields (ignored)", + jsonBody: map[string]any{ + "expr_type": "ColumnResolve", + "name": "user_id", + "junk": "should be ignored", + "num": 123, + }, + expectedColumn: "user_id", + wantedExpreStr: "ColumnResolve", + expectedError: false, + }, + { + testName: "missing name field", + jsonBody: map[string]any{ + "expr_type": "ColumnResolve", + }, + expectedColumn: "", + wantedExpreStr: "ColumnResolve", + expectedError: true, + }, + { + testName: "name is wrong type (number)", + jsonBody: map[string]any{ + "expr_type": "ColumnResolve", + "name": 123, + }, + expectedColumn: "", + wantedExpreStr: "ColumnResolve", + expectedError: true, + }, + { + testName: "name is empty string", + jsonBody: map[string]any{ + "expr_type": "ColumnResolve", + "name": "", + }, + expectedColumn: "", + wantedExpreStr: "ColumnResolve", + expectedError: true, + }, + { + testName: "expr_type wrong / missing (should fail)", + jsonBody: map[string]any{ + // "expr_type": "ColumnResolve", + "name": "a", + }, + expectedColumn: "", + wantedExpreStr: "ColumnResolve", + expectedError: true, + }, + } + for _, tt := range test { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + if tt.expectedError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + // Expected to fail and it failed -> test passes for this case. + return + } + if err != nil { + t.Fatalf("%s failed with error %v\n", tt.testName, err) + } + if !correctExpr(expr, tt.wantedExpreStr) { + t.Errorf("%s recieved the incorrect expression, expected %v but recieved %v\n", tt.testName, tt.wantedExpreStr, expr) + } + cr, _ := expr.(*Expr.ColumnResolve) + if cr.Name != tt.expectedColumn { + t.Errorf("%s has incorrect column resolve name, expected %s but recieved %v\n", tt.testName, tt.expectedColumn, cr.Name) + } + }) + } + // one for each type of accepted expression + }) + t.Run("Literal Resolve Test", func(t *testing.T) { + const exprName = "LiteralResolve" + test := []struct { + testName string + jsonBody jsonOBJ + expectedValue any + expectedType arrow.DataType + wantedExpreStr string + expectedError bool + }{ + { + testName: "basic Literal Resolve", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int", + }, + expectedValue: int64(10), + expectedType: arrow.PrimitiveTypes.Int64, + wantedExpreStr: exprName, + expectedError: false, + }, + { + testName: "string literal", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": "hello", + "lit_type": "string", + }, + expectedValue: "hello", + expectedType: arrow.BinaryTypes.String, + wantedExpreStr: exprName, + expectedError: false, + }, + { + testName: "boolean literal true", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": true, + "lit_type": "boolean", + }, + expectedValue: true, + expectedType: arrow.FixedWidthTypes.Boolean, + wantedExpreStr: exprName, + expectedError: false, + }, + { + testName: "float64 literal", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": 3.14159, + "lit_type": "float64", + }, + expectedValue: 3.14159, + expectedType: arrow.PrimitiveTypes.Float64, + wantedExpreStr: exprName, + expectedError: false, + }, + { + testName: "missing required field (lit_type)", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": 10, + }, + expectedValue: nil, + expectedType: nil, + wantedExpreStr: exprName, + expectedError: true, + }, + { + testName: "invalid lit_type value", + jsonBody: map[string]any{ + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int64", // not supported by your switch + }, + expectedValue: nil, + expectedType: nil, + wantedExpreStr: exprName, + expectedError: true, + }, + } + for _, tt := range test { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + if tt.expectedError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + + } + return + } + if err != nil { + t.Fatalf("%s failed with unexpected error %v\n", tt.testName, err) + } + if !correctExpr(expr, "LiteralResolve") { + t.Errorf("%s recieved the incorrect expression, expected %v but recieved %v\n", tt.testName, tt.wantedExpreStr, expr) + + } + lr, _ := expr.(*Expr.LiteralResolve) + if lr.Value != tt.expectedValue { + t.Fatalf("%s received incorrect value: expected (%T) %v, got (%T) %v", + tt.testName, tt.expectedValue, tt.expectedValue, lr.Value, lr.Value, + ) + } + }) + } + + // one for each type of accepted expression + }) + t.Run("BinaryExpr Test", func(t *testing.T) { + const exprName = "ScalarFunction" + + validVariants := []string{ + "Addition", + "Subtraction", + "Multiplication", + "Division", + "Equal", + "NotEqual", + "LessThan", + "LessThanOrEqual", + "GreaterThan", + "GreaterThanOrEqual", + "And", + "Or", + "Like", + } + + // Helper to keep JSON bodies consistent and small. + mkBinary := func(op string) jsonOBJ { + return map[string]any{ + "expr_type": "BinaryExpr", + "op": op, + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "b", + }, + } + } + + test := []struct { + testName string + jsonBody jsonOBJ + operator string + expectedError bool + }{} + + // --- Generate one passing test per valid operator variant --- + for _, op := range validVariants { + test = append(test, struct { + testName string + jsonBody jsonOBJ + operator string + expectedError bool + }{ + testName: "operator propagates: " + op, + jsonBody: mkBinary(op), + operator: op, + expectedError: false, + }) + } + test = append(test, struct { + testName string + jsonBody jsonOBJ + operator string + expectedError bool + }{ + testName: "Empty Operator", + jsonBody: mkBinary(""), + operator: "", + expectedError: true, + }, + ) + test = append(test, struct { + testName string + jsonBody jsonOBJ + operator string + expectedError bool + }{ + testName: "non-existant Operator", + jsonBody: mkBinary("matrixMultiply"), + operator: "matrixMultiply", + expectedError: true, + }, + ) + + for _, tt := range test { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + if tt.expectedError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + if err != nil { + t.Fatalf("%s failed with unexpected error %v\n", tt.testName, err) + } + if !correctExpr(expr, "BinaryExpr") { + t.Errorf("%s recieved the incorrect expression, expected %v but recieved %v of type %T\n", tt.testName, exprName, expr, expr) + + } + BinaryExpr, _ := expr.(*Expr.BinaryExpr) + if !Expr.MatchesBinaryOperator(tt.operator, int(BinaryExpr.Op)) { + t.Errorf("%s mismatch between expected operator (%s) and the recieved operator (%v)", tt.testName, tt.operator, BinaryExpr.Op) + } + + }) + } + // one for each type of accepted expression + }) + + t.Run("Scalar Function Test", func(t *testing.T) { + const exprName = "ScalarFunction" + + test := []struct { + testName string + jsonBody jsonOBJ + expectedFunc string + expectedError bool + }{ + // ---- VALID ---- + { + testName: "Upper is valid", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "Upper", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "Upper", + expectedError: false, + }, + { + testName: "Lower is valid", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "Lower", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "Lower", + expectedError: false, + }, + { + testName: "Abs is valid", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "Abs", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "Abs", + expectedError: false, + }, + { + testName: "Round is valid", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "Round", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "Round", + expectedError: false, + }, + + // ---- INVALID ---- + { + testName: "invalid scalar function name", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "NotARealFunc", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "", + expectedError: true, + }, + { + testName: "missing func field", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "", + expectedError: true, + }, + { + testName: "func wrong type", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": 123, + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectedFunc: "", + expectedError: true, + }, + { + testName: "missing expr field", + jsonBody: map[string]any{ + "expr_type": "ScalarFunction", + "func": "Upper", + }, + expectedFunc: "", + expectedError: true, + }, + } + + for _, tt := range test { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + + if tt.expectedError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + + if err != nil { + t.Fatalf("%s failed with unexpected error %v", tt.testName, err) + } + + if !correctExpr(expr, exprName) { + t.Fatalf("%s received incorrect expression, expected %s but received %T", + tt.testName, exprName, expr, + ) + } + + sf, ok := expr.(*Expr.ScalarFunction) + if !ok { + t.Fatalf("%s expected *Expr.ScalarFunction but received %T", tt.testName, expr) + } + + // NOTE: if your struct field is named differently, change sf.Func below. + if sf.Function != Expr.FnToScalarFunction(tt.expectedFunc) { + t.Fatalf("%s received incorrect scalar function, expected %q but received %q", + tt.testName, tt.expectedFunc, sf.Function, + ) + } + + }) + } + }) + + t.Run("Alias Test", func(t *testing.T) { + const exprName = "Alias" + tests := []struct { + testName string + jsonBody jsonOBJ + aliasName string + expectError bool + }{ + {testName: "basic alias", + jsonBody: map[string]any{ + "expr_type": "Alias", + "name": "new_name", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "first_column", + }, + }, + aliasName: "new_name", + expectError: false, + }, + { + testName: "alias with different name", + jsonBody: map[string]any{ + "expr_type": "Alias", + "name": "alias_1", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "col_a", + }, + }, + aliasName: "alias_1", + expectError: false, + }, + { + testName: "missing alias name field", + jsonBody: map[string]any{ + "expr_type": "Alias", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "first_column", + }, + }, + aliasName: "", + expectError: true, + }, + { + testName: "alias name wrong type", + jsonBody: map[string]any{ + "expr_type": "Alias", + "name": 123, + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "first_column", + }, + }, + aliasName: "", + expectError: true, + }, + { + testName: "missing expr field", + jsonBody: map[string]any{ + "expr_type": "Alias", + "name": "new_name", + }, + aliasName: "", + expectError: true, + }, + } + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + if tt.expectError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + if err != nil { + t.Fatalf("%s failed with unexpected error %v", tt.testName, err) + } + if !correctExpr(expr, exprName) { + t.Fatalf("%s received incorrect expression, expected %s but received %T", + tt.testName, exprName, expr, + ) + + } + + alias, ok := expr.(*Expr.Alias) + if !ok { + t.Fatalf("%s expected *Expr.Alias but received %T", tt.testName, expr) + } + if alias.Name != tt.aliasName { + t.Fatalf("%s recieved incorrect alias name, expected %s but recieved %s\n", tt.testName, tt.aliasName, alias.Name) + } + + }) + } + + }) + t.Run("CastExpr Test", func(t *testing.T) { + const exprName = "CastExpr" + + tests := []struct { + testName string + jsonBody jsonOBJ + expectedToType string + expectedError bool + }{ + // ---- VALID to_type ---- + { + testName: "cast to int is valid", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": "int", + }, + expectedToType: "int", + expectedError: false, + }, + { + testName: "cast to string is valid", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": "string", + }, + expectedToType: "string", + expectedError: false, + }, + { + testName: "cast to boolean is valid", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": "boolean", + }, + expectedToType: "boolean", + expectedError: false, + }, + { + testName: "cast to float64 is valid", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": "float64", + }, + expectedToType: "float64", + expectedError: false, + }, + + // ---- INVALID to_type / malformed ---- + { + testName: "invalid to_type value", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": "int64", + }, + expectedToType: "", + expectedError: true, + }, + { + testName: "missing to_type field", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + }, + expectedToType: "", + expectedError: true, + }, + { + testName: "to_type wrong type", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "a", + }, + "to_type": 123, + }, + expectedToType: "", + expectedError: true, + }, + { + testName: "missing expr field", + jsonBody: map[string]any{ + "expr_type": "CastExpr", + "to_type": "float64", + }, + expectedToType: "", + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + expr, err := parseExpression(tt.jsonBody) + + if tt.expectedError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + + if err != nil { + t.Fatalf("%s failed with unexpected error %v", tt.testName, err) + } + + if !correctExpr(expr, exprName) { + t.Fatalf("%s received incorrect expression, expected %s but received %T", + tt.testName, exprName, expr, + ) + } + + _, ok := expr.(*Expr.CastExpr) + if !ok { + t.Fatalf("%s expected *Expr.CastExpr but received %T", tt.testName, expr) + } + }) + } + }) +} + +// ! update json bodys +func TestSubstraitProjectParse(t *testing.T) { + source1 := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + source2 := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "fortune1000_2024.csv", + "local": false, + }, + } + projectTestID := "project parse test" + + t.Run("basic project operations", func(t *testing.T) { + lpMetaData := newPlanMetaData(projectTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + testName: "project all coluns", + id: 1, + logicalPlan: map[string]any{ + "input": source1, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + expectError: false, + }, + { + testName: "project some columns", + id: 2, + logicalPlan: map[string]any{ + "input": source2, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "Company", + }, + }, + }, + expectError: false, + }, + { + testName: "project zero columns (should fail)", + id: 1, + logicalPlan: map[string]any{ + "input": source1, + "expressions": []map[string]any{}, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + _, err := parseProject(tt.logicalPlan, lpMetaData) + + if tt.expectError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + if err != nil { + t.Fatalf("unexpected error %v", err) + } + }) + } + }) + + t.Run("parsing alias in project", func(t *testing.T) { + lpMetaData := newPlanMetaData(projectTestID) + + tests := []struct { + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + testName: "provide alias for all columns", + logicalPlan: map[string]any{ + "input": source1, + "expressions": []map[string]any{ + { + "expr_type": "Alias", + "name": "country_name", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + { + "expr_type": "Alias", + "name": "cc", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "provide alias no columns", + logicalPlan: map[string]any{ + "input": source2, + "expressions": []map[string]any{}, + }, + expectError: true, + }, + { + testName: "provide alias for some columns", + logicalPlan: map[string]any{ + "input": source2, + "expressions": []map[string]any{ + { + "expr_type": "Alias", + "name": "company_name", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "Company", + }, + }, + { + // no alias on this one (mix alias + plain column) + "expr_type": "ColumnResolve", + "name": "Country", + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + proj, err := parseProject(tt.logicalPlan, lpMetaData) + + if tt.expectError { + if err == nil { + t.Fatalf("%s did not fail when expected to do so", tt.testName) + } + return + } + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + // optional: sanity check the operator runs + basicBatch, err := proj.Next(5) + if err != nil { + t.Fatalf("unexpected Next() error %v", err) + } + t.Logf("%v\n", basicBatch.PrettyPrint()) + }) + } + }) +} + +func TestFilterParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + } + + // Cleanup functions for source files + cleanupSource1 := func() { + defer func() { + fname := "country_full.csv-filter-with-source-test" + if err := os.Remove(fname); err != nil { + t.Logf("error removing file (%s) file:\t%v\n", fname, err) + } + }() + + } + + t.Run("filter with source input", func(t *testing.T) { + filterTestID := "filter with source test" + lpMetaData := newPlanMetaData(filterTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "basic filter with binary expression (column > literal)", + logicalPlan: map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Africa", + "lit_type": "string", + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "filter with column resolve expression", + logicalPlan: map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectError: true, + }, + { + id: 1, + testName: "filter missing expression field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + id: 1, + testName: "filter missing input field (should fail)", + logicalPlan: map[string]any{ + "expression": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + filter, err := parseFilter(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseFilter() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && filter == nil { + t.Errorf("parseFilter() returned nil filter when error was nil") + } + }) + } + }) + + t.Run("filter with project input", func(t *testing.T) { + filterTestID := "filter with project test" + lpMetaData := newPlanMetaData(filterTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "filter projected columns with binary expression", + logicalPlan: map[string]any{ + "input": projectInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": 50, + "lit_type": "int", + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "filter with complex nested expression", + logicalPlan: map[string]any{ + "input": projectInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "And", + "left": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + "right": map[string]any{ + "expr_type": "BinaryExpr", + "op": "NotEqual", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "", + "lit_type": "string", + }, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "filter with invalid expression type (should fail)", + logicalPlan: map[string]any{ + "input": projectInput, + "expression": map[string]any{ + "expr_type": "UnknownType", + "value": "invalid", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + filter, err := parseFilter(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseFilter() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && filter == nil { + t.Errorf("parseFilter() returned nil filter when error was nil") + } + }) + } + }) +} +func TestDistinctParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + } + + // Cleanup functions for source files + cleanupSource1 := func() { + defer func() { + fname := "country_full.csv-distinct-test.csv" + if err := os.Remove(fname); err != nil { + t.Logf("error removing file (%s) file:\t%v\n", fname, err) + } + }() + } + + distinctTestID := "distinct test" + lpMetaData := newPlanMetaData(distinctTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "distinct with single column", + logicalPlan: map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "distinct with multiple columns", + logicalPlan: map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "distinct on project input", + logicalPlan: map[string]any{ + "input": projectInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "distinct missing expressions field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + id: 1, + testName: "distinct with empty expressions (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{}, + }, + expectError: true, + }, + { + id: 1, + testName: "distinct missing input field (should fail)", + logicalPlan: map[string]any{ + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + distinct, err := parseDistinct(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseDistinct() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && distinct == nil { + t.Errorf("parseDistinct() returned nil when error was nil") + } + }) + } +} + +func TestLimitParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + } + + limitTestID := "limit test" + lpMetaData := newPlanMetaData(limitTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectedLimit int64 + expectError bool + }{ + { + id: 1, + testName: "limit with small value", + logicalPlan: map[string]any{ + "input": sourceInput, + "limit": 10, + }, + expectedLimit: 10, + expectError: false, + }, + { + id: 1, + testName: "limit with large value", + logicalPlan: map[string]any{ + "input": sourceInput, + "limit": 10000, + }, + expectedLimit: 10000, + expectError: false, + }, + { + id: 1, + testName: "limit with value thats too large", + logicalPlan: map[string]any{ + "input": sourceInput, + "limit": math.MaxUint16 + 100, + }, + expectedLimit: 1000000, + expectError: true, + }, + { + id: 1, + testName: "limit on projected input", + logicalPlan: map[string]any{ + "input": projectInput, + "limit": 5, + }, + expectedLimit: 5, + expectError: false, + }, + { + id: 1, + testName: "limit missing limit field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + id: 1, + testName: "limit missing input field (should fail)", + logicalPlan: map[string]any{ + "limit": 10, + }, + expectError: true, + }, + { + id: 1, + testName: "limit with zero value (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "limit": 0, + }, + expectError: true, + }, + } + + cleanupSource1 := func() { + defer func() { + fname := "country_full.csv-limit-test.csv" + if err := os.Remove(fname); err != nil { + t.Logf("error removing file (%s) file:\t%v\n", fname, err) + } + }() + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + limit, err := parseLimit(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseLimit() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError { + if limit == nil { + t.Errorf("parseLimit() returned nil when error was nil") + return + } + + // Verify limit value is set correctly + if int64(limit.Remaining) != tt.expectedLimit { + t.Errorf("parseLimit() limit value = %d, expected %d", limit.Remaining, tt.expectedLimit) + } + } + }) + } +} + +func TestSortParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + } + + sortTestID := "sort test" + lpMetaData := newPlanMetaData(sortTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "sort single column ascending", + logicalPlan: map[string]any{ + "input": sourceInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": true, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "sort single column descending", + logicalPlan: map[string]any{ + "input": sourceInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": false, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "sort multiple columns", + logicalPlan: map[string]any{ + "input": projectInput, + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": true, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "sort missing by field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + id: 1, + testName: "sort missing input field (should fail)", + logicalPlan: map[string]any{ + "by": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "asc": true, + }, + }, + }, + expectError: true, + }, + { + id: 1, + testName: "sort with empty by array (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "by": []map[string]any{}, + }, + expectError: true, + }, + } + + cleanupSource1 := func() { + err := os.Remove("country_full.csv-sort-test") + if err != nil { + t.Logf("error closing file: %v\n", err) + } + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + sort, err := parseSort(tt.logicalPlan, lpMetaData) + if tt.expectError { + if err == nil { + t.Errorf("%s expected error but received nil", tt.testName) + } + return + } + if err != nil { + t.Errorf("%s recieved error %v", tt.testName, err) + } + if !tt.expectError && sort == nil { + t.Errorf("parseSort() returned nil when error was nil") + } + }) + } +} + +func TestAggregateParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectNumericInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "country-code", + }, + { + "expr_type": "ColumnResolve", + "name": "region-code", + }, + }, + }, + } + + projectStringInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + } + + aggregateTestID := "aggregate test" + lpMetaData := newPlanMetaData(aggregateTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "aggregate Sum on numeric column", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "aggregate Count on string column", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "aggregate Avg on numeric column", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Avg", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region-code", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "aggregate Min on numeric column", + logicalPlan: map[string]any{ + "input": projectNumericInput, + "aggrs": []map[string]any{ + { + "function": "Min", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "aggregate Max on string column", + logicalPlan: map[string]any{ + "input": projectStringInput, + "aggrs": []map[string]any{ + { + "function": "Max", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + }, + expectError: true, + }, + { + testName: "aggregate with multiple aggregate functions", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "aggregate missing aggrs field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + testName: "aggregate missing input field (should fail)", + logicalPlan: map[string]any{ + "aggrs": []map[string]any{ + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: true, + }, + { + testName: "aggregate with empty aggrs array (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{}, + }, + expectError: true, + }, + { + testName: "aggregate missing function in aggr (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: true, + }, + } + + cleanupSource1 := func() { + err := os.Remove("country_full.csv-aggregate-test") + if err != nil { + t.Logf("error closing file: %v\n", err) + } + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + _, err := parseSingleAggr(tt.logicalPlan, lpMetaData) + if tt.expectError { + if err == nil { + t.Errorf("%s expected error but received nil", tt.testName) + } + return + } + if err != nil { + t.Errorf("%s recieved error %v", tt.testName, err) + } + + }) + } +} + +func TestHavingParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + projectInput := map[string]any{ + "Operator": "Project", + "Project": map[string]any{ + "input": sourceInput, + "expressions": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + } + + havingTestID := "having test" + lpMetaData := newPlanMetaData(havingTestID) + + tests := []struct { + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + testName: "having with simple equality expression", + logicalPlan: map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + }, + expectError: false, + }, + { + testName: "having with complex AND expression", + logicalPlan: map[string]any{ + "input": projectInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "And", + "left": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + "right": map[string]any{ + "expr_type": "BinaryExpr", + "op": "NotEqual", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "", + "lit_type": "string", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "having missing expression field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + }, + expectError: true, + }, + { + testName: "having missing input field (should fail)", + logicalPlan: map[string]any{ + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + }, + expectError: true, + }, + { + testName: "having with OR expression", + logicalPlan: map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Or", + "left": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "USA", + "lit_type": "string", + }, + }, + "right": map[string]any{ + "expr_type": "BinaryExpr", + "op": "Equal", + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + "right": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + }, + }, + expectError: false, + }, + { + testName: "having with literal only expression (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "expression": map[string]any{ + "expr_type": "LiteralResolve", + "value": "Canada", + "lit_type": "string", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + having, err := parseHaving(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseHaving() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && having == nil { + t.Errorf("parseHaving() returned nil when error was nil") + } + }) + } +} + +func TestGroupByParse(t *testing.T) { + // Reusable input operators + sourceInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + } + + groupByTestID := "group by test" + lpMetaData := newPlanMetaData(groupByTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 1, + testName: "group by single column with single aggregate", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "group by multiple columns with multiple aggregates", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + { + "expr_type": "ColumnResolve", + "name": "sub-region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + { + "function": "Sum", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "country-code", + }, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "group by with avg aggregate", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Avg", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "region-code", + }, + }, + }, + }, + expectError: false, + }, + { + id: 1, + testName: "group by missing input field (should fail)", + logicalPlan: map[string]any{ + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: true, + }, + { + id: 1, + testName: "group by missing group_by field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: true, + }, + { + id: 1, + testName: "group by missing aggrs field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + }, + expectError: true, + }, + { + id: 1, + testName: "group by with empty group_by array (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{}, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: true, + }, + { + id: 1, + testName: "group by with empty aggrs array (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "group_by": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{}, + }, + expectError: true, + }, + { + id: 1, + testName: "group by with misspelled group_by field (should fail)", + logicalPlan: map[string]any{ + "input": sourceInput, + "groupBy": []map[string]any{ + { + "expr_type": "ColumnResolve", + "name": "region", + }, + }, + "aggrs": []map[string]any{ + { + "function": "Count", + "expr": map[string]any{ + "expr_type": "ColumnResolve", + "name": "name", + }, + }, + }, + }, + expectError: true, + }, + } + + cleanupSource1 := func() { + if err := os.Remove("country_full.csv-group-by-test"); err != nil { + t.Logf("error occured closing file %v", err) + } + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + defer func() { + if tt.id == 1 { + cleanupSource1() + } + }() + groupBy, err := parseGroupBy(tt.logicalPlan, lpMetaData) + if tt.expectError { + if err == nil { + t.Errorf("%s expected error but received nil", tt.testName) + } + return + } + if err != nil { + t.Errorf("%s received error %v", tt.testName, err) + return + } + if groupBy == nil { + t.Errorf("%s returned nil when error was nil", tt.testName) + } + }) + } +} + +func TestJoinParse(t *testing.T) { + // Reusable input operators using actual test data + // company_test_data.csv: id, department_name, manager_name, manager_email + // user_test_data.csv: id, username, email_address, is_active, age_years, account_balance_usd, average_session_minutes, favorite_color + userInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "user_test_data.csv", + "local": false, + }, + } + + companyInput := map[string]any{ + "Operator": "Source", + "Source": map[string]any{ + "file-name": "company_test_data.csv", + "local": false, + }, + } + + joinTestID := "join test" + lpMetaData := newPlanMetaData(joinTestID) + + tests := []struct { + id int + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + id: 3, + testName: "join users and company on id with inner join", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: false, + }, + { + id: 3, + testName: "join with company left and users right on id", + logicalPlan: map[string]any{ + "left": companyInput, + "right": userInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: false, + }, + { + id: 3, + testName: "join with unsupported join type left (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Left", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join with unsupported join type right (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Right", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join with unsupported join type outer (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Outer", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join missing on field (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + }, + expectError: true, + }, + { + id: 3, + testName: "join missing left field (should fail)", + logicalPlan: map[string]any{ + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join missing right field (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join missing join_type field (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join with empty on array (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{}, + }, + expectError: true, + }, + { + id: 3, + testName: "join with too many on conditions (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "username", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "manager_name", + }, + }, + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "email_address", + }, + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "manager_email", + }, + }, + }, + }, + expectError: false, + }, + { + id: 3, + testName: "join with missing left in on condition", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "right": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + { + id: 3, + testName: "join with missing right in on condition (should fail)", + logicalPlan: map[string]any{ + "left": userInput, + "right": companyInput, + "join_type": "Inner", + "on": []map[string]any{ + { + "left": map[string]any{ + "expr_type": "ColumnResolve", + "name": "id", + }, + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + join, err := parseJoin(tt.logicalPlan, lpMetaData) + if tt.expectError { + if err == nil { + t.Errorf("%s expected error but received nil", tt.testName) + } + return + } + if err != nil { + t.Errorf("%s received error %v", tt.testName, err) + return + } + if join == nil { + t.Errorf("%s returned nil when error was nil", tt.testName) + } + }) + } +} + +func TestSourceParse(t *testing.T) { + t.Run("source with local CSV", func(t *testing.T) { + sourceTestID := "source local csv test" + lpMetaData := newPlanMetaData(sourceTestID) + + tests := []struct { + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + testName: "local CSV file", + logicalPlan: map[string]any{ + "file-name": "country_full.csv", + "local": true, + }, + expectError: false, + }, + { + testName: "local CSV with various extension", + logicalPlan: map[string]any{ + "file-name": "data.csv", + "local": true, + }, + expectError: true, + }, + { + testName: "missing file-name field (should fail)", + logicalPlan: map[string]any{ + "local": true, + }, + expectError: true, + }, + { + testName: "invalid file extension (should fail)", + logicalPlan: map[string]any{ + "file-name": "data.txt", + "local": true, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + source, err := parseSource(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseSource() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && source == nil { + t.Errorf("parseSource() returned nil when error was nil") + } + }) + } + }) + + t.Run("source with remote files", func(t *testing.T) { + sourceTestID := "source remote test" + lpMetaData := newPlanMetaData(sourceTestID) + + tests := []struct { + testName string + logicalPlan jsonOBJ + expectError bool + }{ + { + testName: "remote CSV file", + logicalPlan: map[string]any{ + "file-name": "country_full.csv", + "local": false, + }, + expectError: false, + }, + { + testName: "remote parquet file", + logicalPlan: map[string]any{ + "file-name": "userdata.parquet", + "local": false, + }, + expectError: false, + }, + { + testName: "remote file with unsupported extension (should fail)", + logicalPlan: map[string]any{ + "file-name": "s3://bucket/data.json", + "local": false, + }, + expectError: true, + }, + { + testName: "missing local field (should fail)", + logicalPlan: map[string]any{ + "file-name": "data.csv", + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + source, err := parseSource(tt.logicalPlan, lpMetaData) + if (err != nil) != tt.expectError { + t.Errorf("parseSource() error = %v, expectError = %v", err, tt.expectError) + return + } + if !tt.expectError && source == nil { + t.Errorf("parseSource() returned nil when error was nil") + } + }) + } + }) +} + +func TestContainsFields(t *testing.T) { + tests := []struct { + name string + fields []string + obj jsonOBJ + wantError bool + }{ + { + name: "all fields present", + fields: []string{"file-name", "local"}, + obj: jsonOBJ{"file-name": "test.csv", "local": true}, + wantError: false, + }, + { + name: "missing single field", + fields: []string{"file-name", "local"}, + obj: jsonOBJ{"file-name": "test.csv"}, + wantError: true, + }, + { + name: "missing multiple fields", + fields: []string{"file-name", "local", "format"}, + obj: jsonOBJ{"file-name": "test.csv"}, + wantError: true, + }, + { + name: "extra fields present", + fields: []string{"file-name"}, + obj: jsonOBJ{"file-name": "test.csv", "local": true, "extra": "field"}, + wantError: false, + }, + } + + for _, tt := range tests { + tt := tt // rebind for subtest safety + + t.Run(tt.name, func(t *testing.T) { + err := containsFields(tt.fields, tt.obj) + + if tt.wantError { + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), "missing required fields") { + t.Fatalf("unexpected error message: %q", err.Error()) + } + return + } + + // want no error + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + }) + } +} +func TestCorrectFieldTypes(t *testing.T) { + tests := []struct { + name string + fields []string + fieldTypes []string + obj jsonOBJ + wantError bool + }{ + { + name: "all types correct", + fields: []string{"file-name", "local"}, + fieldTypes: []string{"string", "boolean"}, + obj: jsonOBJ{"file-name": "test.csv", "local": true}, + wantError: false, + }, + { + name: "single string type mismatch", + fields: []string{"file-name"}, + fieldTypes: []string{"string"}, + obj: jsonOBJ{"file-name": 123}, + wantError: true, + }, + { + name: "single boolean type mismatch", + fields: []string{"local"}, + fieldTypes: []string{"boolean"}, + obj: jsonOBJ{"local": "true"}, + wantError: true, + }, + { + name: "int type correct", + fields: []string{"count"}, + fieldTypes: []string{"int"}, + obj: jsonOBJ{"count": 10}, + wantError: false, + }, + { + name: "object type correct", + fields: []string{"meta"}, + fieldTypes: []string{"object"}, + obj: jsonOBJ{"meta": jsonOBJ{"a": 1}}, + wantError: false, + }, + { + name: "array type correct", + fields: []string{"items"}, + fieldTypes: []string{"array"}, + obj: jsonOBJ{"items": []any{1, 2, 3}}, + wantError: false, + }, + { + name: "mixed correct and incorrect types", + fields: []string{"file-name", "local"}, + fieldTypes: []string{"string", "boolean"}, + obj: jsonOBJ{"file-name": "ok.csv", "local": "yes"}, + wantError: true, + }, + { + name: "multiple mismatches", + fields: []string{"file-name", "local"}, + fieldTypes: []string{"string", "boolean"}, + obj: jsonOBJ{"file-name": 10, "local": "false"}, + wantError: true, + }, + { + name: "extra fields ignored", + fields: []string{"file-name"}, + fieldTypes: []string{"string"}, + obj: jsonOBJ{"file-name": "test.csv", "extra": true}, + wantError: false, + }, + { + name: "field and type count mismatch", + fields: []string{"file-name", "local"}, + fieldTypes: []string{"string"}, + obj: jsonOBJ{"file-name": "test.csv", "local": true}, + wantError: true, + }, + { + name: "empty fields and types", + fields: []string{}, + fieldTypes: []string{}, + obj: jsonOBJ{}, + wantError: false, + }, + } + + for _, tt := range tests { + tt := tt // rebind for subtest safety + + t.Run(tt.name, func(t *testing.T) { + err := correctFieldTypes(tt.fields, tt.fieldTypes, tt.obj) + + if tt.wantError && err == nil { + t.Fatalf("expected error, got nil") + } + + if !tt.wantError && err != nil { + t.Fatalf("expected no error, got: %v", err) + } + }) + } +} +func TestConsumePlan(t *testing.T) { + + basePath := filepath.Join("..", "..", "test_data", "substrait_plans", "medium") + example := []FileIntegrationTest{ + { + name: "mid_01_filter_project_sort.json", + filePath: filepath.Join(basePath, "mid_01_filter_project_sort.json"), + sqlEquiv: "select id , username from user_data where age_years > 25 order by username asc", + }, + { + name: "mid_02_filter_project_sort.json", + filePath: filepath.Join(basePath, "mid_02_group_by_aggregate.json"), + sqlEquiv: "tbd", + }, + { + name: "mid_03_join_filter.json", + filePath: filepath.Join(basePath, "mid_03_join_filter.json"), + sqlEquiv: "tbd", + }, + } + for _, test := range example { + t.Run(test.name, func(t *testing.T) { + file, err := os.Open(test.filePath) + if err != nil { + t.Logf("Skipping %s: file not found err :%v \n", test.name, err) + return + } + results, err := consumePlan(file, newPlanMetaData("Test trial")) + if err != nil { + t.Errorf("%s failed with unexpected error %v\n", test.name, err) + } + /* rc, err := results.emitOperator.Next(50) + if err != nil { + t.Errorf("%s failed with unexpected error %v\n", test.name, err) + + } + t.Logf("record batch of %s \n%v\n", test.sqlEquiv, rc.PrettyPrint()) + */ + fmt.Printf("plan: %v\n", results.p) + _, err = results.consumeAll() + if err != nil { + t.Errorf("test failed with error:\t %v\n", err) + } + for _, f := range results.p.localFileNames { + if _, err := os.Open(f); !strings.Contains(err.Error(), "no such file or directory") { + t.Errorf("%s was found when it should have been cleaned up by consumeAll: %v", f, err) + } + } + + }) + } +} + +func TestCleanU(t *testing.T) { + time.Sleep(time.Second * 3) + testCleanUp() +} diff --git a/src/Backend/test_data/base64-encoding/ex b/src/Backend/test_data/base64-encoding/ex new file mode 100644 index 0000000..6d161d2 --- /dev/null +++ b/src/Backend/test_data/base64-encoding/ex @@ -0,0 +1,65 @@ +{ + "Emit": + { + "Operator": "Project", + "Project": + { + "input": + { + "Operator": "Filter", + "Filter": + { + "input": + { + "Operator": "Source", + "Source": + { + "source-node": + { + "file-name": "employees.parquet", + "local": false + }, + "file-name": "employees.parquet", + "local": false + } + }, + "expression": + { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": + { + "expr_type": "ColumnResolve", + "name": "age" + }, + "right": + { + "expr_type": "LiteralResolve", + "value": 30, + "lit_type": "int" + } + }, + "file-name": "employees.parquet", + "local": false + } + }, + "expressions": + [ + { + "expr_type": "ColumnResolve", + "name": "name" + }, + { + "expr_type": "ColumnResolve", + "name": "age" + }, + { + "expr_type": "ColumnResolve", + "name": "salary" + } + ], + "file-name": "employees.parquet", + "local": false + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/base64-encoding/select-filter-2.txt b/src/Backend/test_data/base64-encoding/select-filter-2.txt new file mode 100644 index 0000000..981f60d --- /dev/null +++ b/src/Backend/test_data/base64-encoding/select-filter-2.txt @@ -0,0 +1 @@ +ewogICJFbWl0IjogewogICAgIk9wZXJhdG9yIjogIlByb2plY3QiLAogICAgIlByb2plY3QiOiB7CiAgICAgICJpbnB1dCI6IHsKICAgICAgICAiT3BlcmF0b3IiOiAiRmlsdGVyIiwKICAgICAgICAiRmlsdGVyIjogewogICAgICAgICAgImlucHV0IjogewogICAgICAgICAgICAiT3BlcmF0b3IiOiAiU291cmNlIiwKICAgICAgICAgICAgIlNvdXJjZSI6IHsKICAgICAgICAgICAgICAic291cmNlLW5vZGUiOiB7CiAgICAgICAgICAgICAgICAiZmlsZS1uYW1lIjogImVtcGxveWVlcy5wYXJxdWV0IiwKICAgICAgICAgICAgICAgICJsb2NhbCI6IGZhbHNlCiAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAiZmlsZS1uYW1lIjogImVtcGxveWVlcy5wYXJxdWV0IiwKICAgICAgICAgICAgICAibG9jYWwiOiBmYWxzZQogICAgICAgICAgICB9CiAgICAgICAgICB9LAogICAgICAgICAgImV4cHJlc3Npb24iOiB7CiAgICAgICAgICAgICJleHByX3R5cGUiOiAiQmluYXJ5RXhwciIsCiAgICAgICAgICAgICJvcCI6ICJHcmVhdGVyVGhhbiIsCiAgICAgICAgICAgICJsZWZ0IjogeyJleHByX3R5cGUiOiAiQ29sdW1uUmVzb2x2ZSIsICJuYW1lIjogImFnZSJ9LAogICAgICAgICAgICAicmlnaHQiOiB7ImV4cHJfdHlwZSI6ICJMaXRlcmFsUmVzb2x2ZSIsICJ2YWx1ZSI6IDMwLCAibGl0X3R5cGUiOiAiaW50In0KICAgICAgICAgIH0sCiAgICAgICAgICAiZmlsZS1uYW1lIjogImVtcGxveWVlcy5wYXJxdWV0IiwKICAgICAgICAgICJsb2NhbCI6IGZhbHNlCiAgICAgICAgfQogICAgICB9LAogICAgICAiZXhwcmVzc2lvbnMiOiBbCiAgICAgICAgeyJleHByX3R5cGUiOiAiQ29sdW1uUmVzb2x2ZSIsICJuYW1lIjogIm5hbWUifSwKICAgICAgICB7ImV4cHJfdHlwZSI6ICJDb2x1bW5SZXNvbHZlIiwgIm5hbWUiOiAiYWdlIn0sCiAgICAgICAgeyJleHByX3R5cGUiOiAiQ29sdW1uUmVzb2x2ZSIsICJuYW1lIjogInNhbGFyeSJ9CiAgICAgIF0sCiAgICAgICJmaWxlLW5hbWUiOiAiZW1wbG95ZWVzLnBhcnF1ZXQiLAogICAgICAibG9jYWwiOiBmYWxzZQogICAgfQogIH0KfQ== \ No newline at end of file diff --git a/src/Backend/test_data/base64-encoding/select-filter.txt b/src/Backend/test_data/base64-encoding/select-filter.txt new file mode 100644 index 0000000..4471ad0 --- /dev/null +++ b/src/Backend/test_data/base64-encoding/select-filter.txt @@ -0,0 +1 @@ +ewogICAgIkVtaXQiOiAKICAgIHsKICAgICAgICAiT3BlcmF0b3IiOiAiUHJvamVjdCIsCiAgICAgICAgIlByb2plY3QiOiAKICAgICAgICB7CiAgICAgICAgICAgICJpbnB1dCI6IAogICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAiT3BlcmF0b3IiOiAiRmlsdGVyIiwKICAgICAgICAgICAgICAgICJGaWx0ZXIiOiAKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiaW5wdXQiOiAKICAgICAgICAgICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAgICAgICAgICJPcGVyYXRvciI6ICJTb3VyY2UiLAogICAgICAgICAgICAgICAgICAgICAgICAiU291cmNlIjogCiAgICAgICAgICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICJzb3VyY2Utbm9kZSI6IAogICAgICAgICAgICAgICAgICAgICAgICAgICAgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICJmaWxlLW5hbWUiOiAidGVzdF9lbXBsb3llZXMuY3N2IiwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAibG9jYWwiOiBmYWxzZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICJmaWxlLW5hbWUiOiAidGVzdF9lbXBsb3llZXMuY3N2IiwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICJsb2NhbCI6IGZhbHNlCiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9LAogICAgICAgICAgICAgICAgICAgICJleHByZXNzaW9uIjogCiAgICAgICAgICAgICAgICAgICAgewogICAgICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkJpbmFyeUV4cHIiLAogICAgICAgICAgICAgICAgICAgICAgICAib3AiOiAiR3JlYXRlclRoYW4iLAogICAgICAgICAgICAgICAgICAgICAgICAibGVmdCI6IAogICAgICAgICAgICAgICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkNvbHVtblJlc29sdmUiLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgIm5hbWUiOiAiaWQiCiAgICAgICAgICAgICAgICAgICAgICAgIH0sCiAgICAgICAgICAgICAgICAgICAgICAgICJyaWdodCI6IAogICAgICAgICAgICAgICAgICAgICAgICB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkxpdGVyYWxSZXNvbHZlIiwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICJ2YWx1ZSI6IDUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAibGl0X3R5cGUiOiAiaW50IgogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAiZmlsZS1uYW1lIjogInRlc3RfZW1wbG95ZWVzLmNzdiIsCiAgICAgICAgICAgICAgICAgICAgImxvY2FsIjogZmFsc2UKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgImV4cHJlc3Npb25zIjogCiAgICAgICAgICAgIFsKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkNvbHVtblJlc29sdmUiLAogICAgICAgICAgICAgICAgICAgICJuYW1lIjogImlkIgogICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgIHsKICAgICAgICAgICAgICAgICAgICAiZXhwcl90eXBlIjogIkNvbHVtblJlc29sdmUiLAogICAgICAgICAgICAgICAgICAgICJuYW1lIjogIm5hbWUiCiAgICAgICAgICAgICAgICB9LAogICAgICAgICAgICAgICAgewogICAgICAgICAgICAgICAgICAgICJleHByX3R5cGUiOiAiQ29sdW1uUmVzb2x2ZSIsCiAgICAgICAgICAgICAgICAgICAgIm5hbWUiOiAiYWdlIgogICAgICAgICAgICAgICAgfQogICAgICAgICAgICBdLAogICAgICAgICAgICAiZmlsZS1uYW1lIjogInRlc3RfZW1wbG95ZWVzLmNzdiIsCiAgICAgICAgICAgICJsb2NhbCI6IGZhbHNlCiAgICAgICAgfQogICAgfQp9 \ No newline at end of file diff --git a/src/Backend/test_data/s3_source/source.json b/src/Backend/test_data/s3_source/source.json index c7cd269..ff36fbb 100644 --- a/src/Backend/test_data/s3_source/source.json +++ b/src/Backend/test_data/s3_source/source.json @@ -1,16 +1,14 @@ { - "meta_data":"names of s3 files", - "csv_files":[ - "s3://my-bucket/data/file1.csv", - "s3://my-bucket/data/file2.csv", - "s3://my-bucket/data/file3.csv" + "meta_data": "names of s3 files", + "csv_files": + [ + "company_test_data.csv", + "county_full.csv", + "user_test_data.csv" ], - "json_files":[ - "s3://my-bucket/data/file1.json", - "s3://my-bucket/data/file2.json" - ], - "parquet_files":[ - "s3://my-bucket/data/file1.parquet", + "parquet_files": + [ + "userdata.parquet", "s3://my-bucket/data/file2.parquet" ] } \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_00_test.json b/src/Backend/test_data/substrait_plans/basic/basic_00_test.json new file mode 100644 index 0000000..b649889 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_00_test.json @@ -0,0 +1,34 @@ +{ + "Emit": + { + "Operator": "Filter", + "Filter": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "expression": + { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": + { + "expr_type": "ColumnResolve", + "name": "id" + }, + "right": + { + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" + } + } + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_01_source_filter.json b/src/Backend/test_data/substrait_plans/basic/basic_01_source_filter.json new file mode 100644 index 0000000..4104e9e --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_01_source_filter.json @@ -0,0 +1,34 @@ +{ + "Emit": + { + "Operator": "Filter", + "Filter": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "expression": + { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": + { + "expr_type": "ColumnResolve", + "name": "age_years" + }, + "right": + { + "expr_type": "LiteralResolve", + "value": 10, + "lit_type": "int" + } + } + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_02_project.json b/src/Backend/test_data/substrait_plans/basic/basic_02_project.json new file mode 100644 index 0000000..c378ec8 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_02_project.json @@ -0,0 +1,29 @@ +{ + "Emit": + { + "Operator": "Project", + "Project": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "expressions": + [ + { + "expr_type": "ColumnResolve", + "name": "email_address" + }, + { + "expr_type": "ColumnResolve", + "name": "account_balance_usd" + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_03_sort.json b/src/Backend/test_data/substrait_plans/basic/basic_03_sort.json new file mode 100644 index 0000000..8de2b57 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_03_sort.json @@ -0,0 +1,29 @@ +{ + "Emit": + { + "Operator": "Sort", + "Sort": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "by": + [ + { + "expr": + { + "expr_type": "ColumnResolve", + "name": "username" + }, + "asc": true + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_04_distinct.json b/src/Backend/test_data/substrait_plans/basic/basic_04_distinct.json new file mode 100644 index 0000000..15fa4f2 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_04_distinct.json @@ -0,0 +1,25 @@ +{ + "Emit": + { + "Operator": "Distinct", + "Distinct": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "company_test_data.csv", + "local": false + } + }, + "expressions": + [ + { + "expr_type": "ColumnResolve", + "name": "department_name" + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_05_limit.json b/src/Backend/test_data/substrait_plans/basic/basic_05_limit.json new file mode 100644 index 0000000..fbfab34 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_05_limit.json @@ -0,0 +1,19 @@ +{ + "Emit": + { + "Operator": "Limit", + "Limit": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "limit": 142 + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/basic/basic_06_aggr.json b/src/Backend/test_data/substrait_plans/basic/basic_06_aggr.json new file mode 100644 index 0000000..92feedf --- /dev/null +++ b/src/Backend/test_data/substrait_plans/basic/basic_06_aggr.json @@ -0,0 +1,37 @@ +{ + "Emit": + { + "Operator": "Aggregate", + "Aggregate": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "company_test_data.csv", + "local": false + } + }, + "aggrs": + [ + { + "function": "count", + "expr": + { + "expr_type": "ColumnResolve", + "name": "manager_name" + } + }, + { + "function": "avg", + "expr": + { + "expr_type": "ColumnResolve", + "name": "id" + } + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/medium/mid_01_filter_project_sort.json b/src/Backend/test_data/substrait_plans/medium/mid_01_filter_project_sort.json new file mode 100644 index 0000000..dddc554 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/medium/mid_01_filter_project_sort.json @@ -0,0 +1,70 @@ +{ + "Emit": + { + "Operator": "Sort", + "Sort": + { + "input": + { + "Operator": "Project", + "Project": + { + "input": + { + "Operator": "Filter", + "Filter": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "expression": + { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": + { + "expr_type": "ColumnResolve", + "name": "age_years" + }, + "right": + { + "expr_type": "LiteralResolve", + "value": 25, + "lit_type": "int" + } + } + } + }, + "expressions": + [ + { + "expr_type": "ColumnResolve", + "name": "id" + }, + { + "expr_type": "ColumnResolve", + "name": "username" + } + ] + } + }, + "by": + [ + { + "expr": + { + "expr_type": "ColumnResolve", + "name": "username" + }, + "asc": true + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/medium/mid_02_group_by_aggregate.json b/src/Backend/test_data/substrait_plans/medium/mid_02_group_by_aggregate.json new file mode 100644 index 0000000..9360546 --- /dev/null +++ b/src/Backend/test_data/substrait_plans/medium/mid_02_group_by_aggregate.json @@ -0,0 +1,36 @@ +{ + "Emit": + { + "Operator": "GroupBy", + "GroupBy": + { + "input": + { + "Operator": "Source", + "Source": + { + "file-name": "company_test_data.csv", + "local": false + } + }, + "group_by": + [ + { + "expr_type": "ColumnResolve", + "name": "department_name" + } + ], + "aggrs": + [ + { + "function": "count", + "expr": + { + "expr_type": "ColumnResolve", + "name": "manager_name" + } + } + ] + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/medium/mid_03_join_filter.json b/src/Backend/test_data/substrait_plans/medium/mid_03_join_filter.json new file mode 100644 index 0000000..5db2aab --- /dev/null +++ b/src/Backend/test_data/substrait_plans/medium/mid_03_join_filter.json @@ -0,0 +1,66 @@ +{ + "Emit": + { + "Operator": "Filter", + "Filter": + { + "input": + { + "Operator": "Join", + "Join": + { + "left": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "right": + { + "Operator": "Source", + "Source": + { + "file-name": "company_test_data.csv", + "local": false + } + }, + "join_type": "Inner", + "on": + [ + { + "left": + { + "expr_type": "ColumnResolve", + "name": "id" + }, + "right": + { + "expr_type": "ColumnResolve", + "name": "id" + } + } + ] + } + }, + "expression": + { + "expr_type": "BinaryExpr", + "op": "GreaterThan", + "left": + { + "expr_type": "ColumnResolve", + "name": "account_balance_usd" + }, + "right": + { + "expr_type": "LiteralResolve", + "value": 5000, + "lit_type": "float64" + } + } + } + } +} \ No newline at end of file diff --git a/src/Backend/test_data/substrait_plans/medium/mid_04_join_sort_limit.json b/src/Backend/test_data/substrait_plans/medium/mid_04_join_sort_limit.json new file mode 100644 index 0000000..6a95f7c --- /dev/null +++ b/src/Backend/test_data/substrait_plans/medium/mid_04_join_sort_limit.json @@ -0,0 +1,69 @@ +{ + "Emit": + { + "Operator": "Limit", + "Limit": + { + "input": + { + "Operator": "Sort", + "Sort": + { + "input": + { + "Operator": "Join", + "Join": + { + "left": + { + "Operator": "Source", + "Source": + { + "file-name": "user_test_data.csv", + "local": false + } + }, + "right": + { + "Operator": "Source", + "Source": + { + "file-name": "company_test_data.csv", + "local": false + } + }, + "join_type": "Inner", + "on": + [ + { + "left": + { + "expr_type": "ColumnResolve", + "name": "id" + }, + "right": + { + "expr_type": "ColumnResolve", + "name": "id" + } + } + ] + } + }, + "by": + [ + { + "expr": + { + "expr_type": "ColumnResolve", + "name": "average_session_minutes" + }, + "asc": false + } + ] + } + }, + "limit": 10320 + } + } +} \ No newline at end of file diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 598386b..fdecf23 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -10,10 +10,10 @@ service SSOperation { // The request message containing the operation details. message QueryExecutionRequest { - bytes substrait_logical = 1; // Substrait logical plan: serialized representation of the query execution + // base64 encoded string of the logical plan (custom IR json format) + string logical_plan = 1; // Substrait logical plan: serialized representation of the query execution (contains s3 link to the source data) string sql_statement = 2; // original sql statement string id = 3; // unique id for this client - SourceType source = 4; // (s3 link| base64 data) } // The response message containing the result. @@ -22,10 +22,6 @@ message QueryExecutionResponse { ErrorDetails error_type = 2; // error type if any } -message SourceType{ - string s3_source = 1; // s3 link to the source data - string mime = 2; -} enum returnTypes{ SUCCESS = 0;