Can a Neural Network Predict Poisson Distributed Target Data?

A few days ago I looked at predicting Poisson distributed target data. There is a dedicated technique called Poisson regression. But I wondered if a standard neural network would be effective or not when applied to data that has a (approximately) Poisson distribution.

The short answer, based on a few informal experiments with PyTorch, is that a standard neural network seems to work just fine on data that is Poisson distributed.

Poisson distributed data typically arises when “things arrive”. For example, you could look at the number of cars that arrive at a fast food drive-through window, every day, from 11:00 to 11:05 AM. The counts of cars arriving will be skewed towards 0 but large counts are possible. I started by generating some synthetic Poisson data that looks like:

 0.3409, -0.1654,  0.1174, -0.7192, 1
-0.3690,  0.3730,  0.6693, -0.9634, 0
 0.7892, -0.8299, -0.9219, -0.6603, 1
-0.8153, -0.6275, -0.3089, -0.2065, 7
-0.8033, -0.1578,  0.9158,  0.0663, 2

. . .

There are four predictor variables, each with a value between -1.0 and +1.0, and a target variable between 0 and 9. I made 200 training items and 40 test items.

I created a 4-(10-10)-1 PyTorch neural network regression model with tanh() hidden node activation and identity() output activation.

To measure model accuracy, I computed a predicted output as a float32, for example 3.4567, and then rounded to the nearest integer, and then checked if the predicted integer was the same as the target integer. The model, with minimal tuning, scored 0.9950 accuracy on the training data (199 out of 200 correct) and 0.9000 accuracy on the test data (36 of 40 correct). These scores were comparable to the results I got using dedicated Poisson regression techniques.

I suspect that I’m missing a few nuances, but I’ll explore some more when I get a chance.



The Poisson distribution is named after mathematician Simeon Poisson. The word “poisson” means “fish” in French, which I’ve always felt was a bit odd. There are dozens of movies that feature fish-women — mermaids — but most aren’t very good. Here are three that have evil, rather than benign, mermaids

In “Mermaid’s Song” (2015), in 1930s Oklahoma young Charlotte discovers she is the daughter of a mermaid. When gangsters show up and take over the family entertainment club, Charlotte discovers she can control people using her voice. This movie had potential but wasn’t executed very well. My grade = C.

In “Peter Pan” (2003), the depiction of the mermaids in Neverland, as rather menacing, is much closer to the original book than the cute depiction seen in the classic 1953 animated Disney film. The 2003 version is not a bad movie at all and dramatically better than the 2023 Disney live action version which is just awful. My grade for the 2003 version is B; my grade for the 1953 Disney animated version is solid A; my grade for the 2023 Disney version is solid D.

In “Killer Mermaid” (2014), the movie title pretty much tells you what you need to know. Two young American women go on vacation to the Mediterranean and discover . . . a killer mermaid. Not a terrible movie but too slow for my taste. My grade = C.


Demo code. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols.

# poisson_regression.py
# PyTorch 2.0-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11 

import numpy as np
import torch as T

device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class SyntheticDataset(T.utils.data.Dataset):
  def __init__(self, src_file):
    # -0.8153, -0.6275, -0.3089, -0.2065, 7
    #  0.3409, -0.1654,  0.1174, -0.7192, 1
    #  0.7892, -0.8299, -0.9219, -0.6603, 1
    # -0.8033, -0.1578,  0.9158,  0.0663, 2
    # -0.3690,  0.3730,  0.6693, -0.9634, 0

    # two-reads approach (memory efficient)
    tmp_x = np.loadtxt(src_file, usecols=[0,1,2,3],
      delimiter=",", comments="#", dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=4, delimiter=",",
      comments="#", dtype=np.float32)
    tmp_y = tmp_y.reshape(-1,1)  # 2D required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (preds, trgts)  # as a tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(4, 10)  # 4-(10-10)-1
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 1)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = self.oupt(z)  # no activation
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  n_correct = 0; n_wrong = 0

  for i in range(len(ds)):
    X = ds[i][0]   # 2-d
    Y = ds[i][1]   # 2-d      # target y as float
    target_y = T.round(Y)       # as int
    with T.no_grad():
      raw_y = model(X)     # predicted y as float
    pred_y = T.round(raw_y)  # as int

    if pred_y == target_y:
      n_correct += 1
    else:
      n_wrong += 1
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def train(model, ds, bs, lr, me, le):
  # dataset, bat_size, lrn_rate, max_epochs, log interval
  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)
  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(model.parameters(), lr=lr)

  for epoch in range(0, me):
    epoch_loss = 0.0  # for one full epoch
    for (b_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # predictors
      y = batch[1]  # target income
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights

    if epoch % le == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
        (epoch, epoch_loss)) 

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin predict Poisson data ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create Dataset objects
  print("\nCreating Dataset objects ")
  train_file = ".\\Data\\train_poisson_200.txt"
  train_ds = SyntheticDataset(train_file)  # 200 rows

  test_file = ".\\Data\\test_poisson_40.txt"
  test_ds = SyntheticDataset(test_file)  # 40 rows

  # 2. create network
  print("\nCreating 4-(10-10)-1 neural network ")
  net = Net().to(device)

