/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.database.rdbms.driver;

import com.healthmarketscience.sqlbuilder.CreateTableQuery;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import java.io.FileInputStream;
import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.rdbms.PredicateInfo;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.StringUtils;
import org.postgresql.PGConnection;

public class PostgreSQLDriver
implements DatabaseDriver {
    public static final String DEFAULT_HOST = "localhost";
    public static final String DEFAULT_PORT = "5432";
    private final HikariDataSource dataSource;

    public PostgreSQLDriver(String databaseName, boolean clearDatabase) {
        this(DEFAULT_HOST, DEFAULT_PORT, databaseName, clearDatabase);
    }

    public PostgreSQLDriver(String host, String port, String databaseName, boolean clearDatabase) {
        this(String.format("jdbc:postgresql://%s:%s/%s?loggerLevel=OFF", host, port, databaseName), databaseName, clearDatabase);
    }

    public PostgreSQLDriver(String connectionString, String databaseName, boolean clearDatabase) {
        try {
            Class.forName("org.postgresql.Driver");
        }
        catch (ClassNotFoundException ex) {
            throw new RuntimeException("Could not find postgres driver. Please check classpath.", ex);
        }
        HikariConfig config = new HikariConfig();
        config.setJdbcUrl(connectionString);
        config.setMaximumPoolSize(Math.min(8, Parallel.NUM_THREADS * 2));
        this.dataSource = new HikariDataSource(config);
        if (clearDatabase) {
            this.executeUpdate("DROP SCHEMA public CASCADE");
            this.executeUpdate("CREATE SCHEMA public");
            this.executeUpdate("GRANT ALL ON SCHEMA public TO public");
        }
    }

    @Override
    public void close() {
        this.dataSource.close();
    }

    @Override
    public Connection getConnection() {
        try {
            return this.dataSource.getConnection();
        }
        catch (SQLException ex) {
            throw new RuntimeException("Failed to get connection from pool.", ex);
        }
    }

    @Override
    public boolean supportsBulkCopy() {
        return true;
    }

    @Override
    public void bulkCopy(String path, String delimiter, boolean hasTruth, PredicateInfo predicateInfo, Partition partition) {
        String sql = String.format("COPY %s(%s%s) FROM STDIN WITH DELIMITER '%s'", predicateInfo.tableName(), StringUtils.join(predicateInfo.argumentColumns(), ", "), hasTruth ? ", value" : "", delimiter);
        this.setColumnDefault(predicateInfo.tableName(), "partition_id", "'" + partition.getID() + "'");
        try (Connection connection = this.getConnection();
             FileInputStream inFile = new FileInputStream(path);){
            PGConnection pgConnection = connection.unwrap(PGConnection.class);
            pgConnection.getCopyAPI().copyIn(sql, inFile);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Could not perform bulk insert on " + predicateInfo.predicate(), ex);
        }
        catch (IOException ex) {
            throw new RuntimeException("Error bulk copying file: " + path, ex);
        }
        finally {
            this.dropColumnDefault(predicateInfo.tableName(), "partition_id");
        }
    }

    public void setColumnDefault(String tableName, String columnName, String defaultValue) {
        String sql = String.format("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s", tableName, columnName, defaultValue);
        try (Connection connection = this.getConnection();){
            PreparedStatement statement = connection.prepareStatement(sql);
            statement.executeUpdate();
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Could not set the column default of %s for %s.%s.", defaultValue, tableName, columnName), ex);
        }
    }

    public void dropColumnDefault(String tableName, String columnName) {
        String sql = String.format("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT", tableName, columnName);
        try (Connection connection = this.getConnection();){
            PreparedStatement statement = connection.prepareStatement(sql);
            statement.executeUpdate();
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Could not drop the column default for %s.%s.", tableName, columnName), ex);
        }
    }

    @Override
    public String getTypeName(ConstantType type) {
        switch (type) {
            case Double: {
                return "DOUBLE PRECISION";
            }
            case Integer: {
                return "INT";
            }
            case String: {
                return "TEXT";
            }
            case Long: {
                return "BIGINT";
            }
            case Date: {
                return "DATE";
            }
            case UniqueIntID: {
                return "INT";
            }
            case UniqueStringID: {
                return "TEXT";
            }
        }
        throw new IllegalStateException("Unknown ConstantType: " + (Object)((Object)type));
    }

    @Override
    public String getSurrogateKeyColumnDefinition(String columnName) {
        return columnName + " SERIAL PRIMARY KEY";
    }

    @Override
    public String getDoubleTypeName() {
        return "DOUBLE PRECISION";
    }

    @Override
    public String getUpsert(String tableName, String[] columns, String[] keyColumns) {
        ArrayList<String> updateValues = new ArrayList<String>();
        for (String column : columns) {
            updateValues.add(String.format("%s = EXCLUDED.%s", column, column));
        }
        ArrayList<String> sql = new ArrayList<String>();
        sql.add("INSERT INTO " + tableName + "");
        sql.add("\t(" + StringUtils.join((Object[])columns, ", ") + ")");
        sql.add("VALUES");
        sql.add("\t(" + StringUtils.repeat("?", ", ", columns.length) + ")");
        sql.add("ON CONFLICT");
        sql.add("\t(" + StringUtils.join((Object[])keyColumns, ", ") + ")");
        sql.add("DO UPDATE SET");
        sql.add("\t" + StringUtils.join(updateValues, ", "));
        return StringUtils.join(sql, "\n");
    }

    private void executeUpdate(String sql) {
        try (Connection connection = this.getConnection();
             Statement stmt = connection.createStatement();){
            stmt.executeUpdate(sql);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Failed to execute a general update: [" + sql + "].", ex);
        }
    }

    @Override
    public String finalizeCreateTable(CreateTableQuery createTable) {
        return ((CreateTableQuery)createTable.validate()).toString().replace("CREATE TABLE", "CREATE UNLOGGED TABLE");
    }

    @Override
    public String getStringAggregate(String columnName, String delimiter, boolean distinct) {
        if (delimiter.contains("'")) {
            throw new IllegalArgumentException("Delimiter (" + delimiter + ") may not contain a single quote.");
        }
        return String.format("STRING_AGG(DISTINCT CAST(%s AS TEXT), '%s')", columnName, delimiter);
    }
}

