diff --git a/src/Renci.SshNet/Messages/Message.cs b/src/Renci.SshNet/Messages/Message.cs
index daf77a65a..ebb8bc730 100644
--- a/src/Renci.SshNet/Messages/Message.cs
+++ b/src/Renci.SshNet/Messages/Message.cs
@@ -1,4 +1,7 @@
#nullable enable
+using System;
+using System.Diagnostics;
+using System.Globalization;
using System.IO;
using Renci.SshNet.Abstractions;
@@ -38,116 +41,85 @@ protected override void WriteBytes(SshDataStream stream)
base.WriteBytes(stream);
}
- /// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].
- internal byte[] GetPacket(byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding = false, int macLength = 0)
+ /// The number of bytes occupied by the packet in .
+ ///
+ /// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].
+ ///
+ internal int GetPacket(ref byte[] buffer, byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding, int macLength)
{
const int outboundPacketSequenceSize = 4;
var messageLength = BufferCapacity;
+ ArraySegment payload = default;
+
if (messageLength == -1 || compressor != null)
{
- using (var sshDataStream = new SshDataStream(DefaultCapacity))
+ using (var sshDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
{
- // skip:
- // * 4 bytes for the outbound packet sequence
- // * 4 bytes for the packet data length
- // * one byte for the packet padding length
- _ = sshDataStream.Seek(outboundPacketSequenceSize + 4 + 1, SeekOrigin.Begin);
-
- if (compressor != null)
- {
- // obtain uncompressed message payload
- using (var uncompressedDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
- {
- WriteBytes(uncompressedDataStream);
-
- // compress message payload
- var compressedMessageData = compressor.Compress(uncompressedDataStream.ToArray());
-
- // add compressed message payload
- sshDataStream.Write(compressedMessageData, 0, compressedMessageData.Length);
- }
- }
- else
- {
- // add message payload
- WriteBytes(sshDataStream);
- }
-
- messageLength = (int)sshDataStream.Length - (outboundPacketSequenceSize + 4 + 1);
-
- var packetLength = messageLength + 4 + 1;
-
- // determine the padding length
- // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
- // padding length calculation
- var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
+ WriteBytes(sshDataStream);
- var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
+ var success = sshDataStream.TryGetBuffer(out payload);
- // skip bytes for outbound packet sequence
- _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
+ Debug.Assert(success);
+ }
- // add packet data length
- sshDataStream.Write(packetDataLength);
+ if (compressor != null)
+ {
+ payload = new(compressor.Compress(payload.Array, payload.Offset, payload.Count));
+ }
- // add packet padding length
- sshDataStream.WriteByte(paddingLength);
+ messageLength = payload.Count;
+ }
- _ = sshDataStream.Seek(0, SeekOrigin.End);
+ // determine the padding length
+ // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
+ // padding length calculation
+ var paddingLength = GetPaddingLength(
+ paddingMultiplier, (excludePacketLengthFieldWhenPadding ? 0 : 4) + 1 + messageLength);
- sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);
+ var packetLength = 1 + messageLength + paddingLength;
- var buffer = sshDataStream.ToArray();
+ var bytesRequired = 4 + 4 + packetLength + macLength;
- // add padding bytes
- CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
+ if ((uint)bytesRequired > (uint)Session.MaximumSshPacketSize)
+ {
+ throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", Session.MaximumSshPacketSize));
+ }
- return buffer;
- }
+ if (buffer.Length < bytesRequired)
+ {
+ Array.Resize(ref buffer, Math.Max(bytesRequired, 2 * buffer.Length));
}
- else
+
+ using (var sshDataStream = new SshDataStream(buffer))
{
- var packetLength = messageLength + 4 + 1;
+ // skip bytes for outbound packet sequenceSize
+ _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
- // determine the padding length
- // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
- // padding length calculation
- var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
+ // add packet length
+ sshDataStream.Write((uint)packetLength);
- var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
+ // add padding length
+ sshDataStream.WriteByte(paddingLength);
- // lets construct an SSH data stream of the exact size required
- using (var sshDataStream = new SshDataStream(packetLength + paddingLength + outboundPacketSequenceSize + macLength))
+ // add message payload
+ if (payload != default)
+ {
+ sshDataStream.Write(payload.Array!, payload.Offset, payload.Count);
+ }
+ else
{
- // skip bytes for outbound packet sequenceSize
- _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
-
- // add packet data length
- sshDataStream.Write(packetDataLength);
-
- // add packet padding length
- sshDataStream.WriteByte(paddingLength);
-
- // add message payload
WriteBytes(sshDataStream);
+ }
- sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);
-
- var buffer = sshDataStream.ToArray();
-
- // add padding bytes
- CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
+ Debug.Assert(sshDataStream.Position == bytesRequired - macLength - paddingLength);
- return buffer;
- }
+ // add padding bytes
+ CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
}
- }
- private static uint GetPacketDataLength(int messageLength, byte paddingLength)
- {
- return (uint)(messageLength + paddingLength + 1);
+ return bytesRequired;
}
private static byte GetPaddingLength(byte paddingMultiplier, long packetLength)
diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs
index c73c3db26..5468b53c8 100644
--- a/src/Renci.SshNet/Session.cs
+++ b/src/Renci.SshNet/Session.cs
@@ -212,6 +212,7 @@ private uint InboundPacketSequence
private Socket _socket;
private ArrayBuffer _receiveBuffer = new(4 * 1024);
+ private byte[] _sendBuffer = new byte[4 * 1024];
///
/// Gets the session semaphore that controls session channels.
@@ -1073,29 +1074,29 @@ internal void SendMessage(Message message)
macLength = _clientMac.HashSize / 8;
}
- var packetData = message.GetPacket(paddingMultiplier, _clientCompression, _clientEtm || _clientAead, macLength);
-
- if (packetData.Length > MaximumSshPacketSize)
- {
- throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", MaximumSshPacketSize));
- }
-
// take a write lock to ensure the outbound packet sequence number is incremented
// atomically, and only after the packet has actually been sent
lock (_socketWriteLock)
{
+ var activeBufferLength = message.GetPacket(
+ ref _sendBuffer,
+ paddingMultiplier,
+ _clientCompression,
+ _clientEtm || _clientAead,
+ macLength);
+
// write outbound packet sequence to start of packet data
- BinaryPrimitives.WriteUInt32BigEndian(packetData, _outboundPacketSequence);
+ BinaryPrimitives.WriteUInt32BigEndian(_sendBuffer, _outboundPacketSequence);
if (_clientMac != null && !_clientEtm)
{
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)
var hashSuccess = _clientMac.TryComputeHash(
- buffer: packetData,
+ buffer: _sendBuffer,
offset: 0,
- count: packetData.Length - macLength,
- destination: packetData.AsSpan(packetData.Length - macLength),
+ count: activeBufferLength - macLength,
+ destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
bytesWritten: out var bytesWritten);
Debug.Assert(hashSuccess && bytesWritten == macLength);
@@ -1110,13 +1111,13 @@ internal void SendMessage(Message message)
var offset = _clientEtm ? 8 : 4;
var numberOfBytesEncrypted = _clientCipher.Encrypt(
- input: packetData,
+ input: _sendBuffer,
offset,
- length: packetData.Length - offset - macLength,
- output: packetData,
+ length: activeBufferLength - offset - macLength,
+ output: _sendBuffer,
outputOffset: offset);
- Debug.Assert(numberOfBytesEncrypted == packetData.Length - offset - macLength + (_clientAead ? macLength : 0));
+ Debug.Assert(numberOfBytesEncrypted == activeBufferLength - offset - macLength + (_clientAead ? macLength : 0));
}
if (_clientMac != null && _clientEtm)
@@ -1124,16 +1125,16 @@ internal void SendMessage(Message message)
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)
var hashSuccess = _clientMac.TryComputeHash(
- buffer: packetData,
+ buffer: _sendBuffer,
offset: 0,
- count: packetData.Length - macLength,
- destination: packetData.AsSpan(packetData.Length - macLength),
+ count: activeBufferLength - macLength,
+ destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
bytesWritten: out var bytesWritten);
Debug.Assert(hashSuccess && bytesWritten == macLength);
}
- SendPacket(packetData, 4, packetData.Length - 4);
+ SendPacket(_sendBuffer, 4, activeBufferLength - 4);
if (_isStrictKex && message is NewKeysMessage)
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
index d13919b05..f3dc36e3c 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
@@ -11,6 +11,7 @@
using Renci.SshNet.Messages.Connection;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
index 257def53c..22d2f71ae 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
@@ -6,6 +6,7 @@
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
index 54b6f90b8..a281fbe54 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
@@ -4,6 +4,7 @@
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
index f20d81d8a..50fe58e6e 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
@@ -3,6 +3,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
index de1cc2741..1d4c2381d 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
@@ -4,6 +4,7 @@
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
index c85d925b7..b517a1cd8 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
@@ -3,6 +3,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
index 41f00b735..a18e46e76 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
@@ -4,6 +4,7 @@
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
namespace Renci.SshNet.Tests.Classes
{
diff --git a/test/Renci.SshNet.Tests/Common/Extensions.cs b/test/Renci.SshNet.Tests/Common/Extensions.cs
index 977c182dc..ce8053435 100644
--- a/test/Renci.SshNet.Tests/Common/Extensions.cs
+++ b/test/Renci.SshNet.Tests/Common/Extensions.cs
@@ -1,6 +1,10 @@
-using System.Collections.Generic;
+#nullable enable
+using System;
+using System.Collections.Generic;
using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Messages;
namespace Renci.SshNet.Tests.Common
{
@@ -21,5 +25,22 @@ public static string AsString(this IList exceptionEvents)
return reportedExceptions;
}
+
+ /// [4 bytes] || packet_len || padding_len || payload || padding.
+ public static byte[] GetPacket(this Message message, byte paddingMultiplier, Compressor? compressor)
+ {
+ var buffer = Array.Empty();
+
+ var byteCount = message.GetPacket(
+ ref buffer,
+ paddingMultiplier,
+ compressor,
+ excludePacketLengthFieldWhenPadding: false,
+ macLength: 0);
+
+ Array.Resize(ref buffer, byteCount);
+
+ return buffer;
+ }
}
}