diff --git a/nas/listener.go b/nas/listener.go index 07c92a4..b367254 100644 --- a/nas/listener.go +++ b/nas/listener.go @@ -14,16 +14,34 @@ import ( "github.com/logrusorgru/aurora/v3" ) -type nasListener struct { +type httpListener struct { net.Listener } -type nasConn struct { +type nasInConn struct { net.Conn - reader io.Reader + in io.Reader } -func (l *nasListener) Accept() (net.Conn, error) { +type nasIOConn struct { + net.Conn + in io.Reader + out io.Writer +} + +func (c nasInConn) Read(b []byte) (int, error) { + return c.in.Read(b) +} + +func (c nasIOConn) Read(b []byte) (int, error) { + return c.in.Read(b) +} + +func (c nasIOConn) Write(b []byte) (int, error) { + return c.out.Write(b) +} + +func (l *httpListener) Accept() (net.Conn, error) { conn, err := l.Listener.Accept() if err != nil { return nil, err @@ -40,16 +58,12 @@ func (l *nasListener) Accept() (net.Conn, error) { _, err = io.Copy(pw, r) _ = pw.CloseWithError(err) }() - return &nasConn{ - Conn: conn, - reader: pr, + return &nasInConn{ + Conn: conn, + in: pr, }, nil } -func (c *nasConn) Read(b []byte) (int, error) { - return c.reader.Read(b) -} - // filterDuplicateHost wraps a net.Conn and filters out duplicate Host headers from the HTTP request, // making the invalid requests sent by DWC acceptable to the standard library's HTTP server. func filterDuplicateHost(r *bufio.Reader) []byte { @@ -84,12 +98,12 @@ func filterDuplicateHost(r *bufio.Reader) []byte { return []byte(line + headers.String()) } -func listenAndServe(address string) { - logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(address)) +func listenAndServe() { + logging.Notice("NAS", "Starting HTTP server on", aurora.BrightCyan(server.Addr)) - l, err := net.Listen("tcp", address) + l, err := net.Listen("tcp", server.Addr) common.ShouldNotError(err) - listener := &nasListener{Listener: l} + listener := &httpListener{Listener: l} defer func() { common.ShouldNotError(listener.Close()) @@ -100,3 +114,40 @@ func listenAndServe(address string) { panic(err) } } + +type tlsListener struct { + net.Listener +} + +func (l *tlsListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + readFromServer, writeFromServer := io.Pipe() + readToServer, writeToServer := io.Pipe() + go handleIncomingTLS(conn, readFromServer, writeToServer) + return &nasIOConn{ + Conn: conn, + in: readToServer, + out: writeFromServer, + }, nil +} + +func listenAndServeTLS() { + logging.Notice("NAS", "Starting HTTPS server on", aurora.BrightCyan(tlsServer.Addr)) + + l, err := net.Listen("tcp", tlsServer.Addr) + common.ShouldNotError(err) + listener := &tlsListener{Listener: l} + + defer func() { + common.ShouldNotError(listener.Close()) + }() + + err = tlsServer.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + panic(err) + } +} diff --git a/nas/main.go b/nas/main.go index 9b8c1e9..14134c1 100644 --- a/nas/main.go +++ b/nas/main.go @@ -18,7 +18,7 @@ import ( var ( serverName string - server *http.Server + server, tlsServer *http.Server payloadServerAddress string ) @@ -37,7 +37,7 @@ func StartServer(reload bool) { payloadServerAddress = config.PayloadServerAddress if config.EnableHTTPS { - go startHTTPSProxy(config) + go setupTLS(config) } err := CacheProfanityFile() @@ -81,7 +81,20 @@ func StartServer(reload bool) { sakeMux.HandleFunc("/", handleUnknown) raceMux.HandleFunc("/", handleUnknown) - go listenAndServe(address) + go listenAndServe() + if config.EnableHTTPS { + tlsServer = &http.Server{ + Addr: *config.NASAddressHTTPS + ":" + config.NASPortHTTPS, + Handler: server.Handler, + IdleTimeout: server.IdleTimeout, + ReadTimeout: server.ReadTimeout, + } + + go func() { + setupTLS(config) + listenAndServeTLS() + }() + } } func Shutdown() { diff --git a/nas/https.go b/nas/tls.go similarity index 75% rename from nas/https.go rename to nas/tls.go index 5006760..84641a2 100644 --- a/nas/https.go +++ b/nas/tls.go @@ -27,24 +27,6 @@ import ( "github.com/logrusorgru/aurora/v3" ) -// Buffered conn for passing to regular TLS after peeking the client hello -type bufferedConn struct { - r *bufio.Reader - net.Conn -} - -func newBufferedConn(c net.Conn) bufferedConn { - return bufferedConn{bufio.NewReader(c), c} -} - -func (b bufferedConn) Peek(n int) ([]byte, error) { - return b.r.Peek(n) -} - -func (b bufferedConn) Read(p []byte) (int, error) { - return b.r.Read(p) -} - // Bare minimum TLS 1.0 server implementation for the Wii's /dev/net/ssl client // Use this with a certificate that exploits the Wii's SSL certificate bug to impersonate naswii.nintendowifi.net // See here: https://github.com/shutterbug2000/wii-ssl-bug @@ -52,20 +34,20 @@ func (b bufferedConn) Read(p []byte) (int, error) { // Don't use this for anything else, it's not secure -func startHTTPSProxy(config common.Config) { - address := *config.NASAddressHTTPS + ":" + config.NASPortHTTPS - nasAddr := *config.NASAddress + ":" + config.NASPort +var ( + rsaKeyWii *rsa.PrivateKey + serverCertsRecordWii []byte + rsaKeyDS *rsa.PrivateKey + serverCertsRecordDS []byte + realTLSConfig *tls.Config +) + +func setupTLS(config common.Config) { privKeyPath := config.KeyPath certsPath := config.CertPath exploitWii := *config.EnableHTTPSExploitWii exploitDS := *config.EnableHTTPSExploitDS - logging.Notice("NAS-TLS", "Starting HTTPS server on", aurora.BrightCyan(address)) - l, err := net.Listen("tcp", address) - if err != nil { - panic(err) - } - setupRealTLS(privKeyPath, certsPath) // Reread the private key and certs on a regular interval go func() { @@ -75,181 +57,210 @@ func startHTTPSProxy(config common.Config) { } }() - if !exploitWii && !exploitDS { - // Only handle real TLS requests - for { - conn, err := l.Accept() - if err != nil { - panic(err) - } - - go func() { - moduleName := "NAS-TLS:" + conn.RemoteAddr().String() - - _ = conn.SetDeadline(time.Now().UTC().Add(25 * time.Second)) - - handleRealTLS(moduleName, conn, nasAddr) - }() - } - } - // Handle requests from Wii, DS and regular TLS - var rsaKeyWii *rsa.PrivateKey - var serverCertsRecordWii []byte + if exploitWii { - certWii, err := os.ReadFile(config.CertPathWii) - if err != nil { - panic(err) - } - - rsaDataWii, err := os.ReadFile(config.KeyPathWii) - if err != nil { - panic(err) - } - - rsaBlockWii, _ := pem.Decode(rsaDataWii) - parsedKeyWii, err := x509.ParsePKCS8PrivateKey(rsaBlockWii.Bytes) - if err != nil { - panic(err) - } - - var ok bool - rsaKeyWii, ok = parsedKeyWii.(*rsa.PrivateKey) - if !ok { - panic("unexpected key type") - } - - serverCertsRecordWii = []byte{0x16, 0x03, 0x01} - - // Length of the record - certLenWii := uint32(len(certWii)) - serverCertsRecordWii = append(serverCertsRecordWii, []byte{ - byte((certLenWii + 10) >> 8), - byte(certLenWii + 10), - }...) - - serverCertsRecordWii = append(serverCertsRecordWii, 0xB) - - serverCertsRecordWii = append(serverCertsRecordWii, []byte{ - byte((certLenWii + 6) >> 16), - byte((certLenWii + 6) >> 8), - byte(certLenWii + 6), - }...) - - serverCertsRecordWii = append(serverCertsRecordWii, []byte{ - byte((certLenWii + 3) >> 16), - byte((certLenWii + 3) >> 8), - byte(certLenWii + 3), - }...) - - serverCertsRecordWii = append(serverCertsRecordWii, []byte{ - byte(certLenWii >> 16), - byte(certLenWii >> 8), - byte(certLenWii), - }...) - - serverCertsRecordWii = append(serverCertsRecordWii, certWii...) - - serverCertsRecordWii = append(serverCertsRecordWii, []byte{ - 0x16, 0x03, 0x01, 0x00, 0x04, 0x0E, 0x00, 0x00, 0x00, - }...) + setupExploitWii(config) } - var rsaKeyDS *rsa.PrivateKey - var serverCertsRecordDS []byte if exploitDS { - certDS, err := os.ReadFile(config.CertPathDS) - if err != nil { - panic(err) - } - - rsaDataDS, err := os.ReadFile(config.KeyPathDS) - if err != nil { - panic(err) - } - - rsaBlockDS, _ := pem.Decode(rsaDataDS) - parsedKeyDS, err := x509.ParsePKCS8PrivateKey(rsaBlockDS.Bytes) - if err != nil { - panic(err) - } - - var ok bool - rsaKeyDS, ok = parsedKeyDS.(*rsa.PrivateKey) - if !ok { - panic("unexpected key type") - } - - wiiCertDS, err := os.ReadFile(config.WiiCertPathDS) - if err != nil { - panic(err) - } - - serverCertsRecordDS = []byte{0x16, 0x03, 0x00} - - // Length of the record - certLenDS := uint32(len(certDS)) - wiiCertLenDS := uint32(len(wiiCertDS)) - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - byte((certLenDS + wiiCertLenDS + 13) >> 8), - byte(certLenDS + wiiCertLenDS + 13), - }...) - - serverCertsRecordDS = append(serverCertsRecordDS, 0xB) - - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - byte((certLenDS + wiiCertLenDS + 9) >> 16), - byte((certLenDS + wiiCertLenDS + 9) >> 8), - byte(certLenDS + wiiCertLenDS + 9), - }...) - - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - byte((certLenDS + wiiCertLenDS + 6) >> 16), - byte((certLenDS + wiiCertLenDS + 6) >> 8), - byte(certLenDS + wiiCertLenDS + 6), - }...) - - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - byte(certLenDS >> 16), - byte(certLenDS >> 8), - byte(certLenDS), - }...) - - serverCertsRecordDS = append(serverCertsRecordDS, certDS...) - - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - byte(wiiCertLenDS >> 16), - byte(wiiCertLenDS >> 8), - byte(wiiCertLenDS), - }...) - - serverCertsRecordDS = append(serverCertsRecordDS, wiiCertDS...) - - serverCertsRecordDS = append(serverCertsRecordDS, []byte{ - 0x16, 0x03, 0x00, 0x00, 0x04, 0x0E, 0x00, 0x00, 0x00, - }...) - } - - for { - conn, err := l.Accept() - if err != nil { - panic(err) - } - - go func() { - // logging.Info("NAS-TLS", "Receiving HTTPS request from", aurora.BrightCyan(conn.RemoteAddr())) - - moduleName := "NAS-TLS:" + conn.RemoteAddr().String() - - _ = conn.SetDeadline(time.Now().UTC().Add(5 * time.Second)) - - handleTLS(moduleName, conn, nasAddr, serverCertsRecordWii, rsaKeyWii, serverCertsRecordDS, rsaKeyDS) - }() + setupExploitDS(config) } } +func setupRealTLS(privKeyPath string, certsPath string) { + // Read server key and certs + + serverKey, err := os.ReadFile(privKeyPath) + if err != nil { + logging.Error("NAS-TLS", "Failed to read server key:", err) + return + } + + serverCerts, err := os.ReadFile(certsPath) + if err != nil { + logging.Error("NAS-TLS", "Failed to read server certs:", err) + return + } + + cert, err := tls.X509KeyPair(serverCerts, serverKey) + if err != nil { + logging.Error("NAS-TLS", "Failed to parse server certs:", err) + return + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + realTLSConfig = &config +} + +func setupExploitWii(config common.Config) { + certWii, err := os.ReadFile(config.CertPathWii) + if err != nil { + panic(err) + } + + rsaDataWii, err := os.ReadFile(config.KeyPathWii) + if err != nil { + panic(err) + } + + rsaBlockWii, _ := pem.Decode(rsaDataWii) + parsedKeyWii, err := x509.ParsePKCS8PrivateKey(rsaBlockWii.Bytes) + if err != nil { + panic(err) + } + + var ok bool + rsaKeyWii, ok = parsedKeyWii.(*rsa.PrivateKey) + if !ok { + panic("unexpected key type") + } + + serverCertsRecordWii = []byte{0x16, 0x03, 0x01} + + // Length of the record + certLenWii := uint32(len(certWii)) + serverCertsRecordWii = append(serverCertsRecordWii, []byte{ + byte((certLenWii + 10) >> 8), + byte(certLenWii + 10), + }...) + + serverCertsRecordWii = append(serverCertsRecordWii, 0xB) + + serverCertsRecordWii = append(serverCertsRecordWii, []byte{ + byte((certLenWii + 6) >> 16), + byte((certLenWii + 6) >> 8), + byte(certLenWii + 6), + }...) + + serverCertsRecordWii = append(serverCertsRecordWii, []byte{ + byte((certLenWii + 3) >> 16), + byte((certLenWii + 3) >> 8), + byte(certLenWii + 3), + }...) + + serverCertsRecordWii = append(serverCertsRecordWii, []byte{ + byte(certLenWii >> 16), + byte(certLenWii >> 8), + byte(certLenWii), + }...) + + serverCertsRecordWii = append(serverCertsRecordWii, certWii...) + + serverCertsRecordWii = append(serverCertsRecordWii, []byte{ + 0x16, 0x03, 0x01, 0x00, 0x04, 0x0E, 0x00, 0x00, 0x00, + }...) +} + +func setupExploitDS(config common.Config) { + certDS, err := os.ReadFile(config.CertPathDS) + if err != nil { + panic(err) + } + + rsaDataDS, err := os.ReadFile(config.KeyPathDS) + if err != nil { + panic(err) + } + + rsaBlockDS, _ := pem.Decode(rsaDataDS) + parsedKeyDS, err := x509.ParsePKCS8PrivateKey(rsaBlockDS.Bytes) + if err != nil { + panic(err) + } + + var ok bool + rsaKeyDS, ok = parsedKeyDS.(*rsa.PrivateKey) + if !ok { + panic("unexpected key type") + } + + wiiCertDS, err := os.ReadFile(config.WiiCertPathDS) + if err != nil { + panic(err) + } + + serverCertsRecordDS = []byte{0x16, 0x03, 0x00} + + // Length of the record + certLenDS := uint32(len(certDS)) + wiiCertLenDS := uint32(len(wiiCertDS)) + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + byte((certLenDS + wiiCertLenDS + 13) >> 8), + byte(certLenDS + wiiCertLenDS + 13), + }...) + + serverCertsRecordDS = append(serverCertsRecordDS, 0xB) + + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + byte((certLenDS + wiiCertLenDS + 9) >> 16), + byte((certLenDS + wiiCertLenDS + 9) >> 8), + byte(certLenDS + wiiCertLenDS + 9), + }...) + + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + byte((certLenDS + wiiCertLenDS + 6) >> 16), + byte((certLenDS + wiiCertLenDS + 6) >> 8), + byte(certLenDS + wiiCertLenDS + 6), + }...) + + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + byte(certLenDS >> 16), + byte(certLenDS >> 8), + byte(certLenDS), + }...) + + serverCertsRecordDS = append(serverCertsRecordDS, certDS...) + + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + byte(wiiCertLenDS >> 16), + byte(wiiCertLenDS >> 8), + byte(wiiCertLenDS), + }...) + + serverCertsRecordDS = append(serverCertsRecordDS, wiiCertDS...) + + serverCertsRecordDS = append(serverCertsRecordDS, []byte{ + 0x16, 0x03, 0x00, 0x00, 0x04, 0x0E, 0x00, 0x00, 0x00, + }...) +} + +// handleIncomingTLS handles incoming requests from the listener. +func handleIncomingTLS(rawConn net.Conn, fromServer *io.PipeReader, toServer *io.PipeWriter) { + moduleName := "NAS-TLS:" + rawConn.RemoteAddr().String() + + var err error + defer func() { + _ = rawConn.Close() + toServer.CloseWithError(err) + }() + + if rsaKeyWii == nil && rsaKeyDS == nil { + // Only handle real TLS requests + err = handleRealTLS(moduleName, rawConn, fromServer, toServer) + } else { + err = handleTLS(moduleName, rawConn, fromServer, toServer) + } + +} + +func peekBytes(r *bufio.Reader, expected []byte) (bool, error) { + buf := make([]byte, len(expected)) + _, err := r.Peek(len(expected)) + if err != nil { + if err == io.EOF { + return false, nil + } + return false, err + } + return bytes.Equal(buf, expected), nil +} + // handleTLS handles the TLS request from the Wii or the DS. It may call handleRealTLS if the request is from a modern web browser. -func handleTLS(moduleName string, rawConn net.Conn, nasAddr string, serverCertsRecordWii []byte, rsaKeyWii *rsa.PrivateKey, serverCertsRecordDS []byte, rsaKeyDS *rsa.PrivateKey) { +func handleTLS(moduleName string, rawConn net.Conn, fromServer *io.PipeReader, toServer *io.PipeWriter) error { // Recover from panics defer func() { if r := recover(); r != nil { @@ -257,78 +268,115 @@ func handleTLS(moduleName string, rawConn net.Conn, nasAddr string, serverCertsR } }() - conn := newBufferedConn(rawConn) - - defer func() { - _ = conn.Close() - }() + r := bufio.NewReader(rawConn) + conn := nasInConn{ + Conn: rawConn, + in: r, + } // Read client hello - // fmt.Printf("Client Hello:\n") - var helloBytes []byte - if rsaKeyWii != nil { - index := 0 - for index = 0; index < 0x1D; index++ { - var err error - helloBytes, err = conn.Peek(index + 1) - if err != nil { - logging.Error(moduleName, "Failed to peek from client:", err) - return - } - - if helloBytes[index] != []byte{ + if rsaKeyWii != nil || rsaKeyDS != nil { + var peekMatchWii, peekMatchDS bool + var err error + if rsaKeyWii != nil { + peekMatchWii, err = peekBytes(r, []byte{ 0x80, 0x2B, 0x01, 0x03, 0x01, 0x00, 0x12, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x35, 0x00, 0x00, 0x2F, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x09, 0x00, 0x00, 0x05, 0x00, 0x00, 0x04, - }[index] { - break - } + }) } - if index == 0x1D { - macFn, cipher, clientCipher, err := handleWiiTLSHandshake(moduleName, conn, serverCertsRecordWii, rsaKeyWii) - if err == nil && macFn != nil && cipher != nil && clientCipher != nil { - proxyConsoleTLS(moduleName, conn, nasAddr, VersionTLS10, macFn, cipher, clientCipher) - } - return - } - } - - if rsaKeyDS != nil { - index := 0 - for index = 0; index < 0x0B; index++ { - var err error - helloBytes, err = conn.Peek(index + 1) - if err != nil { - logging.Error(moduleName, "Failed to peek from client:", err) - return - } - - if helloBytes[index] != []byte{ + if rsaKeyDS != nil && !peekMatchWii && err == nil { + peekMatchDS, err = peekBytes(r, []byte{ 0x16, 0x03, 0x00, 0x00, 0x2F, 0x01, 0x00, 0x00, 0x2B, 0x03, 0x00, - }[index] { - break - } + }) } - if index == 0x0B { - macFn, cipher, clientCipher, err := handleDSSSLHandshake(moduleName, conn, serverCertsRecordDS, rsaKeyDS) + + switch { + case err != nil: + return err + case peekMatchWii: + macFn, cipher, clientCipher, err := handleWiiTLSHandshake(moduleName, conn) if err == nil && macFn != nil && cipher != nil && clientCipher != nil { - proxyConsoleTLS(moduleName, conn, nasAddr, VersionSSL30, macFn, cipher, clientCipher) + err = proxyConsoleTLS(moduleName, conn, fromServer, toServer, VersionTLS10, macFn, cipher, clientCipher) } - return + return err + case peekMatchDS: + macFn, cipher, clientCipher, err := handleDSSSLHandshake(moduleName, conn) + if err == nil && macFn != nil && cipher != nil && clientCipher != nil { + err = proxyConsoleTLS(moduleName, conn, fromServer, toServer, VersionSSL30, macFn, cipher, clientCipher) + } + return err } } _ = conn.SetDeadline(time.Now().UTC().Add(25 * time.Second)) - // logging.Info(moduleName, "Forwarding client hello:", aurora.Cyan(fmt.Sprintf("% X ", helloBytes))) - handleRealTLS(moduleName, conn, nasAddr) + return handleRealTLS(moduleName, conn, fromServer, toServer) } -func handleWiiTLSHandshake(moduleName string, conn bufferedConn, serverCertsRecord []byte, rsaKey *rsa.PrivateKey) (macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher, err error) { +// handleRealTLS handles the TLS request legitimately using crypto/tls +func handleRealTLS(moduleName string, rawConn net.Conn, fromServer *io.PipeReader, toServer *io.PipeWriter) error { + // Recover from panics + defer func() { + if r := recover(); r != nil { + logging.Error(moduleName, "Panic:", r) + if r.(error) != nil { + toServer.CloseWithError(r.(error)) + } + } + }() + + if realTLSConfig == nil { + return errors.New("realTLSConfig is not set") + } + + tlsConn := tls.Server(rawConn, realTLSConfig) + + err := tlsConn.Handshake() + if err != nil { + return err + } + + // Read bytes from the HTTP server and forward them through the TLS connection + go func() { + recvBuf := make([]byte, 0x100) + + for { + n, err := fromServer.Read(recvBuf) + if err != nil { + toServer.CloseWithError(err) + return + } + + _, err = tlsConn.Write(recvBuf[:n]) + if err != nil { + toServer.CloseWithError(err) + return + } + } + }() + + // Read encrypted content from the client and forward it to the HTTP server + buf := make([]byte, 0x1000) + for { + n, err := tlsConn.Read(buf) + if err != nil { + return err + } + + _, err = toServer.Write(buf[:n]) + if err != nil { + toServer.CloseWithError(err) + return err + } + } +} + +func handleWiiTLSHandshake(moduleName string, conn nasInConn) (macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher, err error) { // fmt.Printf("\n") clientHello := make([]byte, 0x2D) - _, err = io.ReadFull(conn.r, clientHello) + _, err = io.ReadFull(conn.in, clientHello) if err != nil { logging.Error(moduleName, "Failed to read from client:", err) return @@ -360,13 +408,13 @@ func handleWiiTLSHandshake(moduleName string, conn bufferedConn, serverCertsReco }...) // Append the certs record to the server hello buffer - serverHello = append(serverHello, serverCertsRecord...) + serverHello = append(serverHello, serverCertsRecordWii...) // fmt.Printf("Server Hello:\n% X\n", serverHello) finishHash.Write(serverHello[0x5:0x2F]) - finishHash.Write(serverHello[0x34 : 0x34+(len(serverCertsRecord)-14)]) - finishHash.Write(serverHello[0x34+(len(serverCertsRecord)-14)+5 : 0x34+(len(serverCertsRecord)-14)+5+4]) + finishHash.Write(serverHello[0x34 : 0x34+(len(serverCertsRecordWii)-14)]) + finishHash.Write(serverHello[0x34+(len(serverCertsRecordWii)-14)+5 : 0x34+(len(serverCertsRecordWii)-14)+5+4]) _, err = conn.Write(serverHello) if err != nil { @@ -428,7 +476,8 @@ func handleWiiTLSHandshake(moduleName string, conn bufferedConn, serverCertsReco finishHash.Write(buf[0x5 : 0x5+0x86]) // Decrypt the pre master secret using our RSA key - preMasterSecret, err := rsa.DecryptPKCS1v15(rand.Reader, rsaKey, encryptedPreMasterSecret) + var preMasterSecret []byte + preMasterSecret, err = rsa.DecryptPKCS1v15(rand.Reader, rsaKeyWii, encryptedPreMasterSecret) if err != nil { logging.Error(moduleName, "Failed to decrypt pre master secret:", err) return @@ -488,7 +537,7 @@ func handleWiiTLSHandshake(moduleName string, conn bufferedConn, serverCertsReco // Send ChangeCipherSpec _, err = conn.Write([]byte{0x14, 0x03, 0x01, 0x00, 0x01, 0x01}) if err != nil { - panic(err) + return } finishedRecord := []byte{0x16, 0x03, 0x01, 0x00, 0x10} @@ -503,9 +552,9 @@ func handleWiiTLSHandshake(moduleName string, conn bufferedConn, serverCertsReco return } -func handleDSSSLHandshake(moduleName string, conn bufferedConn, serverCertsRecord []byte, rsaKey *rsa.PrivateKey) (macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher, err error) { +func handleDSSSLHandshake(moduleName string, conn nasInConn) (macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher, err error) { clientHello := make([]byte, 0x34) - _, err = io.ReadFull(conn.r, clientHello) + _, err = io.ReadFull(conn.in, clientHello) if err != nil { logging.Error(moduleName, "Failed to read from client:", err) return @@ -536,13 +585,13 @@ func handleDSSSLHandshake(moduleName string, conn bufferedConn, serverCertsRecor }...) // Append the certs record to the server hello buffer - serverHello = append(serverHello, serverCertsRecord...) + serverHello = append(serverHello, serverCertsRecordDS...) // fmt.Printf("Server Hello:\n% X\n", serverHello) finishHash.Write(serverHello[0x5:0x2F]) - finishHash.Write(serverHello[0x34 : 0x34+(len(serverCertsRecord)-14)]) - finishHash.Write(serverHello[0x34+(len(serverCertsRecord)-14)+5 : 0x34+(len(serverCertsRecord)-14)+5+4]) + finishHash.Write(serverHello[0x34 : 0x34+(len(serverCertsRecordDS)-14)]) + finishHash.Write(serverHello[0x34+(len(serverCertsRecordDS)-14)+5 : 0x34+(len(serverCertsRecordDS)-14)+5+4]) _, err = conn.Write(serverHello) if err != nil { @@ -604,7 +653,8 @@ func handleDSSSLHandshake(moduleName string, conn bufferedConn, serverCertsRecor finishHash.Write(buf[0x5 : 0x5+0x84]) // Decrypt the pre master secret using our RSA key - preMasterSecret, err := rsa.DecryptPKCS1v15(rand.Reader, rsaKey, encryptedPreMasterSecret) + var preMasterSecret []byte + preMasterSecret, err = rsa.DecryptPKCS1v15(rand.Reader, rsaKeyDS, encryptedPreMasterSecret) if err != nil { logging.Error(moduleName, "Failed to decrypt pre master secret:", err) return @@ -664,7 +714,7 @@ func handleDSSSLHandshake(moduleName string, conn bufferedConn, serverCertsRecor // Send ChangeCipherSpec _, err = conn.Write([]byte{0x14, 0x03, 0x00, 0x00, 0x01, 0x01}) if err != nil { - panic(err) + return } finishedRecord := []byte{0x16, 0x03, 0x00, 0x00, 0x28} @@ -679,30 +729,17 @@ func handleDSSSLHandshake(moduleName string, conn bufferedConn, serverCertsRecor return } -func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, version uint16, macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher) { - // Open a connection to NAS - newConn, err := net.Dial("tcp", nasAddr) - if err != nil { - panic(err) - } - - defer func() { - _ = newConn.Close() - }() - +func proxyConsoleTLS(moduleName string, conn nasInConn, fromServer *io.PipeReader, toServer *io.PipeWriter, version uint16, macFn macFunction, cipher *rc4.Cipher, clientCipher *rc4.Cipher) error { // Read bytes from the HTTP server and forward them through the TLS connection go func() { recvBuf := make([]byte, 0x100) seq := uint64(1) for { - n, err := newConn.Read(recvBuf) + n, err := fromServer.Read(recvBuf) if err != nil { - if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { - return - } - logging.Error(moduleName, "Failed to read from HTTP server:", err) + toServer.CloseWithError(err) return } @@ -713,6 +750,7 @@ func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, versi _, err = conn.Write(record) if err != nil { logging.Error(moduleName, "Failed to write to client:", err) + toServer.CloseWithError(err) return } } @@ -727,11 +765,11 @@ func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, versi if err != nil { if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { logging.Info(moduleName, "Connection closed by client after", aurora.BrightCyan(total), "bytes") - return + return err } logging.Error(moduleName, "Failed to read from client:", err) - return + return err } // fmt.Printf("Received:\n% X ", buf[index:index+n]) @@ -741,18 +779,18 @@ func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, versi for index < 5 { if buf[0] < 0x15 || buf[0] > 0x17 { logging.Error(moduleName, "Invalid record type") - return + return errors.New("invalid record type") } if buf[1] != 0x03 || (version == VersionTLS10 && buf[2] != 0x01) || (version == VersionSSL30 && buf[2] != 0x00) { logging.Error(moduleName, "Invalid TLS version") - return + return errors.New("invalid TLS version") } recordLength := binary.BigEndian.Uint16(buf[3:5]) if recordLength < 17 || (recordLength+5) > 0x1000 { logging.Error(moduleName, "Invalid record length") - return + return errors.New("invalid record length") } if index < int(recordLength)+5 { @@ -766,17 +804,17 @@ func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, versi if buf[0] != 0x17 { if buf[0] == 0x15 || buf[5] == 0x01 || buf[6] == 0x00 { logging.Info(moduleName, "Alert connection close by client after", aurora.BrightCyan(total), "bytes") - return + return nil } logging.Error(moduleName, "Non-application data received:", aurora.Cyan(fmt.Sprintf("% X ", buf[:5+recordLength]))) - return + return errors.New("non-application data received") } else { // Send the decrypted content to the HTTP server - _, err = newConn.Write(buf[5 : 5+recordLength-16]) + _, err = toServer.Write(buf[5 : 5+recordLength-16]) if err != nil { logging.Error(moduleName, "Failed to write to HTTP server:", err) - return + return err } } @@ -787,99 +825,6 @@ func proxyConsoleTLS(moduleName string, conn bufferedConn, nasAddr string, versi } } -var realTLSConfig *tls.Config - -func setupRealTLS(privKeyPath string, certsPath string) { - // Read server key and certs - - serverKey, err := os.ReadFile(privKeyPath) - if err != nil { - logging.Error("NAS-TLS", "Failed to read server key:", err) - return - } - - serverCerts, err := os.ReadFile(certsPath) - if err != nil { - logging.Error("NAS-TLS", "Failed to read server certs:", err) - return - } - - cert, err := tls.X509KeyPair(serverCerts, serverKey) - if err != nil { - logging.Error("NAS-TLS", "Failed to parse server certs:", err) - return - } - - config := tls.Config{ - Certificates: []tls.Certificate{cert}, - } - - realTLSConfig = &config -} - -// handleRealTLS handles the TLS request legitimately using crypto/tls -func handleRealTLS(moduleName string, conn net.Conn, nasAddr string) { - // Recover from panics - defer func() { - if r := recover(); r != nil { - logging.Error(moduleName, "Panic:", r) - } - }() - - if realTLSConfig == nil { - return - } - - tlsConn := tls.Server(conn, realTLSConfig) - - err := tlsConn.Handshake() - if err != nil { - return - } - - newConn, err := net.Dial("tcp", nasAddr) - if err != nil { - panic(err) - } - - defer func() { - _ = newConn.Close() - }() - - // Read bytes from the HTTP server and forward them through the TLS connection - go func() { - recvBuf := make([]byte, 0x100) - - for { - n, err := newConn.Read(recvBuf) - if err != nil { - return - } - - _, err = tlsConn.Write(recvBuf[:n]) - if err != nil { - logging.Error(moduleName, "Failed to write to client:", err) - return - } - } - }() - - // Read encrypted content from the client and forward it to the HTTP server - buf := make([]byte, 0x1000) - for { - n, err := tlsConn.Read(buf) - if err != nil { - return - } - - _, err = newConn.Write(buf[:n]) - if err != nil { - logging.Error(moduleName, "Failed to write to HTTP server:", err) - return - } - } -} - // The following functions are modified from the crypto standard library // // Copyright (c) 2009 The Go Authors. All rights reserved.