Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 54 additions & 82 deletions src/Renci.SshNet/Messages/Message.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#nullable enable
using System;
using System.Diagnostics;
using System.Globalization;
using System.IO;

using Renci.SshNet.Abstractions;
Expand Down Expand Up @@ -38,116 +41,85 @@ protected override void WriteBytes(SshDataStream stream)
base.WriteBytes(stream);
}

/// <returns>[4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].</returns>
internal byte[] GetPacket(byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding = false, int macLength = 0)
/// <returns>The number of bytes occupied by the packet in <paramref name="buffer"/>.</returns>
/// <remarks>
/// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].
/// </remarks>
internal int GetPacket(ref byte[] buffer, byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding, int macLength)
{
const int outboundPacketSequenceSize = 4;

var messageLength = BufferCapacity;

ArraySegment<byte> payload = default;

if (messageLength == -1 || compressor != null)
{
using (var sshDataStream = new SshDataStream(DefaultCapacity))
using (var sshDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
{
// skip:
// * 4 bytes for the outbound packet sequence
// * 4 bytes for the packet data length
// * one byte for the packet padding length
_ = sshDataStream.Seek(outboundPacketSequenceSize + 4 + 1, SeekOrigin.Begin);

if (compressor != null)
{
// obtain uncompressed message payload
using (var uncompressedDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
{
WriteBytes(uncompressedDataStream);

// compress message payload
var compressedMessageData = compressor.Compress(uncompressedDataStream.ToArray());

// add compressed message payload
sshDataStream.Write(compressedMessageData, 0, compressedMessageData.Length);
}
}
else
{
// add message payload
WriteBytes(sshDataStream);
}

messageLength = (int)sshDataStream.Length - (outboundPacketSequenceSize + 4 + 1);

var packetLength = messageLength + 4 + 1;

// determine the padding length
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
// padding length calculation
var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
WriteBytes(sshDataStream);

var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
var success = sshDataStream.TryGetBuffer(out payload);

// skip bytes for outbound packet sequence
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
Debug.Assert(success);
}

// add packet data length
sshDataStream.Write(packetDataLength);
if (compressor != null)
{
payload = new(compressor.Compress(payload.Array, payload.Offset, payload.Count));
}

// add packet padding length
sshDataStream.WriteByte(paddingLength);
messageLength = payload.Count;
}

_ = sshDataStream.Seek(0, SeekOrigin.End);
// determine the padding length
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
// padding length calculation
var paddingLength = GetPaddingLength(
paddingMultiplier, (excludePacketLengthFieldWhenPadding ? 0 : 4) + 1 + messageLength);

sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);
var packetLength = 1 + messageLength + paddingLength;

var buffer = sshDataStream.ToArray();
var bytesRequired = 4 + 4 + packetLength + macLength;

// add padding bytes
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
if ((uint)bytesRequired > (uint)Session.MaximumSshPacketSize)
{
throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", Session.MaximumSshPacketSize));
}

return buffer;
}
if (buffer.Length < bytesRequired)
{
Array.Resize(ref buffer, Math.Max(bytesRequired, 2 * buffer.Length));
}
else

using (var sshDataStream = new SshDataStream(buffer))
{
var packetLength = messageLength + 4 + 1;
// skip bytes for outbound packet sequenceSize
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);

// determine the padding length
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
// padding length calculation
var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
// add packet length
sshDataStream.Write((uint)packetLength);

var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
// add padding length
sshDataStream.WriteByte(paddingLength);

// lets construct an SSH data stream of the exact size required
using (var sshDataStream = new SshDataStream(packetLength + paddingLength + outboundPacketSequenceSize + macLength))
// add message payload
if (payload != default)
{
sshDataStream.Write(payload.Array!, payload.Offset, payload.Count);
}
else
{
// skip bytes for outbound packet sequenceSize
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);

// add packet data length
sshDataStream.Write(packetDataLength);

// add packet padding length
sshDataStream.WriteByte(paddingLength);

// add message payload
WriteBytes(sshDataStream);
}

sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);

var buffer = sshDataStream.ToArray();

// add padding bytes
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
Debug.Assert(sshDataStream.Position == bytesRequired - macLength - paddingLength);

return buffer;
}
// add padding bytes
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
}
}

private static uint GetPacketDataLength(int messageLength, byte paddingLength)
{
return (uint)(messageLength + paddingLength + 1);
return bytesRequired;
}