# -----------------------------------------------------------

  # 3. train model
  print("\nbat_size = 10 ")
  print("loss = MSELoss() ")
  print("optimizer = Adam ")
  print("lrn_rate = 0.01 ")

  print("\nStarting training")
  net.train()
  train(net, train_ds, bs=10, lr=0.01, me=500, le=100)
  print("Done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy ")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on train data = %0.4f" % acc_train)

  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

# -----------------------------------------------------------

  # 5. make a prediction
  print("\nPredicting for 0.5, -0.5, 0.5, -0.5: ")
  x = np.array([[0.5, -0.5, 0.5, -0.5]],
    dtype=np.float32)
  x = np.array([[-0.5, 0.8, -0.1, 0.9]],
    dtype=np.float32)
  x = T.tensor(x, dtype=T.float32).to(device) 

  with T.no_grad():
    pred_y = net(x)
  pred_y = pred_y.item()  # scalar
  print("%0.4f" % pred_y)
  pred_y_int = np.round(pred_y).astype(int)
  print(pred_y_int)

# -----------------------------------------------------------

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  fn = ".\\Models\\poisson_model.pt"
  T.save(net.state_dict(), fn)

  # model = Net()
  # model.load_state_dict(T.load(fn))
  # use model to make prediction(s)

  print("\nEnd Poisson data demo ")

if __name__ == "__main__":
  main()

Training data:

