forked from knusbaum/go9p
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.go
323 lines (299 loc) · 8.74 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
// Package go9p contains contains an interface definition for a 9p2000 server, `Srv`.
// along with a few functions that will serve the 9p2000 protocol using a `Srv`.
//
// Most people wanting to implement a 9p filesystem should start in the subpackage
// github.com/knusbaum/go9p/fs, which contains tools for constructing a file system
// which can be served using the functions in this package.
//
// The subpackage github.com/knusbaum/go9p/proto contains the protocol implementation.
// It is used by the other packages to send and receive 9p2000 messages. It may be
// useful to someone who wants to investigate 9p2000 at the protocol level.
package go9p
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log"
"net"
"reflect"
"sync"
"github.com/knusbaum/go9p/proto"
)
// If Verbose is true, incoming and outgoing 9p messages will be printed to stderr.
var Verbose bool
func verboseLog(msg string, args ...interface{}) {
if Verbose {
log.Printf(msg, args...)
}
}
// The Srv interface is used to handle 9p2000 messages.
// Each function handles a specific type of message, and
// should return a response. If some expected error occurs,
// for example a TOpen message for a file with the wrong
// permissions, a proto.TError message should be returned
// rather than a go error. Returning a go error indicates that
// something has gone wrong with the server, and when used with
// Serve and PostSrv, will cause the connection to be terminated
// or the file descriptor to be closed.
type Srv interface {
NewConn() Conn
Version(Conn, *proto.TRVersion) (proto.FCall, error)
Auth(Conn, *proto.TAuth) (proto.FCall, error)
Attach(Conn, *proto.TAttach) (proto.FCall, error)
Walk(Conn, *proto.TWalk) (proto.FCall, error)
Open(Conn, *proto.TOpen) (proto.FCall, error)
Create(Conn, *proto.TCreate) (proto.FCall, error)
Read(Conn, *proto.TRead) (proto.FCall, error)
Write(Conn, *proto.TWrite) (proto.FCall, error)
Clunk(Conn, *proto.TClunk) (proto.FCall, error)
Remove(Conn, *proto.TRemove) (proto.FCall, error)
Stat(Conn, *proto.TStat) (proto.FCall, error)
Wstat(Conn, *proto.TWstat) (proto.FCall, error)
}
// Conn represents an individual connection to a 9p server.
// In the case of a server listening on a network, there
// may be many clients connected to a given server at once.
type Conn interface {
TagContext(uint16) context.Context
DropContext(uint16)
}
func handleConnection(nc net.Conn, srv Srv) {
defer nc.Close()
read := bufio.NewReader(nc)
err := handleIOAsync(read, nc, "", srv)
if err != nil {
log.Printf("%v\n", err)
}
}
// handleIO seems to be about 10x faster than handleIOAsync
// in my experiments. It would be nice to be able to keep some
// performance without making the reading, handling, and
// writing of calls synchronous.
func handleIO(r io.Reader, w io.Writer, srv Srv) error {
conn := srv.NewConn()
for {
call, err := proto.ParseCall(r)
if err != nil {
return err
}
verboseLog("=in=> %s\n", call)
resp, err := handleCall(call, srv, conn)
if err != nil {
return err
}
if resp == nil {
// This case happens when an active tag is
// flushed.
continue
}
verboseLog("<=out= %s\n", resp)
_, err = w.Write(resp.Compose())
if err != nil {
return err
}
}
return nil
}
func handleIOAsync(r io.Reader, w io.Writer, uname string, srv Srv) error {
incoming := make(chan proto.FCall, 100)
outgoing := make(chan proto.FCall, 100)
conn := srv.NewConn()
// Write the outgoing
var outgoingWG sync.WaitGroup
defer func() { outgoingWG.Wait() }()
outgoingWG.Add(1)
go func() {
outgoingWG.Done()
for call := range outgoing {
verboseLog("<=out= %s\n", call)
_, err := w.Write(call.Compose())
if err != nil {
log.Printf("Protocol error: %v\n", err)
}
}
}()
var workerWG sync.WaitGroup
defer func() { workerWG.Wait(); close(outgoing) }()
for i := 0; i < 100; i++ {
workerWG.Add(1)
go func() {
defer workerWG.Done()
for call := range incoming {
resp, err := handleCall(call, srv, conn)
if err != nil {
log.Printf("Protocol error: %v\n", err)
//return err
return
}
if resp == nil {
// This case happens when an active tag is
// flushed.
continue
}
outgoing <- resp
}
}()
}
// Read incoming
defer close(incoming)
for {
call, err := proto.ParseCall(r)
verboseLog("=in=> %s\n", call)
if err != nil {
log.Printf("Protocol error: %v\n", err)
return err
}
if ta, ok := call.(*proto.TAttach); ok && uname != "" {
// TODO: it would be nice to move this down into Srv so that we
// can respond with RError instead of just killing the connection.
if ta.Uname != uname {
outgoing <- &proto.RError{proto.Header{proto.Rerror, ta.Tag}, fmt.Sprintf("invalid user %s", ta.Uname)}
return fmt.Errorf("Protocol error: client connected with cert for %s, but attached with user name %s", uname, ta.Uname)
}
fmt.Printf("UNAME %s -> %s\n", ta.Uname, uname)
ta.Uname = uname
}
select {
case incoming <- call:
default:
panic("FAILED TO QUEUE INCOMING!")
}
}
return nil
}
func handleCall(call proto.FCall, srv Srv, conn Conn) (proto.FCall, error) {
ctx := conn.TagContext(call.GetTag())
var (
ret proto.FCall
err error
)
switch call.(type) {
case *proto.TRVersion:
ret, err = srv.Version(conn, call.(*proto.TRVersion))
case *proto.TAuth:
ret, err = srv.Auth(conn, call.(*proto.TAuth))
case *proto.TAttach:
ret, err = srv.Attach(conn, call.(*proto.TAttach))
case *proto.TFlush:
flush := call.(*proto.TFlush)
//conn.DropContext(flush.Oldtag)
ret, err = &proto.RFlush{proto.Header{proto.Rflush, flush.Tag}}, nil
case *proto.TWalk:
ret, err = srv.Walk(conn, call.(*proto.TWalk))
case *proto.TOpen:
ret, err = srv.Open(conn, call.(*proto.TOpen))
case *proto.TCreate:
ret, err = srv.Create(conn, call.(*proto.TCreate))
case *proto.TRead:
ret, err = srv.Read(conn, call.(*proto.TRead))
case *proto.TWrite:
ret, err = srv.Write(conn, call.(*proto.TWrite))
case *proto.TClunk:
ret, err = srv.Clunk(conn, call.(*proto.TClunk))
case *proto.TRemove:
ret, err = srv.Remove(conn, call.(*proto.TRemove))
case *proto.TStat:
ret, err = srv.Stat(conn, call.(*proto.TStat))
case *proto.TWstat:
ret, err = srv.Wstat(conn, call.(*proto.TWstat))
default:
return nil, fmt.Errorf("Invalid call: %s", reflect.TypeOf(call))
}
if ctx.Err() != nil {
return nil, nil
}
conn.DropContext(call.GetTag())
return ret, err
}
// ServeReadWriter accepts an io.Reader an io.Writer, and an Srv.
// It reads 9p2000 messages from r, handles them with srv, and
// writes the responses to w.
func ServeReadWriter(r io.Reader, w io.Writer, srv Srv) error {
return handleIOAsync(r, w, "", srv)
}
// Serve serves srv on the given address, addr.
func Serve(addr string, srv Srv) error {
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
for {
a, err := l.Accept()
if err != nil {
return err
}
go func(nc net.Conn, srv Srv) {
defer nc.Close()
read := bufio.NewReader(nc)
err := ServeReadWriter(read, nc, srv)
if err != nil {
log.Printf("%v\n", err)
}
}(a, srv)
}
}
func ServeTLS(addr string, srvcert tls.Certificate, ca *x509.Certificate, withauth bool, srv Srv) error {
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
certpool := x509.NewCertPool()
certpool.AddCert(ca)
var clientAuth tls.ClientAuthType
if withauth {
clientAuth = tls.RequireAndVerifyClientCert
} else {
clientAuth = tls.NoClientCert
}
tlsl := tls.NewListener(l, &tls.Config{
Certificates: []tls.Certificate{srvcert},
ClientCAs: certpool,
ClientAuth: clientAuth,
})
for {
a, err := tlsl.Accept()
if err != nil {
return err
}
go func(nc net.Conn, srv Srv) {
//fmt.Printf("SERVER GOT CONNECTION: %#v\n", nc)
var uname string
if tc, ok := nc.(*tls.Conn); ok && withauth {
err := tc.Handshake()
if err != nil {
fmt.Printf("TLS Error: %v\n", err)
return
}
fmt.Printf("CLIENT: %#v\n", tc.ConnectionState())
fmt.Printf("Client connected as %v\n", tc.ConnectionState().PeerCertificates[0].Subject.CommonName)
uname = tc.ConnectionState().PeerCertificates[0].Subject.CommonName
}
defer nc.Close()
read := bufio.NewReader(nc)
err := handleIOAsync(read, nc, uname, srv)
if err != nil {
log.Printf("%v\n", err)
}
}(a, srv)
}
}
// PostSrv serves srv, from a file descriptor named name.
// The fd is posted and can subsequently be mounted. On Unix, the
// descriptor is posted under in the current namespace, which is
// determined by 9fans.net/go/plan9/client Namespace. On Plan9 it
// is posted in the usual place, /srv.
func PostSrv(name string, srv Srv) error {
f, handle, err := postfd(name)
if err != nil {
return err
}
defer f.Close()
if handle != nil {
defer handle.Close()
}
err = ServeReadWriter(f, f, srv)
return err
}