changeset 2656:83fb61fa476e

mod_storage_memory: Serialize data functions that return the data (prevents mutation of stored data)
author Kim Alvefur <zash@zash.se>
date Thu, 30 Mar 2017 23:40:29 +0200
parents f4353f959460
children 6f5c99c9f6cc
files mod_storage_memory/mod_storage_memory.lua
diffstat 1 files changed, 17 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/mod_storage_memory/mod_storage_memory.lua	Thu Mar 30 23:38:40 2017 +0200
+++ b/mod_storage_memory/mod_storage_memory.lua	Thu Mar 30 23:40:29 2017 +0200
@@ -1,3 +1,8 @@
+local serialize = require "util.serialization".serialize;
+local envload = require "util.envload".envload;
+local st = require "util.stanza";
+local is_stanza = st.is_stanza or function (s) return getmetatable(s) == st.stanza_mt end
+
 local auto_purge_enabled = module:get_option_boolean("storage_memory_temporary", false);
 local auto_purge_stores = module:get_option_set("storage_memory_temporary_stores", {});
 
@@ -9,7 +14,7 @@
 	end
 });
 
-local NULL = {};
+local function NULL() return nil end
 
 local function _purge_store(self, username)
 	self.store[username or NULL] = nil;
@@ -20,11 +25,11 @@
 keyval_store.__index = keyval_store;
 
 function keyval_store:get(username)
-	return self.store[username or NULL];
+	return (self.store[username or NULL] or NULL)();
 end
 
 function keyval_store:set(username, data)
-	self.store[username or NULL] = data;
+	self.store[username or NULL] = envload(serialize(data), "@data", {});
 	return true;
 end
 
@@ -37,6 +42,14 @@
 	if type(when) ~= "number" then
 		when, with, value = value, when, with;
 	end
+	if is_stanza(value) then
+		value = st.preserialize(value);
+		value = function ()
+			return st.deserialize(envload(serialize(data), "@stanza", {}));
+		end
+	else
+		value = envload(serialize(data), "@data", {});
+	end
 	local a = self.store[username or NULL];
 	if not a then
 		a = {};
@@ -61,7 +74,7 @@
 		item = a[i];
 		when, with = item.when, item.with;
 		if when >= when_start and when_end >= when and (not match_with or match_with == with) then
-			coroutine.yield(item.key, item.value, when, with);
+			coroutine.yield(item.key, item.value(), when, with);
 			count = count + 1;
 			if limit and count >= limit then return end
 		end