Skip to content

Commit

Permalink
feat: complete postgres client compatibility tests (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException committed Nov 27, 2024
1 parent 37a0adf commit c39a356
Show file tree
Hide file tree
Showing 16 changed files with 731 additions and 62 deletions.
5 changes: 1 addition & 4 deletions compatibility/pg/c/pg_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ int runTests(PGTest *pgTest) {

size_t removeNewline(char *line) {
size_t len = strlen(line);
if (len > 0 && (line[len - 1] == '\n' || line[len - 1] == '\r')) {
line[--len] = '\0';
}
if (len > 0 && line[len - 1] == '\r') {
while (len > 0 && (line[len - 1] == '\n' || line[len - 1] == '\r')) {
line[--len] = '\0';
}
return len;
Expand Down
7 changes: 7 additions & 0 deletions compatibility/pg/clean.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

rm -rf ./c/pg_test \
./csharp/bin ./csharp/obj \
./go/pg \
./java/*.class \
./rust/target ./rust/Cargo.lock
179 changes: 179 additions & 0 deletions compatibility/pg/csharp/PGTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Data;
using Microsoft.Data.SqlClient;
using System.IO;

public class PGTest
{
public class Tests
{
private SqlConnection conn;
private SqlCommand cmd;
private List<Test> tests = new List<Test>();

public void Connect(string ip, int port, string user, string password)
{
try
{
string connectionString = $"Server={ip},{port};User Id={user};Password={password};";
conn = new SqlConnection(connectionString);
conn.Open();
cmd = conn.CreateCommand();
cmd.CommandType = CommandType.Text;
}
catch (SqlException e)
{
throw new Exception(e.Message);
}
}

public void Disconnect()
{
try
{
cmd.Dispose();
conn.Close();
}
catch (SqlException e)
{
throw new Exception(e.Message);
}
}

public void AddTest(string query, string[][] expectedResults)
{
tests.Add(new Test(query, expectedResults));
}

public bool RunTests()
{
foreach (var test in tests)
{
if (!test.Run(cmd))
{
return false;
}
}
return true;
}

public void ReadTestsFromFile(string filename)
{
try
{
using (var reader = new StreamReader(filename))
{
string line;
while ((line = reader.ReadLine()) != null)
{
if (string.IsNullOrWhiteSpace(line)) continue;
string query = line;
var results = new List<string[]>();
while ((line = reader.ReadLine()) != null && !string.IsNullOrWhiteSpace(line))
{
results.Add(line.Split(','));
}
string[][] expectedResults = results.ToArray();
AddTest(query, expectedResults);
}
}
}
catch (IOException e)
{
Console.Error.WriteLine(e.Message);
Environment.Exit(1);
}
}

public class Test
{
private string query;
private string[][] expectedResults;

public Test(string query, string[][] expectedResults)
{
this.query = query;
this.expectedResults = expectedResults;
}

public bool Run(SqlCommand cmd)
{
try
{
Console.WriteLine("Running test: " + query);
cmd.CommandText = query;
using (var reader = cmd.ExecuteReader())
{
if (!reader.HasRows)
{
if (expectedResults.Length != 0)
{
Console.Error.WriteLine($"Expected {expectedResults.Length} rows, got 0");
return false;
}
Console.WriteLine("Returns 0 rows");
return true;
}
if (reader.FieldCount != expectedResults[0].Length)
{
Console.Error.WriteLine($"Expected {expectedResults[0].Length} columns, got {reader.FieldCount}");
return false;
}
int rows = 0;
while (reader.Read())
{
for (int col = 0; col < expectedResults[rows].Length; col++)
{
string result = reader.GetString(col);
if (expectedResults[rows][col] != result)
{
Console.Error.WriteLine($"Expected:\n'{expectedResults[rows][col]}'");
Console.Error.WriteLine($"Result:\n'{result}'\nRest of the results:");
while (reader.Read())
{
Console.Error.WriteLine(reader.GetString(0));
}
return false;
}
}
rows++;
}
Console.WriteLine("Returns " + rows + " rows");
if (rows != expectedResults.Length)
{
Console.Error.WriteLine($"Expected {expectedResults.Length} rows");
return false;
}
return true;
}
}
catch (SqlException e)
{
Console.Error.WriteLine(e.Message);
return false;
}
}
}
}

public static void Main(string[] args)
{
if (args.Length < 5)
{
Console.Error.WriteLine("Usage: PGTest <ip> <port> <user> <password> <testFile>");
Environment.Exit(1);
}

var tests = new Tests();
tests.Connect(args[0], int.Parse(args[1]), args[2], args[3]);
tests.ReadTestsFromFile(args[4]);

if (!tests.RunTests())
{
tests.Disconnect();
Environment.Exit(1);
}
tests.Disconnect();
}
}
12 changes: 12 additions & 0 deletions compatibility/pg/csharp/PGTest.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" Version="5.2" />
</ItemGroup>

</Project>
179 changes: 179 additions & 0 deletions compatibility/pg/go/pg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package main

import (
"bufio"
"database/sql"
"fmt"
"os"
"strconv"
"strings"

_ "github.com/lib/pq"
)

type Test struct {
query string
expectedResults [][]string
}

type Tests struct {
conn *sql.DB
tests []Test
}

func (t *Tests) connect(ip string, port int, user, password string) {
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable", ip, port, user, password)
var err error
t.conn, err = sql.Open("postgres", connStr)
if err != nil {
panic(err)
}
}

func (t *Tests) disconnect() {
err := t.conn.Close()
if err != nil {
panic(err)
}
}

func (t *Tests) addTest(query string, expectedResults [][]string) {
t.tests = append(t.tests, Test{query, expectedResults})
}

func (t *Tests) readTestsFromFile(filename string) {
file, err := os.Open(filename)
if err != nil {
panic(err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
panic(err)
}
}(file)

scanner := bufio.NewScanner(file)
var query string
var results [][]string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
if query != "" {
t.addTest(query, results)
query = ""
results = nil
}
} else if query == "" {
query = line
} else {
results = append(results, strings.Split(line, ","))
}
}
if query != "" {
t.addTest(query, results)
}
if err := scanner.Err(); err != nil {
panic(err)
}
}

func (t *Tests) runTests() bool {
for _, test := range t.tests {
if !t.runTest(test) {
return false
}
}
return true
}

func (t *Tests) runTest(test Test) bool {
fmt.Println("Running test:", test.query)
rows, err := t.conn.Query(test.query)
if err != nil {
fmt.Println("Error executing query:", err)
return false
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
panic(err)
}
}(rows)

columns, err := rows.Columns()
if err != nil {
fmt.Println("Error getting columns:", err)
return false
}

if len(test.expectedResults) == 0 {
if len(columns) != 0 {
fmt.Printf("Expected 0 columns, got %d\n", len(columns))
return false
}
fmt.Println("Returns 0 rows")
return true
}

if len(columns) != len(test.expectedResults[0]) {
fmt.Printf("Expected %d columns, got %d\n", len(test.expectedResults[0]), len(columns))
return false
}

var rowCount int
for rows.Next() {
row := make([]string, len(columns))
rowPointers := make([]interface{}, len(columns))
for i := range row {
rowPointers[i] = &row[i]
}
if err := rows.Scan(rowPointers...); err != nil {
fmt.Println("Error scanning row:", err)
return false
}

for i, expected := range test.expectedResults[rowCount] {
if row[i] != expected {
fmt.Printf("Expected: '%s', got: '%s'\n", expected, row[i])
return false
}
}
rowCount++
}

if rowCount != len(test.expectedResults) {
fmt.Printf("Expected %d rows, got %d\n", len(test.expectedResults), rowCount)
return false
}

fmt.Printf("Returns %d rows\n", rowCount)
return true
}

func main() {
if len(os.Args) < 6 {
fmt.Println("Usage: pg_test <ip> <port> <user> <password> <testFile>")
os.Exit(1)
}

ip := os.Args[1]
port, ok := strconv.Atoi(os.Args[2])
if ok != nil {
fmt.Println("Invalid port:", os.Args[2])
os.Exit(1)
}
user := os.Args[3]
password := os.Args[4]
testFile := os.Args[5]

tests := &Tests{}
tests.connect(ip, port, user, password)
tests.readTestsFromFile(testFile)

if !tests.runTests() {
tests.disconnect()
os.Exit(1)
}
tests.disconnect()
}
Loading

0 comments on commit c39a356

Please sign in to comment.