C
C#4mo ago
SilverShade

ThreadPool implementation with WSQ

I'm trying to implement a simple version of ThreadPool with work stealing queues. It seems like i'm missing something, but can't see what. My implementation with WSQs is actually slower than without them. The implementation of a work stealing queue (WorkStealingQueue class) was provided for me, so you can assume it's not the issue. I'll provide my code in the messages below.
1 Reply
SilverShade
SilverShade4mo ago
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Threading;
using CustomThreadPool.Collections;

namespace CustomThreadPool.ThreadPools;

public class MyThreadPool : IThreadPool
{
private long _processedTask;
private readonly Queue<Action> _globalQueue = new();
private readonly ReadOnlyDictionary<int, WorkStealingQueue<Action>> _localQueues;

private const int WorkerCount = 10;

public MyThreadPool()
{
var workers = new Thread[WorkerCount];

var queues = new Dictionary<int, WorkStealingQueue<Action>>();
for (var i = 0; i < WorkerCount; i++)
{
var worker = new Thread(Worker) {IsBackground = true};
workers[i] = worker;
queues.Add(worker.ManagedThreadId, new WorkStealingQueue<Action>());
}

_localQueues = new ReadOnlyDictionary<int, WorkStealingQueue<Action>>(queues);

foreach (var worker in workers)
worker.Start();
}

public void EnqueueAction(Action action)
{
if (_localQueues.TryGetValue(Environment.CurrentManagedThreadId, out var localQueue))
{
localQueue.LocalPush(action);
return;
}

lock (_globalQueue)
{
_globalQueue.Enqueue(action);
Monitor.Pulse(_globalQueue);
}
}

public long GetTasksProcessedCount() => _processedTask;

private void Worker()
{
while (true)
{
if (TryGetFromLocalQueue(out var work)
|| TryGetFromGlobalQueue(out work)
|| TrySteal(out work))
{
work();
Interlocked.Increment(ref _processedTask);
continue;
}

lock (_globalQueue)
Monitor.Wait(_globalQueue);
}

// ReSharper disable once FunctionNeverReturns
}
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Threading;
using CustomThreadPool.Collections;

namespace CustomThreadPool.ThreadPools;

public class MyThreadPool : IThreadPool
{
private long _processedTask;
private readonly Queue<Action> _globalQueue = new();
private readonly ReadOnlyDictionary<int, WorkStealingQueue<Action>> _localQueues;

private const int WorkerCount = 10;

public MyThreadPool()
{
var workers = new Thread[WorkerCount];

var queues = new Dictionary<int, WorkStealingQueue<Action>>();
for (var i = 0; i < WorkerCount; i++)
{
var worker = new Thread(Worker) {IsBackground = true};
workers[i] = worker;
queues.Add(worker.ManagedThreadId, new WorkStealingQueue<Action>());
}

_localQueues = new ReadOnlyDictionary<int, WorkStealingQueue<Action>>(queues);

foreach (var worker in workers)
worker.Start();
}

public void EnqueueAction(Action action)
{
if (_localQueues.TryGetValue(Environment.CurrentManagedThreadId, out var localQueue))
{
localQueue.LocalPush(action);
return;
}

lock (_globalQueue)
{
_globalQueue.Enqueue(action);
Monitor.Pulse(_globalQueue);
}
}

public long GetTasksProcessedCount() => _processedTask;

private void Worker()
{
while (true)
{
if (TryGetFromLocalQueue(out var work)
|| TryGetFromGlobalQueue(out work)
|| TrySteal(out work))
{
work();
Interlocked.Increment(ref _processedTask);
continue;
}

lock (_globalQueue)
Monitor.Wait(_globalQueue);
}

// ReSharper disable once FunctionNeverReturns
}
private bool TryGetFromLocalQueue(out Action task)
{
task = null;
return _localQueues[Environment.CurrentManagedThreadId].LocalPop(ref task);
}

private bool TryGetFromGlobalQueue(out Action task)
{
lock (_globalQueue)
return _globalQueue.TryDequeue(out task);
}

private bool TrySteal(out Action task)
{
task = null;
var currentThreadId = Environment.CurrentManagedThreadId;
foreach (var (threadId, queue) in _localQueues)
{
if (threadId == currentThreadId || !queue.TrySteal(ref task))
continue;

return true;
}

return false;
}
private bool TryGetFromLocalQueue(out Action task)
{
task = null;
return _localQueues[Environment.CurrentManagedThreadId].LocalPop(ref task);
}

private bool TryGetFromGlobalQueue(out Action task)
{
lock (_globalQueue)
return _globalQueue.TryDequeue(out task);
}

private bool TrySteal(out Action task)
{
task = null;
var currentThreadId = Environment.CurrentManagedThreadId;
foreach (var (threadId, queue) in _localQueues)
{
if (threadId == currentThreadId || !queue.TrySteal(ref task))
continue;

return true;
}

return false;
}