Skip to content

Commit

Permalink
Fix checking out when user says no
Browse files Browse the repository at this point in the history
  • Loading branch information
bfirsh committed Mar 11, 2021
1 parent 673ea65 commit fe817e2
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 14 deletions.
99 changes: 98 additions & 1 deletion end-to-end-test/end_to_end_test/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import subprocess
import pytest # type: ignore
import shlex
from .utils import get_env


Expand Down Expand Up @@ -88,10 +89,11 @@ def main():
# checking out experiment
output_dir = str(tmpdir_factory.mktemp("output"))
subprocess.run(
["keepsake", "checkout", "-o", output_dir, exp["id"]],
f"keepsake checkout -o {shlex.quote(output_dir)} {exp['id']}",
cwd=tmpdir,
env=env,
check=True,
shell=True,
)

with open(os.path.join(output_dir, rand, rand)) as f:
Expand Down Expand Up @@ -227,3 +229,98 @@ def main():
]
expected_paths = ["foo.txt"]
assert set(actual_paths) == set(expected_paths)


def test_checkout_when_files_exist(tmpdir, tmpdir_factory):
tmpdir = str(tmpdir)
repository = "file://" + str(tmpdir_factory.mktemp("repository"))

rand = str(random.randint(0, 100000))
os.mkdir(os.path.join(tmpdir, rand))
with open(os.path.join(tmpdir, rand, rand), "w") as f:
f.write(rand)

with open(os.path.join(tmpdir, "foo.txt"), "w") as f:
f.write("original")

with open(os.path.join(tmpdir, "keepsake.yaml"), "w") as f:
f.write(
"""
repository: {repository}
""".format(
repository=repository
)
)
with open(os.path.join(tmpdir, "train.py"), "w") as f:
f.write(
"""
import os
import keepsake
def main():
experiment = keepsake.init()
experiment.checkpoint(path="foo.txt")
if __name__ == "__main__":
main()
"""
)

env = get_env()
cmd = ["python", "train.py"]
subprocess.run(cmd, cwd=tmpdir, env=env, check=True)

experiments = json.loads(
subprocess.run(
["keepsake", "ls", "--json"],
cwd=tmpdir,
env=env,
capture_output=True,
check=True,
).stdout
)
assert len(experiments) == 1

exp = experiments[0]

with open(os.path.join(tmpdir, "foo.txt"), "w") as f:
f.write("new")

# stdin is closed
result = subprocess.run(
f"keepsake checkout {exp['id']}", cwd=tmpdir, env=env, check=False, shell=True,
)
assert result.returncode > 0

# Checkout does not overwrite
with open(os.path.join(tmpdir, "foo.txt")) as f:
assert f.read() == "new"

# don't overwrite
result = subprocess.run(
f"keepsake checkout {exp['id']}",
cwd=tmpdir,
env=env,
shell=True,
check=False,
input=b"n\n",
)
assert result.returncode > 0

# Checkout does not overwrite
with open(os.path.join(tmpdir, "foo.txt")) as f:
assert f.read() == "new"

# do overwrite
subprocess.run(
f"keepsake checkout {exp['id']}",
cwd=tmpdir,
env=env,
check=True,
shell=True,
input=b"y\n",
)

# Checkout does overwrite!
with open(os.path.join(tmpdir, "foo.txt")) as f:
assert f.read() == "original"
16 changes: 8 additions & 8 deletions go/pkg/cli/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ func overwriteDisplayPathPrompt(displayPath string, force bool) error {
fmt.Println()
// This is scary! See https://github.com/replicate/keepsake/issues/300
doOverwrite, err := console.InteractiveBool{
Prompt: "Do you want to continue?",
Default: false,
Prompt: "Do you want to continue?",
Default: false,
NonDefaultFlag: "-f",
}.Read()
if err != nil {
return err
}
if !doOverwrite {
console.Info("Aborting.")
return nil
return fmt.Errorf("Aborting.")
}
}
} else if !force {
Expand All @@ -137,15 +137,15 @@ func overwriteDisplayPathPrompt(displayPath string, force bool) error {
fmt.Println()
// This is scary! See https://github.com/replicate/keepsake/issues/300
doOverwrite, err := console.InteractiveBool{
Prompt: "Do you want to continue?",
Default: false,
Prompt: "Do you want to continue?",
Default: false,
NonDefaultFlag: "-f",
}.Read()
if err != nil {
return err
}
if !doOverwrite {
console.Info("Aborting.")
return nil
return fmt.Errorf("Aborting.")
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions go/pkg/cli/rm.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ func removeExperimentOrCheckpoint(cmd *cobra.Command, prefixes []string) error {
}
}
continueDelete, err := console.InteractiveBool{
Prompt: "\nDo you want to continue?",
Default: false,
Prompt: "\nDo you want to continue?",
Default: false,
NonDefaultFlag: "-f",
}.Read()
if err != nil {
return err
}
if !continueDelete {
return nil
return fmt.Errorf("Aborting.")
}
}

Expand Down
5 changes: 3 additions & 2 deletions go/pkg/console/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func DebugOutput(line string) {
ConsoleInstance.DebugOutput(line)
}

func IsTTY() bool {
return isatty.IsTerminal(os.Stdout.Fd())
// IsTTY checks if a file is a TTY or not. E.g. IsTTY(os.Stdin)
func IsTTY(f *os.File) bool {
return isatty.IsTerminal(f.Fd())
}
6 changes: 6 additions & 0 deletions go/pkg/console/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package console
import (
"bufio"
"fmt"
"io"
"os"
"strings"

Expand Down Expand Up @@ -76,6 +77,8 @@ func (i Interactive) Read() (string, error) {
type InteractiveBool struct {
Prompt string
Default bool
// NonDefaultFlag is the flag to suggest passing to do the thing which isn't default when running inside a script
NonDefaultFlag string
}

func (i InteractiveBool) Read() (bool, error) {
Expand All @@ -88,6 +91,9 @@ func (i InteractiveBool) Read() (bool, error) {
reader := bufio.NewReader(os.Stdin)
text, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
return false, fmt.Errorf("stdin is closed. If you're running in a script, you need to pass the '%s' option.", i.NonDefaultFlag)
}
return false, err
}
text = strings.ToLower(strings.TrimSpace(text))
Expand Down

0 comments on commit fe817e2

Please sign in to comment.