Skip to content

Commit

Permalink
feat: [#280] Add some methods for Schema (#747)
Browse files Browse the repository at this point in the history
* feat: [#280] Add some methods for Schema

* Add test

* Add test cases

* Update database/schema/sqlite_schema.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
hwbrzzl and coderabbitai[bot] authored Dec 6, 2024
1 parent 814ad23 commit 62b24a9
Show file tree
Hide file tree
Showing 17 changed files with 684 additions and 112 deletions.
2 changes: 2 additions & 0 deletions contracts/database/schema/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ type Blueprint interface {
MediumText(column string) ColumnDefinition
// Primary Specify the primary key(s) for the table.
Primary(column ...string)
// Rename the table to a given name.
Rename(to string)
// RenameIndex Indicate that the given indexes should be renamed.
RenameIndex(from, to string)
// SetTable Set the table that the blueprint operates on.
Expand Down
2 changes: 2 additions & 0 deletions contracts/database/schema/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type Grammar interface {
CompileIndexes(schema, table string) string
// CompilePrimary Compile a primary key command.
CompilePrimary(blueprint Blueprint, command *Command) string
// CompileRename Compile a rename table command.
CompileRename(blueprint Blueprint, command *Command) string
// CompileRenameIndex Compile a rename index command.
CompileRenameIndex(schema Schema, blueprint Blueprint, command *Command) []string
// CompileTables Compile the query to determine the tables.
Expand Down
16 changes: 14 additions & 2 deletions contracts/database/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ type Schema interface {
Connection(name string) Schema
// Create a new table on the schema.
Create(table string, callback func(table Blueprint)) error
// Drop a table from the schema.
Drop(table string) error
// DropColumns Drop columns from a table on the schema.
DropColumns(table string, columns []string) error
// DropIfExists Drop a table from the schema if exists.
DropIfExists(table string) error
// GetColumnListing Get the column listing for a given table.
Expand All @@ -21,24 +25,32 @@ type Schema interface {
GetForeignKeys(table string) ([]ForeignKey, error)
// GetIndexListing Get the names of the indexes for a given table.
GetIndexListing(table string) []string
// GetTableListing Get the table listing for the database.
GetTableListing() []string
// HasColumn Determine if the given table has a given column.
HasColumn(table, column string) bool
// HasColumns Determine if the given table has given columns.
HasColumns(table string, columns []string) bool
// HasIndex Determine if the given table has a given index.
HasIndex(table, index string) bool
// HasTable Determine if the given table exists.
HasTable(table string) bool
HasTable(name string) bool
// HasType Determine if the given type exists.
HasType(name string) bool
// HasView Determine if the given view exists.
HasView(name string) bool
// Migrations Get the migrations.
Migrations() []Migration
// Orm Get the orm instance.
Orm() orm.Orm
// Register migrations.
Register([]Migration)
// Rename a table on the schema.
Rename(from, to string) error
// SetConnection Set the connection of the schema.
SetConnection(name string)
// Sql Execute a sql directly.
Sql(sql string)
Sql(sql string) error
// Table Modify a table on the schema.
Table(table string, callback func(table Blueprint)) error
}
Expand Down
11 changes: 11 additions & 0 deletions database/schema/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ func (r *Blueprint) Primary(column ...string) {
r.indexCommand(constants.CommandPrimary, column)
}

func (r *Blueprint) Rename(to string) {
command := &schema.Command{
Name: constants.CommandRename,
To: to,
}

r.addCommand(command)
}

func (r *Blueprint) RenameIndex(from, to string) {
command := &schema.Command{
Name: constants.CommandRenameIndex,
Expand Down Expand Up @@ -447,6 +456,8 @@ func (r *Blueprint) ToSql(grammar schema.Grammar) []string {
statements = append(statements, grammar.CompileIndex(r, command))
case constants.CommandPrimary:
statements = append(statements, grammar.CompilePrimary(r, command))
case constants.CommandRename:
statements = append(statements, grammar.CompileRename(r, command))
case constants.CommandRenameIndex:
statements = append(statements, grammar.CompileRenameIndex(r.schema, r, command)...)
case constants.CommandUnique:
Expand Down
1 change: 1 addition & 0 deletions database/schema/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
CommandFullText = "fullText"
CommandIndex = "index"
CommandPrimary = "primary"
CommandRename = "rename"
CommandRenameIndex = "renameIndex"
CommandUnique = "unique"
DefaultStringLength = 255
Expand Down
4 changes: 4 additions & 0 deletions database/schema/grammars/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ func (r *Mysql) CompilePrimary(blueprint schema.Blueprint, command *schema.Comma
return fmt.Sprintf("alter table %s add primary key %s(%s)", r.wrap.Table(blueprint.GetTableName()), algorithm, r.wrap.Columnize(command.Columns))
}

func (r *Mysql) CompileRename(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("rename table %s to %s", r.wrap.Table(blueprint.GetTableName()), r.wrap.Table(command.To))
}

func (r *Mysql) CompileRenameIndex(_ schema.Schema, blueprint schema.Blueprint, command *schema.Command) []string {
return []string{
fmt.Sprintf("alter table %s rename index %s to %s", r.wrap.Table(blueprint.GetTableName()), r.wrap.Column(command.From), r.wrap.Column(command.To)),
Expand Down
4 changes: 4 additions & 0 deletions database/schema/grammars/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ func (r *Postgres) CompilePrimary(blueprint schema.Blueprint, command *schema.Co
return fmt.Sprintf("alter table %s add primary key (%s)", r.wrap.Table(blueprint.GetTableName()), r.wrap.Columnize(command.Columns))
}

func (r *Postgres) CompileRename(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s rename to %s", r.wrap.Table(blueprint.GetTableName()), r.wrap.Table(command.To))
}

func (r *Postgres) CompileRenameIndex(_ schema.Schema, _ schema.Blueprint, command *schema.Command) []string {
return []string{
fmt.Sprintf("alter index %s rename to %s", r.wrap.Column(command.From), r.wrap.Column(command.To)),
Expand Down
4 changes: 4 additions & 0 deletions database/schema/grammars/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ func (r *Sqlite) CompileRebuild() string {
return "vacuum"
}

func (r *Sqlite) CompileRename(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s rename to %s", r.wrap.Table(blueprint.GetTableName()), r.wrap.Table(command.To))
}

func (r *Sqlite) CompileRenameIndex(s schema.Schema, blueprint schema.Blueprint, command *schema.Command) []string {
indexes, err := s.GetIndexes(blueprint.GetTableName())
if err != nil {
Expand Down
50 changes: 27 additions & 23 deletions database/schema/grammars/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (r *Sqlserver) CompileColumns(schema, table string) string {
"order by col.column_id", r.wrap.Quote(table), newSchema)
}

func (r *Sqlserver) CompileComment(blueprint schema.Blueprint, command *schema.Command) string {
func (r *Sqlserver) CompileComment(_ schema.Blueprint, _ *schema.Command) string {
return ""
}

Expand All @@ -72,7 +72,7 @@ func (r *Sqlserver) CompileDrop(blueprint schema.Blueprint) string {
return fmt.Sprintf("drop table %s", r.wrap.Table(blueprint.GetTableName()))
}

func (r *Sqlserver) CompileDropAllDomains(domains []string) string {
func (r *Sqlserver) CompileDropAllDomains(_ []string) string {
return ""
}

Expand All @@ -86,15 +86,15 @@ func (r *Sqlserver) CompileDropAllForeignKeys() string {
EXEC sp_executesql @sql;`
}

func (r *Sqlserver) CompileDropAllTables(tables []string) string {
func (r *Sqlserver) CompileDropAllTables(_ []string) string {
return "EXEC sp_msforeachtable 'DROP TABLE ?'"
}

func (r *Sqlserver) CompileDropAllTypes(types []string) string {
func (r *Sqlserver) CompileDropAllTypes(_ []string) string {
return ""
}

func (r *Sqlserver) CompileDropAllViews(views []string) string {
func (r *Sqlserver) CompileDropAllViews(_ []string) string {
return `DECLARE @sql NVARCHAR(MAX) = N'';
SELECT @sql += 'DROP VIEW ' + QUOTENAME(OBJECT_SCHEMA_NAME(object_id)) + '.' + QUOTENAME(name) + ';'
FROM sys.views;
Expand Down Expand Up @@ -129,7 +129,7 @@ func (r *Sqlserver) CompileDropForeign(blueprint schema.Blueprint, command *sche
return fmt.Sprintf("alter table %s drop constraint %s", r.wrap.Table(blueprint.GetTableName()), r.wrap.Column(command.Index))
}

func (r *Sqlserver) CompileDropFullText(blueprint schema.Blueprint, command *schema.Command) string {
func (r *Sqlserver) CompileDropFullText(_ schema.Blueprint, _ *schema.Command) string {
return ""
}

Expand Down Expand Up @@ -238,13 +238,17 @@ func (r *Sqlserver) CompilePrimary(blueprint schema.Blueprint, command *schema.C
r.wrap.Columnize(command.Columns))
}

func (r *Sqlserver) CompileRename(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("sp_rename %s, %s", r.wrap.Quote(r.wrap.Table(blueprint.GetTableName())), r.wrap.Table(command.To))
}

func (r *Sqlserver) CompileRenameIndex(_ schema.Schema, blueprint schema.Blueprint, command *schema.Command) []string {
return []string{
fmt.Sprintf("sp_rename %s, %s, N'INDEX'", r.wrap.Quote(r.wrap.Table(blueprint.GetTableName())+"."+r.wrap.Column(command.From)), r.wrap.Column(command.To)),
}
}

func (r *Sqlserver) CompileTables(database string) string {
func (r *Sqlserver) CompileTables(_ string) string {
return "select t.name as name, schema_name(t.schema_id) as [schema], sum(u.total_pages) * 8 * 1024 as size " +
"from sys.tables as t " +
"join sys.partitions as p on p.object_id = t.object_id " +
Expand All @@ -264,7 +268,7 @@ func (r *Sqlserver) CompileUnique(blueprint schema.Blueprint, command *schema.Co
r.wrap.Columnize(command.Columns))
}

func (r *Sqlserver) CompileViews(database string) string {
func (r *Sqlserver) CompileViews(_ string) string {
return "select name, schema_name(v.schema_id) as [schema], definition from sys.views as v " +
"inner join sys.sql_modules as m on v.object_id = m.object_id " +
"order by name"
Expand All @@ -274,15 +278,15 @@ func (r *Sqlserver) GetAttributeCommands() []string {
return r.attributeCommands
}

func (r *Sqlserver) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
func (r *Sqlserver) ModifyDefault(_ schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetDefault() != nil {
return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault()))
}

return ""
}

func (r *Sqlserver) ModifyNullable(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
func (r *Sqlserver) ModifyNullable(_ schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetNullable() {
return " null"
} else {
Expand All @@ -301,15 +305,15 @@ func (r *Sqlserver) ModifyIncrement(blueprint schema.Blueprint, column schema.Co
return ""
}

func (r *Sqlserver) TypeBigInteger(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeBigInteger(_ schema.ColumnDefinition) string {
return "bigint"
}

func (r *Sqlserver) TypeChar(column schema.ColumnDefinition) string {
return fmt.Sprintf("nchar(%d)", column.GetLength())
}

func (r *Sqlserver) TypeDate(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeDate(_ schema.ColumnDefinition) string {
return "date"
}

Expand All @@ -325,7 +329,7 @@ func (r *Sqlserver) TypeDecimal(column schema.ColumnDefinition) string {
return fmt.Sprintf("decimal(%d, %d)", column.GetTotal(), column.GetPlaces())
}

func (r *Sqlserver) TypeDouble(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeDouble(_ schema.ColumnDefinition) string {
return "double precision"
}

Expand All @@ -342,31 +346,31 @@ func (r *Sqlserver) TypeFloat(column schema.ColumnDefinition) string {
return "float"
}

func (r *Sqlserver) TypeInteger(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeInteger(_ schema.ColumnDefinition) string {
return "int"
}

func (r *Sqlserver) TypeJson(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeJson(_ schema.ColumnDefinition) string {
return "nvarchar(max)"
}

func (r *Sqlserver) TypeJsonb(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeJsonb(_ schema.ColumnDefinition) string {
return "nvarchar(max)"
}

func (r *Sqlserver) TypeLongText(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeLongText(_ schema.ColumnDefinition) string {
return "nvarchar(max)"
}

func (r *Sqlserver) TypeMediumInteger(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeMediumInteger(_ schema.ColumnDefinition) string {
return "int"
}

func (r *Sqlserver) TypeMediumText(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeMediumText(_ schema.ColumnDefinition) string {
return "nvarchar(max)"
}

func (r *Sqlserver) TypeSmallInteger(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeSmallInteger(_ schema.ColumnDefinition) string {
return "smallint"
}

Expand All @@ -379,7 +383,7 @@ func (r *Sqlserver) TypeString(column schema.ColumnDefinition) string {
return "nvarchar(255)"
}

func (r *Sqlserver) TypeText(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeText(_ schema.ColumnDefinition) string {
return "nvarchar(max)"
}

Expand Down Expand Up @@ -419,11 +423,11 @@ func (r *Sqlserver) TypeTimestampTz(column schema.ColumnDefinition) string {
}
}

func (r *Sqlserver) TypeTinyInteger(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeTinyInteger(_ schema.ColumnDefinition) string {
return "tinyint"
}

func (r *Sqlserver) TypeTinyText(column schema.ColumnDefinition) string {
func (r *Sqlserver) TypeTinyText(_ schema.ColumnDefinition) string {
return "nvarchar(255)"
}

Expand Down
8 changes: 2 additions & 6 deletions database/schema/postgres_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func (r *PostgresSchema) DropAllTables() error {
}

func (r *PostgresSchema) DropAllTypes() error {
schema := r.grammar.EscapeNames([]string{r.schema})[0]
types, err := r.GetTypes()
if err != nil {
return err
Expand All @@ -71,7 +70,7 @@ func (r *PostgresSchema) DropAllTypes() error {
var dropTypes, dropDomains []string

for _, t := range types {
if !t.Implicit && schema == t.Schema {
if !t.Implicit && r.schema == t.Schema {
if t.Type == "domain" {
dropDomains = append(dropDomains, fmt.Sprintf("%s.%s", t.Schema, t.Name))
} else {
Expand All @@ -98,20 +97,17 @@ func (r *PostgresSchema) DropAllTypes() error {
}

func (r *PostgresSchema) DropAllViews() error {
schema := r.grammar.EscapeNames([]string{r.schema})[0]

views, err := r.GetViews()
if err != nil {
return err
}

var dropViews []string
for _, view := range views {
if schema == view.Schema {
if r.schema == view.Schema {
dropViews = append(dropViews, fmt.Sprintf("%s.%s", view.Schema, view.Name))
}
}

if len(dropViews) == 0 {
return nil
}
Expand Down
Loading

0 comments on commit 62b24a9

Please sign in to comment.