# train_poisson_200.txt
#
-0.8153, -0.6275, -0.3089, -0.2065, 7
0.3409, -0.1654, 0.1174, -0.7192, 1
0.7892, -0.8299, -0.9219, -0.6603, 1
-0.8033, -0.1578, 0.9158, 0.0663, 2
-0.3690, 0.3730, 0.6693, -0.9634, 0
-0.7953, -0.1719, 0.3888, -0.1716, 4
0.7666, 0.2473, 0.5019, -0.3022, 1
0.7918, -0.1438, 0.9297, 0.3269, 1
-0.5259, 0.8068, 0.1474, -0.9943, 3
-0.3467, 0.0541, 0.7719, -0.2855, 1
0.2467, -0.9684, 0.8589, 0.3818, 1
-0.6553, -0.7257, 0.8652, 0.3936, 2
0.5109, 0.5078, 0.8460, 0.4230, 5
-0.9602, -0.9476, -0.9434, -0.5076, 9
0.0777, 0.1056, 0.6841, -0.7517, 0
0.1715, 0.9392, 0.1221, -0.9627, 2
0.1125, -0.7271, -0.8802, -0.7573, 2
-0.7850, -0.5486, 0.4260, 0.1194, 4
0.4907, 0.3385, -0.4702, -0.8673, 2
0.2594, -0.5797, 0.5055, -0.8669, 0
0.6095, -0.6131, 0.2789, 0.0493, 1
-0.4734, -0.8681, 0.4701, 0.5444, 5
0.8639, -0.9721, -0.5313, 0.2336, 4
0.9004, 0.1133, 0.8312, 0.2831, 1
0.0113, -0.9570, 0.8959, 0.6542, 1
0.8802, 0.1640, 0.7577, 0.6895, 5
-0.0802, 0.0927, 0.5972, -0.4286, 1
0.1982, -0.9689, 0.1870, -0.1326, 1
-0.3695, 0.7858, 0.1557, -0.6320, 5
0.2241, -0.8922, -0.1596, 0.3581, 5
0.6577, 0.1494, 0.2562, -0.4288, 1
-0.1970, -0.3652, 0.2438, -0.1395, 2
0.3556, -0.6029, -0.1466, -0.3133, 1
0.7600, 0.8077, 0.3254, -0.4596, 2
0.7098, 0.0554, 0.6043, 0.1450, 2
0.6237, 0.7499, 0.3768, 0.1390, 8
0.8326, 0.8193, -0.4858, -0.7782, 4
-0.1683, 0.2334, -0.5327, -0.7961, 5
-0.0457, -0.6947, 0.2436, 0.0880, 2
0.5405, 0.4635, -0.4806, -0.4859, 7
-0.3995, -0.7140, 0.8026, 0.0831, 1
-0.1162, 0.1632, 0.9795, -0.5922, 0
-0.4757, 0.5003, -0.0860, -0.8861, 3
-0.5761, 0.5972, -0.4053, -0.9448, 7
0.6877, -0.2380, 0.4997, 0.0223, 1
0.8951, -0.5571, -0.4659, -0.8371, 0
-0.7132, -0.8432, -0.9633, -0.8666, 3
-0.7733, -0.9444, 0.5097, -0.2103, 1
-0.0952, -0.0998, -0.0439, -0.0520, 7
0.2370, -0.9793, 0.0773, -0.9940, 0
0.8108, 0.5919, 0.8305, -0.7089, 0
0.4636, 0.8186, -0.1983, -0.5003, 7
0.8215, -0.2669, -0.1328, 0.0246, 3
-0.9381, 0.4338, 0.7820, -0.9454, 1
0.0345, 0.8328, -0.1471, -0.5052, 9
-0.8250, -0.5454, -0.3712, -0.6505, 3
0.1484, -0.3020, -0.8861, -0.5424, 6
0.6337, 0.1887, 0.9520, 0.8031, 6
0.9507, -0.6640, 0.9456, 0.5349, 1
0.2652, 0.3375, -0.0462, -0.9737, 1
-0.3226, 0.0478, 0.5098, -0.0723, 3
-0.9542, 0.0382, 0.6200, -0.9748, 1
0.3736, -0.1015, 0.8296, 0.2887, 2
0.1570, -0.4518, 0.1211, 0.3435, 6
0.7117, -0.6099, 0.4946, -0.4208, 0
-0.1445, 0.6154, -0.2929, -0.5726, 9
-0.3827, 0.4665, 0.4889, -0.5572, 2
-0.6021, -0.7150, -0.2458, -0.9467, 1
0.8347, 0.4226, 0.1078, -0.3910, 2
0.9521, -0.6803, -0.5948, -0.1376, 2
-0.4090, 0.4632, 0.8906, -0.1489, 3
-0.2859, -0.7839, 0.5751, -0.7868, 0
0.4577, 0.0334, 0.4139, 0.5611, 9
0.5406, 0.5012, 0.2264, -0.1963, 4
-0.9938, 0.5498, 0.7928, -0.5214, 3
-0.5594, -0.3958, 0.7661, 0.0863, 2
-0.7233, -0.4197, 0.2277, -0.3517, 2
0.0357, -0.6111, 0.6959, -0.4967, 0
0.6455, 0.7224, -0.1203, -0.4885, 4
-0.0443, -0.7313, 0.8557, 0.7919, 3
0.7134, -0.1628, 0.3669, -0.2040, 1
0.5836, 0.3915, 0.5557, -0.1870, 2
-0.2498, 0.7150, 0.2392, -0.4959, 5
0.6166, -0.4094, 0.0882, -0.0242, 2
0.7768, -0.6312, 0.1707, 0.7964, 7
0.8437, -0.4420, 0.2177, 0.3649, 3
-0.9725, -0.1666, 0.8770, -0.3139, 1
0.3770, -0.4932, 0.3847, -0.5454, 0
0.2272, 0.2966, -0.6601, -0.7011, 6
0.7507, -0.6321, -0.0743, -0.1421, 1
0.7193, 0.3432, 0.2669, -0.7505, 1
0.9731, 0.8966, 0.2902, -0.6966, 1
0.5022, 0.1587, 0.8494, -0.8705, 0
-0.8940, -0.6010, -0.1545, -0.7850, 1
-0.2428, -0.6236, 0.4940, -0.3192, 1
-0.0963, 0.4169, 0.5549, -0.0103, 5
0.9319, -0.7812, 0.3461, -0.0001, 0
0.5356, -0.4194, -0.5662, -0.9666, 1
-0.4524, -0.3433, 0.0951, -0.5597, 1
-0.7144, -0.8118, 0.7404, -0.5263, 0
0.6250, -0.4324, 0.0557, -0.3212, 1
0.9488, -0.3766, 0.3376, -0.3481, 0
-0.5785, -0.9170, -0.3563, -0.9258, 1
0.3407, -0.1391, 0.5356, 0.0720, 2
-0.7304, -0.6132, -0.3287, -0.8954, 1
-0.5284, 0.8817, 0.3684, -0.8702, 3
0.4028, 0.2099, 0.4647, -0.4931, 1
-0.4357, 0.7675, 0.1354, -0.7698, 4
0.1920, -0.5211, -0.7372, -0.6763, 2
0.2314, -0.8816, 0.5006, 0.8964, 5
-0.3580, -0.7541, 0.4426, -0.1193, 1
-0.3933, -0.9572, 0.9950, 0.1641, 1
0.8579, 0.0142, -0.0906, 0.1757, 6
0.4086, 0.3633, 0.3943, 0.2372, 7
0.5216, 0.5621, 0.8082, -0.5325, 1
-0.2178, -0.3589, 0.6310, 0.2271, 3
-0.1447, -0.8011, -0.7699, -0.2532, 6
0.6415, 0.1993, 0.3777, -0.0178, 2
-0.5298, -0.0768, -0.6028, -0.9490, 3
0.4498, -0.3392, 0.6870, -0.1431, 1
-0.3900, 0.7419, 0.8175, -0.3403, 3
0.7984, -0.8486, 0.7572, -0.6183, 0
0.6437, 0.2565, 0.9126, 0.1798, 1
-0.1413, -0.3265, 0.9839, -0.2395, 0
0.0376, -0.6554, -0.8509, -0.2594, 8
0.2690, -0.1722, 0.9818, 0.8599, 5
-0.1600, -0.4760, 0.8216, -0.9555, 0
0.4006, -0.0590, 0.6543, -0.0083, 1
0.2465, 0.2767, -0.3449, -0.8650, 2
-0.2358, -0.7466, -0.5115, -0.8413, 1
0.4834, 0.2300, 0.3448, -0.9832, 0
0.0064, -0.5382, -0.6502, -0.6300, 2
-0.2141, 0.5813, 0.2902, -0.2122, 7
-0.1920, -0.7278, -0.0987, -0.3312, 2
0.4011, 0.8611, 0.7252, -0.6651, 1
-0.3216, 0.1118, 0.0735, -0.2188, 6
0.3570, 0.3746, 0.1230, -0.2838, 3
0.8715, 0.1938, 0.9592, -0.1180, 0
-0.9248, 0.5295, 0.0366, -0.9894, 3
0.9970, -0.7207, -0.8589, -0.8531, 1
0.9436, -0.8105, 0.6835, 0.3703, 1
-0.9481, -0.0770, -0.4374, -0.9421, 4
0.5420, -0.3405, 0.5931, -0.3507, 0
0.0361, -0.2545, 0.4207, -0.0887, 2
0.9808, 0.5478, -0.3314, -0.8220, 2
0.7201, 0.9148, 0.9189, -0.9243, 0
0.9074, -0.0461, -0.4435, 0.0060, 7
0.7594, 0.2640, -0.5787, -0.3098, 8
-0.1113, -0.8325, -0.6694, -0.6056, 2
0.0324, 0.7265, 0.9683, -0.9803, 0
-0.4259, -0.7336, 0.8742, 0.6097, 3
-0.6292, 0.8663, 0.8715, -0.4329, 3
0.1029, -0.6294, -0.1158, -0.6294, 1
-0.7136, 0.2647, 0.3238, -0.1323, 8
-0.0146, -0.0697, 0.6135, -0.4867, 1
-0.5197, 0.3729, 0.9798, -0.6451, 1
-0.4215, 0.8955, 0.6999, -0.1307, 8
0.2597, -0.6839, -0.9704, -0.4690, 4
-0.5102, -0.4154, -0.6081, -0.8241, 3
0.8142, 0.7209, -0.3231, -0.9457, 2
-0.6430, 0.9397, 0.4839, -0.4804, 7
0.9105, -0.8385, -0.8329, 0.2383, 8
0.5304, 0.1363, 0.3324, -0.7844, 0
0.7123, -0.2713, 0.7845, -0.9446, 0
0.9628, 0.2190, -0.1647, -0.6616, 1
0.9332, -0.6918, 0.7902, -0.3780, 0
0.3641, -0.5271, -0.6645, 0.0170, 9
0.3848, -0.7621, 0.8015, -0.0405, 0
0.7899, -0.3417, 0.0560, 0.3008, 4
0.2271, -0.5711, 0.8788, 0.5009, 2
0.4848, -0.2195, 0.5197, 0.8059, 9
-0.8192, -0.7420, -0.7895, -0.6545, 5
0.5979, 0.6213, 0.7200, -0.0829, 2
0.4503, 0.1311, -0.0152, -0.4816, 2
-0.5011, -0.5615, 0.5993, 0.0048, 2
0.5540, 0.0673, 0.4788, 0.0308, 2
0.9725, -0.9435, 0.8655, 0.8617, 1
-0.4184, 0.8318, 0.8058, 0.0708, 9
0.9794, -0.0702, 0.4692, 0.2816, 2
0.6387, -0.8604, -0.9162, -0.9012, 1
-0.2935, -0.6036, -0.5588, -0.9124, 1
0.3077, -0.1125, 0.4379, -0.7800, 0
0.6953, 0.3181, -0.2423, -0.1669, 6
0.8985, 0.2784, 0.4970, -0.2968, 1
0.6735, 0.1041, 0.7353, 0.6997, 6
0.8007, -0.9768, -0.2405, -0.5290, 0
0.4760, 0.4482, -0.0764, -0.2695, 5
0.0674, 0.1012, 0.2310, -0.2087, 3
-0.4976, 0.3115, 0.9208, -0.9929, 0
-0.7820, 0.0876, 0.2538, -0.5141, 3
0.2446, -0.5366, 0.5403, -0.7890, 0
-0.0978, -0.1841, 0.7495, 0.4059, 4
0.5348, -0.8376, 0.9968, -0.2208, 0
-0.6228, -0.1183, 0.6896, 0.1905, 5
0.0253, 0.1478, -0.1194, -0.6129, 2
-0.8494, -0.1486, 0.2040, -0.1492, 6
0.8838, 0.7398, 0.4080, -0.0594, 4
0.6003, -0.4621, -0.5617, -0.1424, 4
0.2869, -0.9090, -0.0729, -0.3305, 1
-0.4305, -0.7555, 0.3366, 0.3627, 5

