Skip to content

Commit

Permalink
Add deployments subcommand (#95)
Browse files Browse the repository at this point in the history
* Update to replicate-go v0.20.0

* Add deployments subcommand
  • Loading branch information
mattt committed Jun 20, 2024
1 parent dd5e7c4 commit 2a8875a
Show file tree
Hide file tree
Showing 10 changed files with 614 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cmd/replicate/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/replicate/cli/internal/cmd"
"github.com/replicate/cli/internal/cmd/account"
"github.com/replicate/cli/internal/cmd/auth"
"github.com/replicate/cli/internal/cmd/deployment"
"github.com/replicate/cli/internal/cmd/hardware"
"github.com/replicate/cli/internal/cmd/model"
"github.com/replicate/cli/internal/cmd/prediction"
Expand Down Expand Up @@ -41,6 +42,7 @@ func init() {
model.RootCmd,
prediction.RootCmd,
training.RootCmd,
deployment.RootCmd,
hardware.RootCmd,
cmd.ScaffoldCmd,
} {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/cli/browser v1.3.0
github.com/getkin/kin-openapi v0.125.0
github.com/mattn/go-isatty v0.0.20
github.com/replicate/replicate-go v0.18.1
github.com/replicate/replicate-go v0.20.0
github.com/schollz/progressbar/v3 v3.14.4
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.9.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/replicate/replicate-go v0.18.1 h1:4zduLVJxdQAoyl7zKj1e2nxwJVMcT6O/sXe6/eUEtns=
github.com/replicate/replicate-go v0.18.1/go.mod h1:D2x8SztjeUKcaYnSgVu3H2DechufLJWZJB4+TLA3Rag=
github.com/replicate/replicate-go v0.20.0 h1:ujksgJCyJMuRdXjtoRe6wA08NmCTS16LP/x7UtvSLRE=
github.com/replicate/replicate-go v0.20.0/go.mod h1:D2x8SztjeUKcaYnSgVu3H2DechufLJWZJB4+TLA3Rag=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
Expand Down
114 changes: 114 additions & 0 deletions internal/cmd/deployment/create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package deployment

import (
"encoding/json"
"fmt"

"github.com/cli/browser"
"github.com/replicate/replicate-go"
"github.com/spf13/cobra"

"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/identifier"
"github.com/replicate/cli/internal/util"
)

// createCmd represents the create command
var createCmd = &cobra.Command{
Use: "create <[owner/]name> [flags]",
Short: "Create a new deployment",
Example: `replicate deployment create text-to-image --model=stability-ai/sdxl --hardware=gpu-a100-large`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
r8, err := client.NewClient()
if err != nil {
return err
}

opts := &replicate.CreateDeploymentOptions{}

opts.Name = args[0]

flags := cmd.Flags()

modelFlag, _ := flags.GetString("model")
id, err := identifier.ParseIdentifier(modelFlag)
if err != nil {
return fmt.Errorf("expected <owner>/<name>[:version] but got %s", args[0])
}
opts.Model = fmt.Sprintf("%s/%s", id.Owner, id.Name)
if id.Version != "" {
opts.Version = id.Version
} else {
model, err := r8.GetModel(cmd.Context(), id.Owner, id.Name)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
opts.Version = model.LatestVersion.ID
}

opts.Hardware, _ = flags.GetString("hardware")

flagMap := map[string]*int{
"min-instances": &opts.MinInstances,
"max-instances": &opts.MaxInstances,
}
for flagName, optPtr := range flagMap {
if flags.Changed(flagName) {
value, _ := flags.GetInt(flagName)
*optPtr = value
}
}

deployment, err := r8.CreateDeployment(cmd.Context(), *opts)
if err != nil {
return fmt.Errorf("failed to create deployment: %w", err)
}

if flags.Changed("json") || !util.IsTTY() {
bytes, err := json.MarshalIndent(deployment, "", " ")
if err != nil {
return fmt.Errorf("failed to serialize model: %w", err)
}
fmt.Println(string(bytes))
return nil
}

url := fmt.Sprintf("https://replicate.com/deployments/%s/%s", deployment.Owner, deployment.Name)
if flags.Changed("web") {
if util.IsTTY() {
fmt.Println("Opening in browser...")
}

err := browser.OpenURL(url)
if err != nil {
return fmt.Errorf("failed to open browser: %w", err)
}

return nil
}

fmt.Printf("Deployment created: %s\n", url)

return nil
},
}

func init() {
addCreateFlags(createCmd)
}

func addCreateFlags(cmd *cobra.Command) {
cmd.Flags().String("model", "", "Model to deploy")
_ = cmd.MarkFlagRequired("model")

cmd.Flags().String("hardware", "", "SKU of the hardware to run the model")
_ = cmd.MarkFlagRequired("hardware")

cmd.Flags().Int("min-instances", 0, "Minimum number of instances to run the model")
cmd.Flags().Int("max-instances", 0, "Maximum number of instances to run the model")

cmd.Flags().Bool("json", false, "Emit JSON")
cmd.Flags().Bool("web", false, "View on web")
cmd.MarkFlagsMutuallyExclusive("json", "web")
}
135 changes: 135 additions & 0 deletions internal/cmd/deployment/list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package deployment

import (
"encoding/json"
"fmt"
"os/exec"
"strconv"

"github.com/spf13/cobra"

"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/util"

"github.com/charmbracelet/bubbles/table"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)

var baseStyle = lipgloss.NewStyle().
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(lipgloss.Color("240"))

type model struct {
table table.Model
}

func (m model) Init() tea.Cmd { return nil }

func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) { //nolint:gocritic
case tea.KeyMsg:
switch msg.String() {
case "esc":
if m.table.Focused() {
m.table.Blur()
} else {
m.table.Focus()
}
case "q", "ctrl+c":
return m, tea.Quit
case "enter":
selected := m.table.SelectedRow()
if len(selected) == 0 {
return m, nil
}
url := fmt.Sprintf("https://replicate.com/deployments/%s", selected[0])
return m, tea.ExecProcess(exec.Command("open", url), nil)
}
}
m.table, cmd = m.table.Update(msg)
return m, cmd
}

