diff --git a/source/aws-s3/README.md b/source/aws-s3/README.md index e69de29..3a59cfe 100644 --- a/source/aws-s3/README.md +++ b/source/aws-s3/README.md @@ -0,0 +1,3 @@ +# aws-s3 + +`s3:///` diff --git a/source/aws-s3/s3.go b/source/aws-s3/s3.go index 32c097a..8b58140 100644 --- a/source/aws-s3/s3.go +++ b/source/aws-s3/s3.go @@ -1 +1,125 @@ package awss3 + +import ( + "fmt" + "io" + "net/url" + "os" + "path" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/mattes/migrate/source" +) + +func init() { + source.Register("s3", &s3Driver{}) +} + +type s3Driver struct { + s3client s3iface.S3API + bucket string + prefix string + migrations *source.Migrations +} + +func (s *s3Driver) Open(folder string) (source.Driver, error) { + u, err := url.Parse(folder) + if err != nil { + return nil, err + } + sess, err := session.NewSession() + if err != nil { + return nil, err + } + driver := s3Driver{ + bucket: u.Host, + prefix: strings.Trim(u.Path, "/") + "/", + s3client: s3.New(sess), + migrations: source.NewMigrations(), + } + err = driver.loadMigrations() + if err != nil { + return nil, err + } + return &driver, nil +} + +func (s *s3Driver) loadMigrations() error { + output, err := s.s3client.ListObjects(&s3.ListObjectsInput{ + Bucket: aws.String(s.bucket), + Prefix: aws.String(s.prefix), + Delimiter: aws.String("/"), + }) + if err != nil { + return err + } + for _, object := range output.Contents { + _, fileName := path.Split(aws.StringValue(object.Key)) + m, err := source.DefaultParse(fileName) + if err != nil { + continue + } + if !s.migrations.Append(m) { + return fmt.Errorf("unable to parse file %v", aws.StringValue(object.Key)) + } + } + return nil +} + +func (s *s3Driver) Close() error { + return nil +} + +func (s *s3Driver) First() (uint, error) { + v, ok := s.migrations.First() + if !ok { + return 0, os.ErrNotExist + } + return v, nil +} + +func (s *s3Driver) Prev(version uint) (uint, error) { + v, ok := s.migrations.Prev(version) + if !ok { + return 0, os.ErrNotExist + } + return v, nil +} + +func (s *s3Driver) Next(version uint) (uint, error) { + v, ok := s.migrations.Next(version) + if !ok { + return 0, os.ErrNotExist + } + return v, nil +} + +func (s *s3Driver) ReadUp(version uint) (io.ReadCloser, string, error) { + if m, ok := s.migrations.Up(version); ok { + return s.open(m) + } + return nil, "", os.ErrNotExist +} + +func (s *s3Driver) ReadDown(version uint) (io.ReadCloser, string, error) { + if m, ok := s.migrations.Down(version); ok { + return s.open(m) + } + return nil, "", os.ErrNotExist +} + +func (s *s3Driver) open(m *source.Migration) (io.ReadCloser, string, error) { + key := path.Join(s.prefix, m.Raw) + object, err := s.s3client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, "", err + } + return object.Body, m.Identifier, nil +} diff --git a/source/aws-s3/s3_test.go b/source/aws-s3/s3_test.go index 32c097a..f07d7ff 100644 --- a/source/aws-s3/s3_test.go +++ b/source/aws-s3/s3_test.go @@ -1 +1,82 @@ package awss3 + +import ( + "errors" + "io/ioutil" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/mattes/migrate/source" + st "github.com/mattes/migrate/source/testing" +) + +func Test(t *testing.T) { + s3Client := fakeS3{ + bucket: "some-bucket", + objects: map[string]string{ + "staging/migrations/1_foobar.up.sql": "1 up", + "staging/migrations/1_foobar.down.sql": "1 down", + "prod/migrations/1_foobar.up.sql": "1 up", + "prod/migrations/1_foobar.down.sql": "1 down", + "prod/migrations/3_foobar.up.sql": "3 up", + "prod/migrations/4_foobar.up.sql": "4 up", + "prod/migrations/4_foobar.down.sql": "4 down", + "prod/migrations/5_foobar.down.sql": "5 down", + "prod/migrations/7_foobar.up.sql": "7 up", + "prod/migrations/7_foobar.down.sql": "7 down", + "prod/migrations/not-a-migration.txt": "", + "prod/migrations/0-random-stuff/whatever.txt": "", + }, + } + driver := s3Driver{ + bucket: "some-bucket", + prefix: "prod/migrations/", + migrations: source.NewMigrations(), + s3client: &s3Client, + } + err := driver.loadMigrations() + if err != nil { + t.Fatal(err) + } + st.Test(t, &driver) +} + +type fakeS3 struct { + s3.S3 + bucket string + objects map[string]string +} + +func (s *fakeS3) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { + bucket := aws.StringValue(input.Bucket) + if bucket != s.bucket { + return nil, errors.New("bucket not found") + } + prefix := aws.StringValue(input.Prefix) + delimiter := aws.StringValue(input.Delimiter) + var output s3.ListObjectsOutput + for name := range s.objects { + if strings.HasPrefix(name, prefix) { + if delimiter == "" || !strings.Contains(strings.Replace(name, prefix, "", 1), delimiter) { + output.Contents = append(output.Contents, &s3.Object{ + Key: aws.String(name), + }) + } + } + } + return &output, nil +} + +func (s *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + bucket := aws.StringValue(input.Bucket) + if bucket != s.bucket { + return nil, errors.New("bucket not found") + } + if data, ok := s.objects[aws.StringValue(input.Key)]; ok { + body := ioutil.NopCloser(strings.NewReader(data)) + return &s3.GetObjectOutput{Body: body}, nil + } + return nil, errors.New("object not found") +}