Test data:

# test_poisson_40.txt
#
-0.4772, -0.1630, 0.0836, -0.0451, 7
0.9281, 0.2131, 0.2079, -0.3614, 1
0.7283, -0.1042, 0.1236, 0.4734, 8
-0.1050, -0.6317, 0.6575, -0.9380, 0
0.1540, 0.7508, 0.2171, -0.4967, 4
0.5395, 0.6057, -0.1509, -0.5918, 3
0.0632, 0.3584, 0.0258, -0.4018, 4
-0.5358, 0.2559, -0.3920, -0.9409, 4
0.7079, 0.1917, -0.5197, -0.6270, 3
-0.5256, 0.0790, 0.4986, -0.5435, 1
-0.9205, -0.7049, 0.7029, -0.7907, 0
-0.2456, -0.1511, -0.0541, -0.1543, 6
0.5452, -0.7256, -0.1514, -0.4561, 1
0.4090, -0.5478, 0.6900, 0.0525, 1
-0.4331, 0.6013, 0.3296, -0.4425, 5
-0.5260, 0.8514, 0.5624, -0.3852, 6
0.7623, -0.7478, 0.2669, -0.4446, 0
-0.2060, -0.6769, 0.2980, 0.5067, 6
0.4678, -0.8251, -0.2005, -0.4408, 1
-0.1433, -0.9900, -0.6110, -0.3596, 3
0.1973, 0.2230, 0.0002, -0.3583, 3
-0.6194, -0.9882, 0.4699, -0.1130, 1
0.2492, -0.2253, 0.3402, -0.5929, 0
-0.6557, -0.9626, -0.0261, -0.6316, 1
-0.0424, 0.4989, 0.1454, -0.2751, 6
0.0534, -0.9415, 0.1102, -0.4140, 0
0.4655, -0.1193, 0.3263, -0.5222, 1
0.7903, -0.7789, 0.4623, 0.0909, 1
-0.9404, -0.7065, -0.4320, -0.9903, 1
0.7405, 0.0939, 0.0863, 0.2315, 6
-0.2665, 0.1499, 0.6184, -0.5968, 1
0.4858, -0.3182, -0.9048, -0.6873, 3
0.5068, -0.4351, -0.2628, -0.4145, 1
-0.3925, 0.0433, 0.4246, -0.3894, 2
0.8460, 0.8042, -0.2085, -0.8320, 2
0.0779, 0.0842, 0.6504, -0.9360, 0
-0.4491, 0.0188, 0.7716, 0.1306, 4
-0.1172, 0.9443, 0.4850, -0.2902, 6
0.0852, 0.7681, -0.4499, -0.7839, 7
-0.1839, -0.8993, 0.4897, -0.4085, 0
This entry was posted in PyTorch. Bookmark the permalink.