From 80ae592c6125096f6ada34672beafc3dfdf881cb Mon Sep 17 00:00:00 2001 From: Robert Hague Date: Tue, 28 Apr 2026 21:04:49 +0200 Subject: [PATCH] Serialise packets into a buffer A byte array is allocated to hold each plaintext packet. This removes that by adding a buffer for that purpose. --- src/Renci.SshNet/Messages/Message.cs | 136 +++++++----------- src/Renci.SshNet/Session.cs | 39 ++--- .../Classes/SessionTest_Connected.cs | 1 + ...SendsDebugMessageAfterKexInit_StrictKex.cs | 1 + ...DisconnectMessageAfterKexInit_StrictKex.cs | 1 + ...dsIgnoreMessageAfterKexInit_NoStrictKex.cs | 1 + ...endsIgnoreMessageAfterKexInit_StrictKex.cs | 1 + ...sIgnoreMessageBeforeKexInit_NoStrictKex.cs | 1 + ...ndsIgnoreMessageBeforeKexInit_StrictKex.cs | 1 + test/Renci.SshNet.Tests/Common/Extensions.cs | 23 ++- 10 files changed, 103 insertions(+), 102 deletions(-) diff --git a/src/Renci.SshNet/Messages/Message.cs b/src/Renci.SshNet/Messages/Message.cs index daf77a65a..ebb8bc730 100644 --- a/src/Renci.SshNet/Messages/Message.cs +++ b/src/Renci.SshNet/Messages/Message.cs @@ -1,4 +1,7 @@ #nullable enable +using System; +using System.Diagnostics; +using System.Globalization; using System.IO; using Renci.SshNet.Abstractions; @@ -38,116 +41,85 @@ protected override void WriteBytes(SshDataStream stream) base.WriteBytes(stream); } - /// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes]. - internal byte[] GetPacket(byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding = false, int macLength = 0) + /// The number of bytes occupied by the packet in . + /// + /// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes]. + /// + internal int GetPacket(ref byte[] buffer, byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding, int macLength) { const int outboundPacketSequenceSize = 4; var messageLength = BufferCapacity; + ArraySegment payload = default; + if (messageLength == -1 || compressor != null) { - using (var sshDataStream = new SshDataStream(DefaultCapacity)) + using (var sshDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity)) { - // skip: - // * 4 bytes for the outbound packet sequence - // * 4 bytes for the packet data length - // * one byte for the packet padding length - _ = sshDataStream.Seek(outboundPacketSequenceSize + 4 + 1, SeekOrigin.Begin); - - if (compressor != null) - { - // obtain uncompressed message payload - using (var uncompressedDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity)) - { - WriteBytes(uncompressedDataStream); - - // compress message payload - var compressedMessageData = compressor.Compress(uncompressedDataStream.ToArray()); - - // add compressed message payload - sshDataStream.Write(compressedMessageData, 0, compressedMessageData.Length); - } - } - else - { - // add message payload - WriteBytes(sshDataStream); - } - - messageLength = (int)sshDataStream.Length - (outboundPacketSequenceSize + 4 + 1); - - var packetLength = messageLength + 4 + 1; - - // determine the padding length - // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the - // padding length calculation - var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength); + WriteBytes(sshDataStream); - var packetDataLength = GetPacketDataLength(messageLength, paddingLength); + var success = sshDataStream.TryGetBuffer(out payload); - // skip bytes for outbound packet sequence - _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); + Debug.Assert(success); + } - // add packet data length - sshDataStream.Write(packetDataLength); + if (compressor != null) + { + payload = new(compressor.Compress(payload.Array, payload.Offset, payload.Count)); + } - // add packet padding length - sshDataStream.WriteByte(paddingLength); + messageLength = payload.Count; + } - _ = sshDataStream.Seek(0, SeekOrigin.End); + // determine the padding length + // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the + // padding length calculation + var paddingLength = GetPaddingLength( + paddingMultiplier, (excludePacketLengthFieldWhenPadding ? 0 : 4) + 1 + messageLength); - sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength); + var packetLength = 1 + messageLength + paddingLength; - var buffer = sshDataStream.ToArray(); + var bytesRequired = 4 + 4 + packetLength + macLength; - // add padding bytes - CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); + if ((uint)bytesRequired > (uint)Session.MaximumSshPacketSize) + { + throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", Session.MaximumSshPacketSize)); + } - return buffer; - } + if (buffer.Length < bytesRequired) + { + Array.Resize(ref buffer, Math.Max(bytesRequired, 2 * buffer.Length)); } - else + + using (var sshDataStream = new SshDataStream(buffer)) { - var packetLength = messageLength + 4 + 1; + // skip bytes for outbound packet sequenceSize + _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); - // determine the padding length - // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the - // padding length calculation - var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength); + // add packet length + sshDataStream.Write((uint)packetLength); - var packetDataLength = GetPacketDataLength(messageLength, paddingLength); + // add padding length + sshDataStream.WriteByte(paddingLength); - // lets construct an SSH data stream of the exact size required - using (var sshDataStream = new SshDataStream(packetLength + paddingLength + outboundPacketSequenceSize + macLength)) + // add message payload + if (payload != default) + { + sshDataStream.Write(payload.Array!, payload.Offset, payload.Count); + } + else { - // skip bytes for outbound packet sequenceSize - _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); - - // add packet data length - sshDataStream.Write(packetDataLength); - - // add packet padding length - sshDataStream.WriteByte(paddingLength); - - // add message payload WriteBytes(sshDataStream); + } - sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength); - - var buffer = sshDataStream.ToArray(); - - // add padding bytes - CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); + Debug.Assert(sshDataStream.Position == bytesRequired - macLength - paddingLength); - return buffer; - } + // add padding bytes + CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); } - } - private static uint GetPacketDataLength(int messageLength, byte paddingLength) - { - return (uint)(messageLength + paddingLength + 1); + return bytesRequired; } private static byte GetPaddingLength(byte paddingMultiplier, long packetLength) diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index c73c3db26..5468b53c8 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -212,6 +212,7 @@ private uint InboundPacketSequence private Socket _socket; private ArrayBuffer _receiveBuffer = new(4 * 1024); + private byte[] _sendBuffer = new byte[4 * 1024]; /// /// Gets the session semaphore that controls session channels. @@ -1073,29 +1074,29 @@ internal void SendMessage(Message message) macLength = _clientMac.HashSize / 8; } - var packetData = message.GetPacket(paddingMultiplier, _clientCompression, _clientEtm || _clientAead, macLength); - - if (packetData.Length > MaximumSshPacketSize) - { - throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", MaximumSshPacketSize)); - } - // take a write lock to ensure the outbound packet sequence number is incremented // atomically, and only after the packet has actually been sent lock (_socketWriteLock) { + var activeBufferLength = message.GetPacket( + ref _sendBuffer, + paddingMultiplier, + _clientCompression, + _clientEtm || _clientAead, + macLength); + // write outbound packet sequence to start of packet data - BinaryPrimitives.WriteUInt32BigEndian(packetData, _outboundPacketSequence); + BinaryPrimitives.WriteUInt32BigEndian(_sendBuffer, _outboundPacketSequence); if (_clientMac != null && !_clientEtm) { // non-ETM mac = MAC(key, sequence_number || unencrypted_packet) var hashSuccess = _clientMac.TryComputeHash( - buffer: packetData, + buffer: _sendBuffer, offset: 0, - count: packetData.Length - macLength, - destination: packetData.AsSpan(packetData.Length - macLength), + count: activeBufferLength - macLength, + destination: _sendBuffer.AsSpan(activeBufferLength - macLength), bytesWritten: out var bytesWritten); Debug.Assert(hashSuccess && bytesWritten == macLength); @@ -1110,13 +1111,13 @@ internal void SendMessage(Message message) var offset = _clientEtm ? 8 : 4; var numberOfBytesEncrypted = _clientCipher.Encrypt( - input: packetData, + input: _sendBuffer, offset, - length: packetData.Length - offset - macLength, - output: packetData, + length: activeBufferLength - offset - macLength, + output: _sendBuffer, outputOffset: offset); - Debug.Assert(numberOfBytesEncrypted == packetData.Length - offset - macLength + (_clientAead ? macLength : 0)); + Debug.Assert(numberOfBytesEncrypted == activeBufferLength - offset - macLength + (_clientAead ? macLength : 0)); } if (_clientMac != null && _clientEtm) @@ -1124,16 +1125,16 @@ internal void SendMessage(Message message) // ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet) var hashSuccess = _clientMac.TryComputeHash( - buffer: packetData, + buffer: _sendBuffer, offset: 0, - count: packetData.Length - macLength, - destination: packetData.AsSpan(packetData.Length - macLength), + count: activeBufferLength - macLength, + destination: _sendBuffer.AsSpan(activeBufferLength - macLength), bytesWritten: out var bytesWritten); Debug.Assert(hashSuccess && bytesWritten == macLength); } - SendPacket(packetData, 4, packetData.Length - 4); + SendPacket(_sendBuffer, 4, activeBufferLength - 4); if (_isStrictKex && message is NewKeysMessage) { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs index d13919b05..f3dc36e3c 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs @@ -11,6 +11,7 @@ using Renci.SshNet.Messages.Connection; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs index 257def53c..22d2f71ae 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs @@ -6,6 +6,7 @@ using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs index 54b6f90b8..a281fbe54 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs @@ -4,6 +4,7 @@ using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs index f20d81d8a..50fe58e6e 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs @@ -3,6 +3,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs index de1cc2741..1d4c2381d 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs @@ -4,6 +4,7 @@ using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs index c85d925b7..b517a1cd8 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs @@ -3,6 +3,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs index 41f00b735..a18e46e76 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs @@ -4,6 +4,7 @@ using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { diff --git a/test/Renci.SshNet.Tests/Common/Extensions.cs b/test/Renci.SshNet.Tests/Common/Extensions.cs index 977c182dc..ce8053435 100644 --- a/test/Renci.SshNet.Tests/Common/Extensions.cs +++ b/test/Renci.SshNet.Tests/Common/Extensions.cs @@ -1,6 +1,10 @@ -using System.Collections.Generic; +#nullable enable +using System; +using System.Collections.Generic; using Renci.SshNet.Common; +using Renci.SshNet.Compression; +using Renci.SshNet.Messages; namespace Renci.SshNet.Tests.Common { @@ -21,5 +25,22 @@ public static string AsString(this IList exceptionEvents) return reportedExceptions; } + + /// [4 bytes] || packet_len || padding_len || payload || padding. + public static byte[] GetPacket(this Message message, byte paddingMultiplier, Compressor? compressor) + { + var buffer = Array.Empty(); + + var byteCount = message.GetPacket( + ref buffer, + paddingMultiplier, + compressor, + excludePacketLengthFieldWhenPadding: false, + macLength: 0); + + Array.Resize(ref buffer, byteCount); + + return buffer; + } } }