private static byte GetPaddingLength(byte paddingMultiplier, long packetLength)
Expand Down
39 changes: 20 additions & 19 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ private uint InboundPacketSequence
private Socket _socket;

private ArrayBuffer _receiveBuffer = new(4 * 1024);
private byte[] _sendBuffer = new byte[4 * 1024];

/// <summary>
/// Gets the session semaphore that controls session channels.
Expand Down Expand Up @@ -1073,29 +1074,29 @@ internal void SendMessage(Message message)
macLength = _clientMac.HashSize / 8;
}

var packetData = message.GetPacket(paddingMultiplier, _clientCompression, _clientEtm || _clientAead, macLength);

if (packetData.Length > MaximumSshPacketSize)
{
throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", MaximumSshPacketSize));
}

// take a write lock to ensure the outbound packet sequence number is incremented
// atomically, and only after the packet has actually been sent
lock (_socketWriteLock)
{
var activeBufferLength = message.GetPacket(
Comment on lines 1079 to +1081
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for posterity: I measured and did not see any increased lock contention from moving GetPacket inside the lock. Most parallel scenarios are already made serial by the lock in Channel.SendData

ref _sendBuffer,
paddingMultiplier,
_clientCompression,
_clientEtm || _clientAead,
macLength);

// write outbound packet sequence to start of packet data
BinaryPrimitives.WriteUInt32BigEndian(packetData, _outboundPacketSequence);
BinaryPrimitives.WriteUInt32BigEndian(_sendBuffer, _outboundPacketSequence);

if (_clientMac != null && !_clientEtm)
{
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)

var hashSuccess = _clientMac.TryComputeHash(
buffer: packetData,
buffer: _sendBuffer,
offset: 0,
count: packetData.Length - macLength,
destination: packetData.AsSpan(packetData.Length - macLength),
count: activeBufferLength - macLength,
destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
bytesWritten: out var bytesWritten);

Debug.Assert(hashSuccess && bytesWritten == macLength);
Expand All @@ -1110,30 +1111,30 @@ internal void SendMessage(Message message)
var offset = _clientEtm ? 8 : 4;

var numberOfBytesEncrypted = _clientCipher.Encrypt(
input: packetData,
input: _sendBuffer,
offset,
length: packetData.Length - offset - macLength,
output: packetData,
length: activeBufferLength - offset - macLength,
output: _sendBuffer,
outputOffset: offset);

Debug.Assert(numberOfBytesEncrypted == packetData.Length - offset - macLength + (_clientAead ? macLength : 0));
Debug.Assert(numberOfBytesEncrypted == activeBufferLength - offset - macLength + (_clientAead ? macLength : 0));
}

if (_clientMac != null && _clientEtm)
{
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)

var hashSuccess = _clientMac.TryComputeHash(
buffer: packetData,
buffer: _sendBuffer,
offset: 0,
count: packetData.Length - macLength,
destination: packetData.AsSpan(packetData.Length - macLength),
count: activeBufferLength - macLength,
destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
bytesWritten: out var bytesWritten);

Debug.Assert(hashSuccess && bytesWritten == macLength);
}

SendPacket(packetData, 4, packetData.Length - 4);
SendPacket(_sendBuffer, 4, activeBufferLength - 4);

if (_isStrictKex && message is NewKeysMessage)
{
Expand Down
1 change: 1 addition & 0 deletions test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

using Renci.SshNet.Messages.Connection;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;

using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;

using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes
{
Expand Down
23 changes: 22 additions & 1 deletion test/Renci.SshNet.Tests/Common/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System.Collections.Generic;
#nullable enable
using System;
using System.Collections.Generic;

using Renci.SshNet.Common;
using Renci.SshNet.Compression;
using Renci.SshNet.Messages;

namespace Renci.SshNet.Tests.Common
{
Expand All @@ -21,5 +25,22 @@ public static string AsString(this IList<ExceptionEventArgs> exceptionEvents)

return reportedExceptions;
}

/// <returns>[4 bytes] || packet_len || padding_len || payload || padding.</returns>
public static byte[] GetPacket(this Message message, byte paddingMultiplier, Compressor? compressor)
{
var buffer = Array.Empty<byte>();

var byteCount = message.GetPacket(
ref buffer,
paddingMultiplier,
compressor,
excludePacketLengthFieldWhenPadding: false,
macLength: 0);

Array.Resize(ref buffer, byteCount);

return buffer;
}
}
}