Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions cmd/okdev-sshd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ func detectShell() string {
return "/bin/sh"
}

func resolveInteractiveShell(serverShell string) string {
if env := strings.TrimSpace(os.Getenv("OKDEV_SHELL")); env != "" {
if _, err := os.Stat(env); err == nil {
return env
}
}
for _, sh := range []string{"/bin/bash", "/bin/zsh", "/bin/sh"} {
if _, err := os.Stat(sh); err == nil {
return sh
}
}
return serverShell
}

func isZshShell(shell string) bool {
return strings.HasSuffix(shell, "/zsh")
}

func loadAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
data, err := os.ReadFile(path)
if err != nil {
Expand Down Expand Up @@ -135,10 +153,11 @@ func sessionHandler(shell string) ssh.Handler {
func buildCmd(s ssh.Session, shell string, extraEnv []string) *exec.Cmd {
var cmd *exec.Cmd
if len(s.RawCommand()) == 0 {
if script := interactiveLoginScript(s, shell); script != "" {
interactiveShell := resolveInteractiveShell(shell)
if script := interactiveLoginScript(s, interactiveShell); script != "" {
cmd = exec.Command(shell, "-lc", script)
} else {
cmd = exec.Command(shell, "-l")
cmd = exec.Command(interactiveShell, "-l")
}
} else {
cmd = exec.Command(shell, "-lc", s.RawCommand())
Expand All @@ -147,10 +166,10 @@ func buildCmd(s ssh.Session, shell string, extraEnv []string) *exec.Cmd {
return cmd
}

func interactiveLoginScript(s ssh.Session, shell string) string {
func interactiveLoginScript(s ssh.Session, interactiveShell string) string {
return buildInteractiveLoginScript(
sessionEnvMap(s),
shell,
interactiveShell,
strings.TrimSpace(os.Getenv("OKDEV_WORKSPACE")),
strings.TrimSpace(os.Getenv("OKDEV_TMUX")),
)
Expand All @@ -165,6 +184,10 @@ func buildInteractiveLoginScript(sessionEnv map[string]string, shell, workspace,
parts = append(parts, fmt.Sprintf("if [ -x %s ]; then %s 2>&1 || echo 'warning: postAttach script failed' >&2; fi", postAttach, postAttach))
}

if zshScript := zshBootstrapScript(workspace, shell); zshScript != "" {
parts = append(parts, zshScript)
}

parts = append(parts, terminalBootstrapScript())

if tmuxFlag == "1" && sessionEnv["OKDEV_NO_TMUX"] != "1" {
Expand All @@ -175,6 +198,14 @@ func buildInteractiveLoginScript(sessionEnv map[string]string, shell, workspace,
return strings.Join(parts, "; ")
}

func zshBootstrapScript(workspace, shell string) string {
if !isZshShell(shell) || workspace == "" {
return ""
}
zshrc := shellQuote(strings.TrimRight(workspace, "/") + "/.okdev/zshrc")
return fmt.Sprintf("if [ -f %s ] && [ ! -e ~/.zshrc ]; then printf '%%s\\n' 'if [ -f %s ]; then' ' source %s' 'fi' > ~/.zshrc; fi", zshrc, zshrc, zshrc)
}

func terminalBootstrapScript() string {
return `if [ "${TERM:-}" = "xterm-ghostty" ]; then export TERM=xterm-256color; fi`
}
Expand Down
84 changes: 81 additions & 3 deletions cmd/okdev-sshd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,46 @@ func TestDetectShellReturnsExistingShell(t *testing.T) {
}
}

func TestDetectShellIgnoresOKDEVShellEnv(t *testing.T) {
t.Setenv("OKDEV_SHELL", "/bin/sh")
got := detectShell()
if got != "/bin/bash" && got != "/bin/sh" {
t.Fatalf("expected command shell fallback to ignore OKDEV_SHELL, got %q", got)
}
}

func TestDetectShellIgnoresNonexistentOKDEVShell(t *testing.T) {
t.Setenv("OKDEV_SHELL", "/definitely/missing/zsh")
got := detectShell()
if got != "/bin/bash" && got != "/bin/sh" {
t.Fatalf("expected command shell fallback to ignore nonexistent OKDEV_SHELL, got %q", got)
}
}

func TestResolveInteractiveShellUsesOKDEVShell(t *testing.T) {
t.Setenv("OKDEV_SHELL", "/bin/sh")
got := resolveInteractiveShell("/bin/bash")
if got != "/bin/sh" {
t.Fatalf("expected /bin/sh from OKDEV_SHELL, got %q", got)
}
}

func TestResolveInteractiveShellIgnoresNonexistentOKDEVShell(t *testing.T) {
t.Setenv("OKDEV_SHELL", "/definitely/missing/zsh")
got := resolveInteractiveShell("/bin/sh")
if got == "/definitely/missing/zsh" {
t.Fatal("expected resolveInteractiveShell to ignore nonexistent OKDEV_SHELL path")
}
}

func TestResolveInteractiveShellFallsBackToDetection(t *testing.T) {
t.Setenv("OKDEV_SHELL", "")
got := resolveInteractiveShell("/bin/sh")
if got != "/bin/bash" && got != "/bin/zsh" && got != "/bin/sh" {
t.Fatalf("expected a valid shell from fallback detection, got %q", got)
}
}

func TestLoadAuthorizedKeysMissingFileReturnsNil(t *testing.T) {
keys, err := loadAuthorizedKeys("/definitely/missing/authorized_keys")
if err != nil {
Expand Down Expand Up @@ -118,6 +158,33 @@ func TestBuildInteractiveLoginScript(t *testing.T) {
}
}

func TestBuildInteractiveLoginScriptWithZshSourcesZshrc(t *testing.T) {
script := buildInteractiveLoginScript(
map[string]string{},
"/bin/zsh",
"/workspace/demo",
"1",
)
if !strings.Contains(script, ".okdev/zshrc") {
t.Fatalf("expected zsh bootstrap to source .okdev/zshrc: %s", script)
}
if !strings.Contains(script, "exec '/bin/zsh' -l") {
t.Fatalf("expected login shell exec with zsh: %s", script)
}
}

func TestBuildInteractiveLoginScriptWithBashDoesNotSourceZshrc(t *testing.T) {
script := buildInteractiveLoginScript(
map[string]string{},
"/bin/bash",
"/workspace/demo",
"1",
)
if strings.Contains(script, ".okdev/zshrc") {
t.Fatalf("expected bash bootstrap to not source .okdev/zshrc: %s", script)
}
}

func TestBuildInteractiveLoginScriptSkipsTmuxWhenDisabled(t *testing.T) {
script := buildInteractiveLoginScript(
map[string]string{"OKDEV_NO_TMUX": "1"},
Expand Down Expand Up @@ -219,10 +286,21 @@ func TestSessionEnvMap(t *testing.T) {
func TestBuildCmdInteractiveShell(t *testing.T) {
t.Setenv("OKDEV_WORKSPACE", "")
t.Setenv("OKDEV_TMUX", "")
t.Setenv("OKDEV_SHELL", "")
cmd := buildCmd(fakeSessionCmd{}, "/bin/sh", nil)
want := `/bin/sh -lc if [ "${TERM:-}" = "xterm-ghostty" ]; then export TERM=xterm-256color; fi; exec '/bin/sh' -l`
if got := strings.Join(cmd.Args, " "); got != want {
t.Fatalf("unexpected interactive args: %q", got)
got := strings.Join(cmd.Args, " ")
// The bootstrap script runs via the server shell (/bin/sh -lc ...),
// but the final exec uses the resolved interactive shell.
if !strings.HasPrefix(got, "/bin/sh -lc") {
t.Fatalf("expected server shell /bin/sh to run the bootstrap script: %q", got)
}
if !strings.Contains(got, "xterm-ghostty") {
t.Fatalf("expected terminal bootstrap in script: %q", got)
}
// The final exec should use whatever resolveInteractiveShell returns.
resolved := resolveInteractiveShell("/bin/sh")
if !strings.Contains(got, "exec '"+resolved+"' -l") {
t.Fatalf("expected exec with resolved interactive shell %q: %q", resolved, got)
}
}

Expand Down
55 changes: 55 additions & 0 deletions internal/cli/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func newInitCmd(opts *Options) *cobra.Command {
var syncLocalOverride string
var syncRemoteOverride string
var sshUserOverride string
var shellOverride string
var stignorePreset string
var setFlags []string

Expand Down Expand Up @@ -70,6 +71,7 @@ func newInitCmd(opts *Options) *cobra.Command {
SyncLocal: syncLocalOverride,
SyncRemote: syncRemoteOverride,
SSHUser: sshUserOverride,
Shell: shellOverride,
}
applyOverrides(vars, overrides)
applyWorkloadDefaults(vars)
Expand Down Expand Up @@ -159,6 +161,12 @@ func newInitCmd(opts *Options) *cobra.Command {
}
}

zshFiles, err := scaffoldZshFiles(abs, vars, force, cmd.OutOrStdout())
if err != nil {
return err
}
scaffolded = append(scaffolded, zshFiles...)

fmt.Fprintf(cmd.OutOrStdout(), "Wrote %s\n", abs)
if resolvedPreset != "" {
fmt.Fprintf(cmd.OutOrStdout(), "Using .stignore preset: %s\n", resolvedPreset)
Expand Down Expand Up @@ -188,6 +196,7 @@ func newInitCmd(opts *Options) *cobra.Command {
cmd.Flags().StringVar(&syncLocalOverride, "sync-local", "", "Local sync path")
cmd.Flags().StringVar(&syncRemoteOverride, "sync-remote", "", "Remote sync path")
cmd.Flags().StringVar(&sshUserOverride, "ssh-user", "", "SSH user")
cmd.Flags().StringVar(&shellOverride, "shell", "", "Shell for interactive SSH sessions (e.g., /bin/zsh)")
cmd.Flags().StringVar(&stignorePreset, "stignore-preset", "", "Local .stignore preset: default|python|node|go|rust")
cmd.Flags().StringArrayVar(&setFlags, "set", nil, "Set a template variable (repeatable: --set key=value)")
return cmd
Expand Down Expand Up @@ -644,6 +653,52 @@ func detectSTIgnorePreset(dir string) string {
return ""
}

func scaffoldZshFiles(configPath string, vars *config.TemplateVars, force bool, w io.Writer) ([]string, error) {
if !isZshShellPath(vars.Shell) {
return nil, nil
}
var wrote []string

zshrcPath := resolveInitScaffoldFilePath(configPath, ".okdev/zshrc")
if _, err := os.Stat(zshrcPath); err != nil || force {
content, err := config.RenderEmbeddedTemplate("templates/zshrc.tmpl", vars)
if err != nil {
return nil, fmt.Errorf("render zshrc template: %w", err)
}
if err := os.MkdirAll(filepath.Dir(zshrcPath), 0o755); err != nil {
return nil, fmt.Errorf("create zshrc directory: %w", err)
}
if err := os.WriteFile(zshrcPath, []byte(content), 0o644); err != nil {
return nil, fmt.Errorf("write zshrc: %w", err)
}
wrote = append(wrote, zshrcPath)
}

examplePath := resolveInitScaffoldFilePath(configPath, ".okdev/zsh-setup.example.sh")
if _, err := os.Stat(examplePath); err != nil || force {
content, err := config.RenderEmbeddedTemplate("templates/zsh-setup.example.sh.tmpl", vars)
if err != nil {
return nil, fmt.Errorf("render zsh-setup example template: %w", err)
}
if err := os.WriteFile(examplePath, []byte(content), 0o644); err != nil {
return nil, fmt.Errorf("write zsh-setup example: %w", err)
}
wrote = append(wrote, examplePath)
}

if len(wrote) > 0 {
fmt.Fprintln(w, "Note: spec.ssh.shell affects interactive SSH sessions only.")
fmt.Fprintln(w, " zsh must exist in the image or be installed by your lifecycle hook.")
fmt.Fprintln(w, " Review .okdev/zsh-setup.example.sh for oh-my-zsh/plugin setup recipes.")
}

return wrote, nil
}

func isZshShellPath(shell string) bool {
return strings.HasSuffix(strings.TrimSpace(shell), "/zsh")
}

func writeInitSTIgnore(configPath string, rendered []byte, templateRef string, stignorePreset string, force bool, projectDirs ...string) (string, bool, error) {
var cfg config.DevEnvironment
if err := yaml.Unmarshal(rendered, &cfg); err != nil {
Expand Down
42 changes: 42 additions & 0 deletions internal/cli/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,45 @@ spec:
t.Fatalf("expected shadowed basic template ref to be persisted, got:\n%s", string(raw))
}
}

func TestInitWithZshShellScaffoldsZshFiles(t *testing.T) {
tmp := t.TempDir()
oldwd, _ := os.Getwd()
t.Cleanup(func() { _ = os.Chdir(oldwd) })
if err := os.Chdir(tmp); err != nil {
t.Fatal(err)
}

cmd := newInitCmd(&Options{})
cmd.SetArgs([]string{"--yes", "--shell", "/bin/zsh"})
cmd.SetIn(strings.NewReader(""))

var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(&out)

if err := cmd.Execute(); err != nil {
t.Fatalf("init: %v", err)
}

zshrcPath := filepath.Join(tmp, ".okdev", "zshrc")
if _, err := os.Stat(zshrcPath); err != nil {
t.Fatalf("expected .okdev/zshrc to be scaffolded: %v", err)
}
examplePath := filepath.Join(tmp, ".okdev", "zsh-setup.example.sh")
if _, err := os.Stat(examplePath); err != nil {
t.Fatalf("expected .okdev/zsh-setup.example.sh to be scaffolded: %v", err)
}

cfgRaw, err := os.ReadFile(filepath.Join(tmp, ".okdev.yaml"))
if err != nil {
t.Fatalf("read config: %v", err)
}
if !strings.Contains(string(cfgRaw), "shell: /bin/zsh") {
t.Fatalf("expected config to contain shell: /bin/zsh, got:\n%s", cfgRaw)
}

if !strings.Contains(out.String(), "zsh-setup.example.sh") {
t.Fatalf("expected guidance message in output, got %q", out.String())
}
}
Loading
Loading