func (m model) View() string {
return baseStyle.Render(m.table.View()) + "\n"
}

var listCmd = &cobra.Command{
Use: "list",
Short: "List deployments",
Example: "replicate deployment list",
RunE: func(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()

r8, err := client.NewClient()
if err != nil {
return err
}

deployments, err := r8.ListDeployments(ctx)
if err != nil {
return fmt.Errorf("failed to get deployments: %w", err)
}

if cmd.Flags().Changed("json") || !util.IsTTY() {
bytes, err := json.MarshalIndent(deployments, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal deployments: %w", err)
}
fmt.Println(string(bytes))
return nil
}

columns := []table.Column{
{Title: "Name", Width: 20},
{Title: "Release #", Width: 10},
{Title: "Model Version", Width: 60},
}

rows := []table.Row{}

for _, deployment := range deployments.Results {
rows = append(rows, table.Row{
deployment.Owner + "/" + deployment.Name,
strconv.Itoa(deployment.CurrentRelease.Number),
fmt.Sprintf("%s:%s", deployment.CurrentRelease.Model, deployment.CurrentRelease.Version),
})
}

t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(true),
table.WithHeight(30),
)

s := table.DefaultStyles()
s.Header = s.Header.
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(lipgloss.Color("240")).
BorderBottom(true).
Bold(false)
s.Selected = s.Selected.
Foreground(lipgloss.Color("229")).
Background(lipgloss.Color("57")).
Bold(false)
t.SetStyles(s)

m := model{t}
if _, err := tea.NewProgram(m).Run(); err != nil {
return err
}

return nil
},
}

func init() {
addListFlags(listCmd)
}

func addListFlags(cmd *cobra.Command) {
cmd.Flags().Bool("json", false, "Emit JSON")
}
39 changes: 39 additions & 0 deletions internal/cmd/deployment/root.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package deployment

import (
"github.com/spf13/cobra"
)

var RootCmd = &cobra.Command{
Use: "deployments [subcommand]",
Short: "Interact with deployments",
Aliases: []string{"deployments", "d"},
}

func init() {
RootCmd.AddGroup(&cobra.Group{
ID: "subcommand",
Title: "Subcommands:",
})
for _, cmd := range []*cobra.Command{
listCmd,
showCmd,
schemaCmd,
createCmd,
updateCmd,
} {
RootCmd.AddCommand(cmd)
cmd.GroupID = "subcommand"
}

// RootCmd.AddGroup(&cobra.Group{
// ID: "alias",
// Title: "Alias commands:",
// })
// for _, cmd := range []*cobra.Command{
// runCmd,
// } {
// RootCmd.AddCommand(cmd)
// cmd.GroupID = "alias"
// }
}
Loading

0 comments on commit 2a8875a

Please sign in to comment.