diff --git a/.vscode/settings.json b/.vscode/settings.json index 5bbe8c61..da190865 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,10 @@ "i18n-ally.keystyle": "nested", "i18n-ally.sourceLanguage": "en", "editor.defaultFormatter": "esbenp.prettier-vscode", - "editor.formatOnSave": true + "editor.formatOnSave": true, + "omnisharp.enableRoslynAnalyzers": true, + "omnisharp.useModernNet": false, + "[csharp]": { + "editor.defaultFormatter": "ms-dotnettools.csharp" + } } diff --git a/AppApi.cs b/AppApi.cs index 9cbb33f8..51d24f5f 100644 --- a/AppApi.cs +++ b/AppApi.cs @@ -51,6 +51,11 @@ namespace VRCX CheckGameRunning(); } + /// + /// Computes the MD5 hash of the file represented by the specified base64-encoded string. + /// + /// The base64-encoded string representing the file. + /// The MD5 hash of the file as a base64-encoded string. public string MD5File(string Blob) { var fileData = Convert.FromBase64CharArray(Blob.ToCharArray(), 0, Blob.Length); @@ -58,6 +63,11 @@ namespace VRCX return Convert.ToBase64String(md5); } + /// + /// Computes the signature of the file represented by the specified base64-encoded string using the librsync library. + /// + /// The base64-encoded string representing the file. + /// The signature of the file as a base64-encoded string. public string SignFile(string Blob) { var fileData = Convert.FromBase64CharArray(Blob.ToCharArray(), 0, Blob.Length); @@ -68,12 +78,21 @@ namespace VRCX return Convert.ToBase64String(sigBytes); } + /// + /// Returns the length of the file represented by the specified base64-encoded string. + /// + /// The base64-encoded string representing the file. + /// The length of the file in bytes. public string FileLength(string Blob) { var fileData = Convert.FromBase64CharArray(Blob.ToCharArray(), 0, Blob.Length); return fileData.Length.ToString(); } + /// + /// Reads the VRChat config file and returns its contents as a string. + /// + /// The contents of the VRChat config file as a string, or an empty string if the file does not exist. public string ReadConfigFile() { var logPath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData) + @"Low\VRChat\VRChat\"; @@ -87,6 +106,10 @@ namespace VRCX return json; } + /// + /// Writes the specified JSON string to the VRChat config file. + /// + /// The JSON string to write to the config file. public void WriteConfigFile(string json) { var logPath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData) + @"Low\VRChat\VRChat\"; @@ -94,6 +117,11 @@ namespace VRCX File.WriteAllText(configFile, json); } + /// + /// Gets the VRChat application data location by reading the config file and checking the cache directory. + /// If the cache directory is not found in the config file, it returns the default cache path. + /// + /// The VRChat application data location. public string GetVRChatAppDataLocation() { var json = ReadConfigFile(); @@ -114,21 +142,34 @@ namespace VRCX return cachePath; } + /// + /// Gets the VRChat cache location by combining the VRChat application data location with the cache directory name. + /// + /// The VRChat cache location. public string GetVRChatCacheLocation() { return Path.Combine(GetVRChatAppDataLocation(), "Cache-WindowsPlayer"); } + /// + /// Shows the developer tools for the main browser window. + /// public void ShowDevTools() { MainForm.Instance.Browser.ShowDevTools(); } + /// + /// Deletes all cookies from the global cef cookie manager. + /// public void DeleteAllCookies() { Cef.GetGlobalCookieManager().DeleteCookies(); } + /// + /// Checks if the VRChat game and SteamVR are currently running and updates the browser's JavaScript function $app.updateIsGameRunning with the results. + /// public void CheckGameRunning() { var isGameRunning = false; @@ -144,10 +185,16 @@ namespace VRCX isSteamVRRunning = true; } + // TODO: fix this throwing an exception for being called before the browser is ready. somehow it gets past the checks if (MainForm.Instance?.Browser != null && !MainForm.Instance.Browser.IsLoading) MainForm.Instance.Browser.ExecuteScriptAsync("$app.updateIsGameRunning", isGameRunning, isSteamVRRunning); } + + /// + /// Kills the VRChat process if it is currently running. + /// + /// The number of processes that were killed (0 or 1). public int QuitGame() { var processes = Process.GetProcessesByName("vrchat"); @@ -157,6 +204,10 @@ namespace VRCX return processes.Length; } + /// + /// Starts the VRChat game process with the specified command-line arguments. + /// + /// The command-line arguments to pass to the VRChat game. public void StartGame(string arguments) { // try stream first @@ -204,6 +255,12 @@ namespace VRCX } } + /// + /// Starts the VRChat game process with the specified command-line arguments from the given path. + /// + /// The path to the VRChat game executable. + /// The command-line arguments to pass to the VRChat game. + /// True if the game was started successfully, false otherwise. public bool StartGameFromPath(string path, string arguments) { if (!path.EndsWith(".exe")) @@ -222,6 +279,11 @@ namespace VRCX return true; } + + /// + /// Opens the specified URL in the default browser. + /// + /// The URL to open. public void OpenLink(string url) { if (url.StartsWith("http://") || @@ -264,21 +326,43 @@ namespace VRCX VRCXVR.Instance.Restart(); } + /// + /// Returns an array of arrays containing information about the connected VR devices. + /// Each sub-array contains the type of device and its current state + /// + /// An array of arrays containing information about the connected VR devices. public string[][] GetVRDevices() { return VRCXVR.Instance.GetDevices(); } + /// + /// Returns the current CPU usage as a percentage. + /// + /// The current CPU usage as a percentage. public float CpuUsage() { return CpuMonitor.Instance.CpuUsage; } + /// + /// Retrieves an image from the VRChat API and caches it for future use. The function will return the cached image if it already exists. + /// + /// The URL of the image to retrieve. + /// The ID of the file associated with the image. + /// The version of the file associated with the image. + /// A string representing the file location of the cached image. public string GetImage(string url, string fileId, string version) { return ImageCache.GetImage(url, fileId, version); } + /// + /// Displays a desktop notification with the specified bold text, optional text, and optional image. + /// + /// The bold text to display in the notification. + /// The optional text to display in the notification. + /// The optional image to display in the notification. public void DesktopNotification(string BoldText, string Text = "", string Image = "") { var toastXml = ToastNotificationManager.GetTemplateContent(ToastTemplateType.ToastImageAndText02); @@ -297,6 +381,13 @@ namespace VRCX ToastNotificationManager.CreateToastNotifier("VRCX").Show(toast); } + /// + /// Displays an XSOverlay notification with the specified title, content, and optional image. + /// + /// The title of the notification. + /// The content of the notification. + /// The duration of the notification in milliseconds. + /// The optional image to display in the notification. public void XSNotification(string Title, string Content, int Timeout, string Image = "") { bool UseBase64Icon; @@ -332,6 +423,10 @@ namespace VRCX broadcastSocket.SendTo(byteBuffer, endPoint); } + /// + /// Downloads the VRCX update executable from the specified URL and saves it to the AppData directory. + /// + /// The URL of the VRCX update to download. public void DownloadVRCXUpdate(string url) { var Location = Path.Combine(Program.AppDataDirectory, "update.exe"); @@ -340,6 +435,9 @@ namespace VRCX client.DownloadFile(new Uri(url), Location); } + /// + /// Restarts the VRCX application for an update by launching a new process with the "/Upgrade" argument and exiting the current process. + /// public void RestartApplication() { var VRCXProcess = new Process(); @@ -350,6 +448,10 @@ namespace VRCX Environment.Exit(0); } + /// + /// Checks if the VRCX update executable exists in the AppData directory. + /// + /// True if the update executable exists, false otherwise. public bool CheckForUpdateExe() { if (File.Exists(Path.Combine(Program.AppDataDirectory, "update.exe"))) @@ -357,6 +459,9 @@ namespace VRCX return false; } + /// + /// Sends an IPC packet to announce the start of VRCX. + /// public void IPCAnnounceStart() { IPCServer.Send(new IPCPacket @@ -365,6 +470,11 @@ namespace VRCX }); } + /// + /// Sends an IPC packet with a specified message type and data. + /// + /// The message type to send. + /// The data to send. public void SendIpc(string type, string data) { IPCServer.Send(new IPCPacket @@ -397,6 +507,10 @@ namespace VRCX VRCXVR._browser2.ExecuteScriptAsync($"$app.{function}", json); } + /// + /// Gets the launch command from the startup arguments and clears the launch command. + /// + /// The launch command. public string GetLaunchCommand() { var command = StartupArgs.LaunchCommand; @@ -404,11 +518,18 @@ namespace VRCX return command; } + /// + /// Focuses the main window of the VRCX application. + /// public void FocusWindow() { MainForm.Instance.Invoke(new Action(() => { MainForm.Instance.Focus_Window(); })); } + /// + /// Returns the file path of the custom user CSS file, if it exists. + /// + /// The file path of the custom user CSS file, or an empty string if it doesn't exist. public string CustomCssPath() { var output = string.Empty; @@ -418,6 +539,10 @@ namespace VRCX return output; } + /// + /// Returns the file path of the custom user js file, if it exists. + /// + /// The file path of the custom user js file, or an empty string if it doesn't exist. public string CustomScriptPath() { var output = string.Empty; @@ -442,6 +567,10 @@ namespace VRCX return Program.Version; } + /// + /// Returns whether or not the VRChat client was last closed gracefully. According to the log file, anyway. + /// + /// True if the VRChat client was last closed gracefully, false otherwise. public bool VrcClosedGracefully() { return LogWatcher.Instance.VrcClosedGracefully; @@ -457,6 +586,10 @@ namespace VRCX WinformThemer.DoFunny(); } + /// + /// Returns the number of milliseconds that the system has been running. + /// + /// The number of milliseconds that the system has been running. public double GetUptime() { using (var uptime = new PerformanceCounter("System", "System Up Time")) @@ -466,12 +599,23 @@ namespace VRCX } } + /// + /// Returns a color value derived from the given user ID. + /// This is, essentially, and is used for, random colors. + /// + /// The user ID to derive the color value from. + /// A color value derived from the given user ID. public int GetColourFromUserID(string userId) { var hash = _hasher.ComputeHash(Encoding.UTF8.GetBytes(userId)); return (hash[3] << 8) | hash[4]; } + /// + /// Returns a dictionary of color values derived from the given list of user IDs. + /// + /// The list of user IDs to derive the color values from. + /// A dictionary of color values derived from the given list of user IDs. public Dictionary GetColourBulk(List userIds) { var output = new Dictionary(); @@ -483,6 +627,10 @@ namespace VRCX return output; } + /// + /// Retrieves the current text from the clipboard. + /// + /// The current text from the clipboard. public string GetClipboard() { var clipboard = string.Empty; @@ -493,6 +641,11 @@ namespace VRCX return clipboard; } + /// + /// Retrieves the value of the specified key from the VRChat group in the windows registry. + /// + /// The name of the key to retrieve. + /// The value of the specified key, or null if the key does not exist. public object GetVRChatRegistryKey(string key) { // https://answers.unity.com/questions/177945/playerprefs-changing-the-name-of-keys.html?childToView=208076#answer-208076 @@ -528,6 +681,12 @@ namespace VRCX return null; } + /// + /// Sets the value of the specified key in the VRChat group in the windows registry. + /// + /// The name of the key to set. + /// The value to set for the specified key. + /// True if the key was successfully set, false otherwise. public bool SetVRChatRegistryKey(string key, string value) { uint hash = 5381; @@ -562,6 +721,11 @@ namespace VRCX return true; } + /// + /// Retrieves a dictionary of moderations for the specified user from the VRChat LocalPlayerModerations folder. + /// + /// The ID of the current user. + /// A dictionary of moderations for the specified user, or null if the file does not exist. public Dictionary GetVRChatModerations(string currentUserId) { var filePath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData) + $@"Low\VRChat\VRChat\LocalPlayerModerations\{currentUserId}-show-hide-user.vrcset"; @@ -590,6 +754,12 @@ namespace VRCX return output; } + /// + /// Retrieves the moderation type for the specified user from the VRChat LocalPlayerModerations folder. + /// + /// The ID of the current user. + /// The ID of the user to retrieve the moderation type for. + /// The moderation type for the specified user, or 0 if the file does not exist or the user is not found. public short GetVRChatUserModeration(string currentUserId, string userId) { var filePath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData) + $@"Low\VRChat\VRChat\LocalPlayerModerations\{currentUserId}-show-hide-user.vrcset"; @@ -615,6 +785,13 @@ namespace VRCX return 0; } + /// + /// Sets the moderation type for the specified user in the VRChat LocalPlayerModerations folder. + /// + /// The ID of the current user. + /// The ID of the user to set the moderation type for. + /// The moderation type to set for the specified user. + /// True if the operation was successful, false otherwise. public bool SetVRChatUserModeration(string currentUserId, string userId, int type) { var filePath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData) + $@"Low\VRChat\VRChat\LocalPlayerModerations\{currentUserId}-show-hide-user.vrcset"; @@ -648,6 +825,10 @@ namespace VRCX return true; } + /// + /// Sets whether or not the application should start up automatically with Windows. + /// + /// True to enable automatic startup, false to disable it. public void SetStartup(bool enabled) { try @@ -679,6 +860,13 @@ namespace VRCX AutoAppLaunchManager.Instance.KillChildrenOnExit = killOnExit; } + /// + /// Adds metadata to a PNG screenshot file and optionally renames the file to include the specified world ID. + /// + /// The path to the PNG screenshot file. + /// The metadata to add to the screenshot file. + /// The ID of the world to associate with the screenshot. + /// Whether or not to rename the screenshot file to include the world ID. public void AddScreenshotMetadata(string path, string metadataString, string worldId, bool changeFilename = false) { var fileName = Path.GetFileNameWithoutExtension(path); @@ -696,7 +884,10 @@ namespace VRCX ScreenshotHelper.WritePNGDescription(path, metadataString); } - // Create a function that opens a file dialog so a user can choose a .png file. Print the name of the file after it is chosen + /// + /// Opens a file dialog to select a PNG screenshot file. + /// The resulting file path is passed to . + /// public void OpenScreenshotFileDialog() { if (dialogOpen) return; @@ -737,6 +928,10 @@ namespace VRCX thread.Start(); } + /// + /// Retrieves metadata from a PNG screenshot file and send the result to displayScreenshotMetadata in app.js + /// + /// The path to the PNG screenshot file. public void GetScreenshotMetadata(string path) { if (string.IsNullOrEmpty(path)) @@ -818,6 +1013,9 @@ namespace VRCX ExecuteAppFunction("displayScreenshotMetadata", metadata.ToString(Formatting.Indented)); } + /// + /// Gets the last screenshot taken by VRChat and retrieves its metadata. + /// public void GetLastScreenshot() { // Get the last screenshot taken by VRChat @@ -836,6 +1034,10 @@ namespace VRCX GetScreenshotMetadata(lastScreenshot); } + /// + /// Copies an image file to the clipboard if it exists and is of a supported image file type. + /// + /// The path to the image file to copy to the clipboard. public void CopyImageToClipboard(string path) { // check if the file exists and is any image file type @@ -849,6 +1051,9 @@ namespace VRCX } } + /// + /// Opens the folder containing user-defined shortcuts, if it exists. + /// public void OpenShortcutFolder() { var path = AutoAppLaunchManager.Instance.AppShortcutDirectory; @@ -858,6 +1063,11 @@ namespace VRCX OpenFolderAndSelectItem(path, true); } + /// + /// Opens the folder containing the specified file or folder path and selects the item in the folder. + /// + /// The path to the file or folder to select in the folder. + /// Whether the specified path is a folder or not. Defaults to false. public void OpenFolderAndSelectItem(string path, bool isFolder = false) { // I don't think it's quite meant for it, but SHOpenFolderAndSelectItems can open folders by passing the folder path as the item to select, as a child to itself, somehow. So we'll check to see if 'path' is a folder as well. @@ -901,11 +1111,17 @@ namespace VRCX } } + /// + /// Flashes the window of the main form. + /// public void FlashWindow() { MainForm.Instance.BeginInvoke(new MethodInvoker(() => { WinformThemer.Flash(MainForm.Instance); })); } + /// + /// Sets the user agent string for the browser. + /// public void SetUserAgent() { using (var client = MainForm.Instance.Browser.GetDevToolsClient()) diff --git a/AssetBundleCacher.cs b/AssetBundleCacher.cs index e92377ca..6d845b79 100644 --- a/AssetBundleCacher.cs +++ b/AssetBundleCacher.cs @@ -59,6 +59,12 @@ namespace VRCX return AppApi.Instance.GetVRChatCacheLocation(); } + /// + /// Gets the full location of the VRChat cache for a specific asset bundle. + /// + /// The ID of the asset bundle. + /// The version of the asset bundle. + /// The full location of the VRChat cache for the specified asset bundle. public string GetVRChatCacheFullLocation(string id, int version) { var cachePath = GetVRChatCacheLocation(); @@ -67,6 +73,12 @@ namespace VRCX return Path.Combine(cachePath, idHash, versionLocation); } + /// + /// Checks the VRChat cache for a specific asset bundle. + /// + /// The ID of the asset bundle. + /// The version of the asset bundle. + /// An array containing the file size and lock status of the asset bundle. public long[] CheckVRChatCache(string id, int version) { long FileSize = -1; @@ -155,6 +167,11 @@ namespace VRCX DownloadProgress = -16; } + /// + /// Deletes the cache directory for a specific asset bundle. + /// + /// The ID of the asset bundle to delete. + /// The version of the asset bundle to delete. public void DeleteCache(string id, int version) { var FullLocation = GetVRChatCacheFullLocation(id, version); @@ -162,6 +179,9 @@ namespace VRCX Directory.Delete(FullLocation, true); } + /// + /// Deletes the entire VRChat cache directory. + /// public void DeleteAllCache() { var cachePath = GetVRChatCacheLocation(); @@ -172,6 +192,9 @@ namespace VRCX } } + /// + /// Removes empty directories from the VRChat cache directory and deletes old versions of cached asset bundles. + /// public void SweepCache() { var cachePath = GetVRChatCacheLocation(); @@ -201,6 +224,10 @@ namespace VRCX } } + /// + /// Returns the size of the VRChat cache directory in bytes. + /// + /// The size of the VRChat cache directory in bytes. public long GetCacheSize() { var cachePath = GetVRChatCacheLocation(); @@ -214,6 +241,12 @@ namespace VRCX } } + + /// + /// Recursively calculates the size of a directory and all its subdirectories. + /// + /// The directory to calculate the size of. + /// The size of the directory and all its subdirectories in bytes. public long DirSize(DirectoryInfo d) { long size = 0; diff --git a/AutoAppLaunchManager.cs b/AutoAppLaunchManager.cs index 5dcb0191..c6d410a9 100644 --- a/AutoAppLaunchManager.cs +++ b/AutoAppLaunchManager.cs @@ -22,7 +22,7 @@ namespace VRCX private DateTime startTime = DateTime.Now; private Dictionary startedProcesses = new Dictionary(); - private static readonly byte[] shortcutSignatureBytes = { 0x4C, 0x00, 0x00, 0x00 }; // signature for ShellLinkHeader\ + private static readonly byte[] shortcutSignatureBytes = { 0x4C, 0x00, 0x00, 0x00 }; // signature for ShellLinkHeader private const uint TH32CS_SNAPPROCESS = 2; @@ -123,6 +123,11 @@ namespace VRCX // This is a recursive function that kills a process and all of its children. // It uses the CreateToolhelp32Snapshot winapi func to get a snapshot of all running processes, loops through them with Process32First/Process32Next, and kills any processes that have the given pid as their parent. + + /// + /// Kills a process and all of its child processes. + /// + /// The process ID of the parent process. public static void KillProcessTree(int pid) { IntPtr snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); diff --git a/LogWatcher.cs b/LogWatcher.cs index 67b49cb9..3103de06 100644 --- a/LogWatcher.cs +++ b/LogWatcher.cs @@ -14,6 +14,9 @@ using CefSharp; namespace VRCX { + /// + /// Monitors the VRChat log files for changes and provides access to the log data. + /// public class LogWatcher { public static readonly LogWatcher Instance; @@ -88,6 +91,9 @@ namespace VRCX } } + /// + /// Updates the log watcher by checking for new log files and updating the log list. + /// private void Update() { if (m_ResetLog) @@ -157,6 +163,11 @@ namespace VRCX m_FirstRun = false; } + /// + /// Parses the log file starting from the current position and updates the log context. + /// + /// The file information of the log file to parse. + /// The log context to update. private void ParseLog(FileInfo fileInfo, LogContext logContext) { try @@ -224,6 +235,7 @@ namespace VRCX ParseLogUsharpVideoPlay(fileInfo, logContext, line, offset) || ParseLogUsharpVideoSync(fileInfo, logContext, line, offset) || ParseLogWorldVRCX(fileInfo, logContext, line, offset) || + ParseLogWorldDataVRCX(fileInfo, logContext, line, offset) || ParseLogOnAudioConfigurationChanged(fileInfo, logContext, line, offset) || ParseLogScreenshot(fileInfo, logContext, line, offset) || ParseLogStringDownload(fileInfo, logContext, line, offset) || @@ -593,6 +605,19 @@ namespace VRCX return true; } + private bool ParseLogWorldDataVRCX(FileInfo fileInfo, LogContext logContext, string line, int offset) + { + // [VRCX-World] store:test:testvalue + + if (string.Compare(line, offset, "[VRCX-World] ", 0, 13, StringComparison.Ordinal) != 0) + return false; + + var data = line.Substring(offset + 13); + + WorldDBManager.Instance.ProcessLogWorldDataRequest(data); + return true; + } + private bool ParseLogVideoChange(FileInfo fileInfo, LogContext logContext, string line, int offset) { // 2021.04.20 13:37:69 Log - [Video Playback] Attempting to resolve URL 'https://www.youtube.com/watch?v=dQw4w9WgXcQ' @@ -927,7 +952,7 @@ namespace VRCX private bool ParseOpenVRInit(FileInfo fileInfo, LogContext logContext, string line, int offset) { // 2022.07.29 02:52:14 Log - OpenVR initialized! - + // 2023.04.22 16:52:28 Log - Initializing VRSDK. // 2023.04.22 16:52:29 Log - StartVRSDK: Open VR Loader @@ -944,7 +969,7 @@ namespace VRCX return true; } - + private bool ParseDesktopMode(FileInfo fileInfo, LogContext logContext, string line, int offset) { // 2023.04.22 16:54:18 Log - VR Disabled diff --git a/Program.cs b/Program.cs index c784256f..47741085 100644 --- a/Program.cs +++ b/Program.cs @@ -6,6 +6,7 @@ using System; using System.IO; +using System.Threading.Tasks; using System.Windows.Forms; namespace VRCX @@ -78,8 +79,12 @@ namespace VRCX Application.EnableVisualStyles(); Application.SetCompatibleTextRenderingDefault(false); + // I'll re-do this whole function eventually I swear + var worldDBServer = new WorldDBManager("http://127.0.0.1:22500/"); + Task.Run(worldDBServer.Start); + ProcessMonitor.Instance.Init(); - SQLite.Instance.Init(); + SQLiteLegacy.Instance.Init(); VRCXStorage.Load(); LoadFromConfig(); CpuMonitor.Instance.Init(); @@ -99,14 +104,18 @@ namespace VRCX AutoAppLaunchManager.Instance.Exit(); LogWatcher.Instance.Exit(); WebApi.Instance.Exit(); + worldDBServer.Stop(); Discord.Instance.Exit(); CpuMonitor.Instance.Exit(); VRCXStorage.Save(); - SQLite.Instance.Exit(); + SQLiteLegacy.Instance.Exit(); ProcessMonitor.Instance.Exit(); } + /// + /// Sets GPUFix to true if it is not already set and the VRCX_GPUFix key in the database is true. + /// private static void LoadFromConfig() { if (!GPUFix) diff --git a/SQLite.cs b/SQLite.cs index 5bd26aa8..3bfd1fd7 100644 --- a/SQLite.cs +++ b/SQLite.cs @@ -1,151 +1,4956 @@ -using CefSharp; +// +// Copyright (c) 2009-2021 Krueger Systems, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// +#if WINDOWS_PHONE && !USE_WP8_NATIVE_SQLITE +#define USE_CSHARP_SQLITE +#endif + using System; +using System.Collections; +using System.Diagnostics; +#if !USE_SQLITEPCL_RAW +using System.Runtime.InteropServices; +#endif using System.Collections.Generic; -using System.Data.SQLite; -using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Text; using System.Threading; -namespace VRCX +#if USE_CSHARP_SQLITE +using Sqlite3 = Community.CsharpSqlite.Sqlite3; +using Sqlite3DatabaseHandle = Community.CsharpSqlite.Sqlite3.sqlite3; +using Sqlite3Statement = Community.CsharpSqlite.Sqlite3.Vdbe; +#elif USE_WP8_NATIVE_SQLITE +using Sqlite3 = Sqlite.Sqlite3; +using Sqlite3DatabaseHandle = Sqlite.Database; +using Sqlite3Statement = Sqlite.Statement; +#elif USE_SQLITEPCL_RAW +using Sqlite3DatabaseHandle = SQLitePCL.sqlite3; +using Sqlite3BackupHandle = SQLitePCL.sqlite3_backup; +using Sqlite3Statement = SQLitePCL.sqlite3_stmt; +using Sqlite3 = SQLitePCL.raw; +#else +using Sqlite3DatabaseHandle = System.IntPtr; +using Sqlite3BackupHandle = System.IntPtr; +using Sqlite3Statement = System.IntPtr; +#endif + +#pragma warning disable 1591 // XML Doc Comments + +namespace SQLite { - public class SQLite - { - public static readonly SQLite Instance; - private readonly ReaderWriterLockSlim m_ConnectionLock; - private readonly SQLiteConnection m_Connection; + public class SQLiteException : Exception + { + public SQLite3.Result Result { get; private set; } - static SQLite() + protected SQLiteException (SQLite3.Result r, string message) : base (message) + { + Result = r; + } + + public static SQLiteException New (SQLite3.Result r, string message) + { + return new SQLiteException (r, message); + } + } + + public class NotNullConstraintViolationException : SQLiteException + { + public IEnumerable Columns { get; protected set; } + + protected NotNullConstraintViolationException (SQLite3.Result r, string message) + : this (r, message, null, null) + { + + } + + protected NotNullConstraintViolationException (SQLite3.Result r, string message, TableMapping mapping, object obj) + : base (r, message) + { + if (mapping != null && obj != null) { + this.Columns = from c in mapping.Columns + where c.IsNullable == false && c.GetValue (obj) == null + select c; + } + } + + public static new NotNullConstraintViolationException New (SQLite3.Result r, string message) + { + return new NotNullConstraintViolationException (r, message); + } + + public static NotNullConstraintViolationException New (SQLite3.Result r, string message, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException (r, message, mapping, obj); + } + + public static NotNullConstraintViolationException New (SQLiteException exception, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException (exception.Result, exception.Message, mapping, obj); + } + } + + [Flags] + public enum SQLiteOpenFlags + { + ReadOnly = 1, ReadWrite = 2, Create = 4, + NoMutex = 0x8000, FullMutex = 0x10000, + SharedCache = 0x20000, PrivateCache = 0x40000, + ProtectionComplete = 0x00100000, + ProtectionCompleteUnlessOpen = 0x00200000, + ProtectionCompleteUntilFirstUserAuthentication = 0x00300000, + ProtectionNone = 0x00400000 + } + + [Flags] + public enum CreateFlags + { + /// + /// Use the default creation options + /// + None = 0x000, + /// + /// Create a primary key index for a property called 'Id' (case-insensitive). + /// This avoids the need for the [PrimaryKey] attribute. + /// + ImplicitPK = 0x001, + /// + /// Create indices for properties ending in 'Id' (case-insensitive). + /// + ImplicitIndex = 0x002, + /// + /// Create a primary key for a property called 'Id' and + /// create an indices for properties ending in 'Id' (case-insensitive). + /// + AllImplicit = 0x003, + /// + /// Force the primary key property to be auto incrementing. + /// This avoids the need for the [AutoIncrement] attribute. + /// The primary key property on the class should have type int or long. + /// + AutoIncPK = 0x004, + /// + /// Create virtual table using FTS3 + /// + FullTextSearch3 = 0x100, + /// + /// Create virtual table using FTS4 + /// + FullTextSearch4 = 0x200 + } + + public interface ISQLiteConnection + { + Sqlite3DatabaseHandle Handle { get; } + string DatabasePath { get; } + int LibVersionNumber { get; } + bool TimeExecution { get; set; } + bool Trace { get; set; } + Action Tracer { get; set; } + bool StoreDateTimeAsTicks { get; } + bool StoreTimeSpanAsTicks { get; } + string DateTimeStringFormat { get; } + TimeSpan BusyTimeout { get; set; } + IEnumerable TableMappings { get; } + bool IsInTransaction { get; } + + event EventHandler TableChanged; + + void Backup (string destinationDatabasePath, string databaseName = "main"); + void BeginTransaction (); + void Close (); + void Commit (); + SQLiteCommand CreateCommand (string cmdText, params object[] ps); + SQLiteCommand CreateCommand (string cmdText, Dictionary args); + int CreateIndex (string indexName, string tableName, string[] columnNames, bool unique = false); + int CreateIndex (string indexName, string tableName, string columnName, bool unique = false); + int CreateIndex (string tableName, string columnName, bool unique = false); + int CreateIndex (string tableName, string[] columnNames, bool unique = false); + int CreateIndex (Expression> property, bool unique = false); + CreateTableResult CreateTable (CreateFlags createFlags = CreateFlags.None); + CreateTableResult CreateTable (Type ty, CreateFlags createFlags = CreateFlags.None); + CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new(); + CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new(); + CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new(); + CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + where T5 : new(); + CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None, params Type[] types); + IEnumerable DeferredQuery (string query, params object[] args) where T : new(); + IEnumerable DeferredQuery (TableMapping map, string query, params object[] args); + int Delete (object objectToDelete); + int Delete (object primaryKey); + int Delete (object primaryKey, TableMapping map); + int DeleteAll (); + int DeleteAll (TableMapping map); + void Dispose (); + int DropTable (); + int DropTable (TableMapping map); + void EnableLoadExtension (bool enabled); + void EnableWriteAheadLogging (); + int Execute (string query, params object[] args); + T ExecuteScalar (string query, params object[] args); + T Find (object pk) where T : new(); + object Find (object pk, TableMapping map); + T Find (Expression> predicate) where T : new(); + T FindWithQuery (string query, params object[] args) where T : new(); + object FindWithQuery (TableMapping map, string query, params object[] args); + T Get (object pk) where T : new(); + object Get (object pk, TableMapping map); + T Get (Expression> predicate) where T : new(); + TableMapping GetMapping (Type type, CreateFlags createFlags = CreateFlags.None); + TableMapping GetMapping (CreateFlags createFlags = CreateFlags.None); + List GetTableInfo (string tableName); + int Insert (object obj); + int Insert (object obj, Type objType); + int Insert (object obj, string extra); + int Insert (object obj, string extra, Type objType); + int InsertAll (IEnumerable objects, bool runInTransaction = true); + int InsertAll (IEnumerable objects, string extra, bool runInTransaction = true); + int InsertAll (IEnumerable objects, Type objType, bool runInTransaction = true); + int InsertOrReplace (object obj); + int InsertOrReplace (object obj, Type objType); + List Query (string query, params object[] args) where T : new(); + List Query (TableMapping map, string query, params object[] args); + List QueryScalars (string query, params object[] args); + void Release (string savepoint); + void Rollback (); + void RollbackTo (string savepoint); + void RunInTransaction (Action action); + string SaveTransactionPoint (); + TableQuery Table () where T : new(); + int Update (object obj); + int Update (object obj, Type objType); + int UpdateAll (IEnumerable objects, bool runInTransaction = true); + } + + /// + /// An open connection to a SQLite database. + /// + [Preserve (AllMembers = true)] + public partial class SQLiteConnection : IDisposable + , ISQLiteConnection + { + private bool _open; + private TimeSpan _busyTimeout; + readonly static Dictionary _mappings = new Dictionary (); + private System.Diagnostics.Stopwatch _sw; + private long _elapsedMilliseconds = 0; + + private int _transactionDepth = 0; + private Random _rand = new Random (); + + public Sqlite3DatabaseHandle Handle { get; private set; } + static readonly Sqlite3DatabaseHandle NullHandle = default (Sqlite3DatabaseHandle); + static readonly Sqlite3BackupHandle NullBackupHandle = default (Sqlite3BackupHandle); + + /// + /// Gets the database path used by this connection. + /// + public string DatabasePath { get; private set; } + + /// + /// Gets the SQLite library version number. 3007014 would be v3.7.14 + /// + public int LibVersionNumber { get; private set; } + + /// + /// Whether Trace lines should be written that show the execution time of queries. + /// + public bool TimeExecution { get; set; } + + /// + /// Whether to write queries to during execution. + /// + public bool Trace { get; set; } + + /// + /// The delegate responsible for writing trace lines. + /// + /// The tracer. + public Action Tracer { get; set; } + + /// + /// Whether to store DateTime properties as ticks (true) or strings (false). + /// + public bool StoreDateTimeAsTicks { get; private set; } + + /// + /// Whether to store TimeSpan properties as ticks (true) or strings (false). + /// + public bool StoreTimeSpanAsTicks { get; private set; } + + /// + /// The format to use when storing DateTime properties as strings. Ignored if StoreDateTimeAsTicks is true. + /// + /// The date time string format. + public string DateTimeStringFormat { get; private set; } + + /// + /// The DateTimeStyles value to use when parsing a DateTime property string. + /// + /// The date time style. + internal System.Globalization.DateTimeStyles DateTimeStyle { get; private set; } + +#if USE_SQLITEPCL_RAW && !NO_SQLITEPCL_RAW_BATTERIES + static SQLiteConnection () + { + SQLitePCL.Batteries_V2.Init (); + } +#endif + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + public SQLiteConnection (string databasePath, bool storeDateTimeAsTicks = true) + : this (new SQLiteConnectionString (databasePath, SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks)) + { + } + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Flags controlling how the connection should be opened. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + public SQLiteConnection (string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = true) + : this (new SQLiteConnectionString (databasePath, openFlags, storeDateTimeAsTicks)) + { + } + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Details on how to find and open the database. + /// + public SQLiteConnection (SQLiteConnectionString connectionString) + { + if (connectionString == null) + throw new ArgumentNullException (nameof (connectionString)); + if (connectionString.DatabasePath == null) + throw new InvalidOperationException ("DatabasePath must be specified"); + + DatabasePath = connectionString.DatabasePath; + + LibVersionNumber = SQLite3.LibVersionNumber (); + +#if NETFX_CORE + SQLite3.SetDirectory(/*temp directory type*/2, Windows.Storage.ApplicationData.Current.TemporaryFolder.Path); +#endif + + Sqlite3DatabaseHandle handle; + +#if SILVERLIGHT || USE_CSHARP_SQLITE || USE_SQLITEPCL_RAW + var r = SQLite3.Open (connectionString.DatabasePath, out handle, (int)connectionString.OpenFlags, connectionString.VfsName); +#else + // open using the byte[] + // in the case where the path may include Unicode + // force open to using UTF-8 using sqlite3_open_v2 + var databasePathAsBytes = GetNullTerminatedUtf8 (connectionString.DatabasePath); + var r = SQLite3.Open (databasePathAsBytes, out handle, (int)connectionString.OpenFlags, connectionString.VfsName); +#endif + + Handle = handle; + if (r != SQLite3.Result.OK) { + throw SQLiteException.New (r, String.Format ("Could not open database file: {0} ({1})", DatabasePath, r)); + } + _open = true; + + StoreDateTimeAsTicks = connectionString.StoreDateTimeAsTicks; + StoreTimeSpanAsTicks = connectionString.StoreTimeSpanAsTicks; + DateTimeStringFormat = connectionString.DateTimeStringFormat; + DateTimeStyle = connectionString.DateTimeStyle; + + BusyTimeout = TimeSpan.FromSeconds (1.0); + Tracer = line => Debug.WriteLine (line); + + connectionString.PreKeyAction?.Invoke (this); + if (connectionString.Key is string stringKey) { + SetKey (stringKey); + } + else if (connectionString.Key is byte[] bytesKey) { + SetKey (bytesKey); + } + else if (connectionString.Key != null) { + throw new InvalidOperationException ("Encryption keys must be strings or byte arrays"); + } + connectionString.PostKeyAction?.Invoke (this); + } + + /// + /// Enables the write ahead logging. WAL is significantly faster in most scenarios + /// by providing better concurrency and better disk IO performance than the normal + /// journal mode. You only need to call this function once in the lifetime of the database. + /// + public void EnableWriteAheadLogging () + { + ExecuteScalar ("PRAGMA journal_mode=WAL"); + } + + /// + /// Convert an input string to a quoted SQL string that can be safely used in queries. + /// + /// The quoted string. + /// The unsafe string to quote. + static string Quote (string unsafeString) + { + // TODO: Doesn't call sqlite3_mprintf("%Q", u) because we're waiting on https://github.com/ericsink/SQLitePCL.raw/issues/153 + if (unsafeString == null) + return "NULL"; + var safe = unsafeString.Replace ("'", "''"); + return "'" + safe + "'"; + } + + /// + /// Sets the key used to encrypt/decrypt the database with "pragma key = ...". + /// This must be the first thing you call before doing anything else with this connection + /// if your database is encrypted. + /// This only has an effect if you are using the SQLCipher nuget package. + /// + /// Ecryption key plain text that is converted to the real encryption key using PBKDF2 key derivation + void SetKey (string key) + { + if (key == null) + throw new ArgumentNullException (nameof (key)); + var q = Quote (key); + ExecuteScalar ("pragma key = " + q); + } + + /// + /// Sets the key used to encrypt/decrypt the database. + /// This must be the first thing you call before doing anything else with this connection + /// if your database is encrypted. + /// This only has an effect if you are using the SQLCipher nuget package. + /// + /// 256-bit (32 byte) ecryption key data + void SetKey (byte[] key) + { + if (key == null) + throw new ArgumentNullException (nameof (key)); + if (key.Length != 32 && key.Length != 48) + throw new ArgumentException ("Key must be 32 bytes (256-bit) or 48 bytes (384-bit)", nameof (key)); + var s = String.Join ("", key.Select (x => x.ToString ("X2"))); + ExecuteScalar ("pragma key = \"x'" + s + "'\""); + } + + /// + /// Enable or disable extension loading. + /// + public void EnableLoadExtension (bool enabled) + { + SQLite3.Result r = SQLite3.EnableLoadExtension (Handle, enabled ? 1 : 0); + if (r != SQLite3.Result.OK) { + string msg = SQLite3.GetErrmsg (Handle); + throw SQLiteException.New (r, msg); + } + } + +#if !USE_SQLITEPCL_RAW + static byte[] GetNullTerminatedUtf8 (string s) + { + var utf8Length = System.Text.Encoding.UTF8.GetByteCount (s); + var bytes = new byte [utf8Length + 1]; + utf8Length = System.Text.Encoding.UTF8.GetBytes(s, 0, s.Length, bytes, 0); + return bytes; + } +#endif + + /// + /// Sets a busy handler to sleep the specified amount of time when a table is locked. + /// The handler will sleep multiple times until a total time of has accumulated. + /// + public TimeSpan BusyTimeout { + get { return _busyTimeout; } + set { + _busyTimeout = value; + if (Handle != NullHandle) { + SQLite3.BusyTimeout (Handle, (int)_busyTimeout.TotalMilliseconds); + } + } + } + + /// + /// Returns the mappings from types to tables that the connection + /// currently understands. + /// + public IEnumerable TableMappings { + get { + lock (_mappings) { + return new List (_mappings.Values); + } + } + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// The type whose mapping to the database is returned. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping (Type type, CreateFlags createFlags = CreateFlags.None) + { + TableMapping map; + var key = type.FullName; + lock (_mappings) { + if (_mappings.TryGetValue (key, out map)) { + if (createFlags != CreateFlags.None && createFlags != map.CreateFlags) { + map = new TableMapping (type, createFlags); + _mappings[key] = map; + } + } + else { + map = new TableMapping (type, createFlags); + _mappings.Add (key, map); + } + } + return map; + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping (CreateFlags createFlags = CreateFlags.None) + { + return GetMapping (typeof (T), createFlags); + } + + private struct IndexedColumn + { + public int Order; + public string ColumnName; + } + + private struct IndexInfo + { + public string IndexName; + public string TableName; + public bool Unique; + public List Columns; + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + public int DropTable () + { + return DropTable (GetMapping (typeof (T))); + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + /// + /// The TableMapping used to identify the table. + /// + public int DropTable (TableMapping map) + { + var query = string.Format ("drop table if exists \"{0}\"", map.TableName); + return Execute (query); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated. + /// + public CreateTableResult CreateTable (CreateFlags createFlags = CreateFlags.None) + { + return CreateTable (typeof (T), createFlags); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// Type to reflect to a database table. + /// Optional flags allowing implicit PK and indexes based on naming conventions. + /// + /// Whether the table was created or migrated. + /// + public CreateTableResult CreateTable (Type ty, CreateFlags createFlags = CreateFlags.None) + { + var map = GetMapping (ty, createFlags); + + // Present a nice error if no columns specified + if (map.Columns.Length == 0) { + throw new Exception (string.Format ("Cannot create a table without columns (does '{0}' have public properties?)", ty.FullName)); + } + + // Check if the table exists + var result = CreateTableResult.Created; + var existingCols = GetTableInfo (map.TableName); + + // Create or migrate it + if (existingCols.Count == 0) { + + // Facilitate virtual tables a.k.a. full-text search. + bool fts3 = (createFlags & CreateFlags.FullTextSearch3) != 0; + bool fts4 = (createFlags & CreateFlags.FullTextSearch4) != 0; + bool fts = fts3 || fts4; + var @virtual = fts ? "virtual " : string.Empty; + var @using = fts3 ? "using fts3 " : fts4 ? "using fts4 " : string.Empty; + + // Build query. + var query = "create " + @virtual + "table if not exists \"" + map.TableName + "\" " + @using + "(\n"; + var decls = map.Columns.Select (p => Orm.SqlDecl (p, StoreDateTimeAsTicks, StoreTimeSpanAsTicks)); + var decl = string.Join (",\n", decls.ToArray ()); + query += decl; + query += ")"; + if (map.WithoutRowId) { + query += " without rowid"; + } + + Execute (query); + } + else { + result = CreateTableResult.Migrated; + MigrateTable (map, existingCols); + } + + var indexes = new Dictionary (); + foreach (var c in map.Columns) { + foreach (var i in c.Indices) { + var iname = i.Name ?? map.TableName + "_" + c.Name; + IndexInfo iinfo; + if (!indexes.TryGetValue (iname, out iinfo)) { + iinfo = new IndexInfo { + IndexName = iname, + TableName = map.TableName, + Unique = i.Unique, + Columns = new List () + }; + indexes.Add (iname, iinfo); + } + + if (i.Unique != iinfo.Unique) + throw new Exception ("All the columns in an index must have the same value for their Unique property"); + + iinfo.Columns.Add (new IndexedColumn { + Order = i.Order, + ColumnName = c.Name + }); + } + } + + foreach (var indexName in indexes.Keys) { + var index = indexes[indexName]; + var columns = index.Columns.OrderBy (i => i.Order).Select (i => i.ColumnName).ToArray (); + CreateIndex (indexName, index.TableName, columns, index.Unique); + } + + return result; + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + { + return CreateTables (createFlags, typeof (T), typeof (T2)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + { + return CreateTables (createFlags, typeof (T), typeof (T2), typeof (T3)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + { + return CreateTables (createFlags, typeof (T), typeof (T2), typeof (T3), typeof (T4)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + where T5 : new() + { + return CreateTables (createFlags, typeof (T), typeof (T2), typeof (T3), typeof (T4), typeof (T5)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables (CreateFlags createFlags = CreateFlags.None, params Type[] types) + { + var result = new CreateTablesResult (); + foreach (Type type in types) { + var aResult = CreateTable (type, createFlags); + result.Results[type] = aResult; + } + return result; + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the index to create + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + /// Zero on success. + public int CreateIndex (string indexName, string tableName, string[] columnNames, bool unique = false) + { + const string sqlFormat = "create {2} index if not exists \"{3}\" on \"{0}\"(\"{1}\")"; + var sql = String.Format (sqlFormat, tableName, string.Join ("\", \"", columnNames), unique ? "unique" : "", indexName); + return Execute (sql); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the index to create + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + /// Zero on success. + public int CreateIndex (string indexName, string tableName, string columnName, bool unique = false) + { + return CreateIndex (indexName, tableName, new string[] { columnName }, unique); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + /// Zero on success. + public int CreateIndex (string tableName, string columnName, bool unique = false) + { + return CreateIndex (tableName + "_" + columnName, tableName, columnName, unique); + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + /// Zero on success. + public int CreateIndex (string tableName, string[] columnNames, bool unique = false) + { + return CreateIndex (tableName + "_" + string.Join ("_", columnNames), tableName, columnNames, unique); + } + + /// + /// Creates an index for the specified object property. + /// e.g. CreateIndex<Client>(c => c.Name); + /// + /// Type to reflect to a database table. + /// Property to index + /// Whether the index should be unique + /// Zero on success. + public int CreateIndex (Expression> property, bool unique = false) + { + MemberExpression mx; + if (property.Body.NodeType == ExpressionType.Convert) { + mx = ((UnaryExpression)property.Body).Operand as MemberExpression; + } + else { + mx = (property.Body as MemberExpression); + } + var propertyInfo = mx.Member as PropertyInfo; + if (propertyInfo == null) { + throw new ArgumentException ("The lambda expression 'property' should point to a valid Property"); + } + + var propName = propertyInfo.Name; + + var map = GetMapping (); + var colName = map.FindColumnWithPropertyName (propName).Name; + + return CreateIndex (map.TableName, colName, unique); + } + + [Preserve (AllMembers = true)] + public class ColumnInfo + { + // public int cid { get; set; } + + [Column ("name")] + public string Name { get; set; } + + // [Column ("type")] + // public string ColumnType { get; set; } + + public int notnull { get; set; } + + // public string dflt_value { get; set; } + + // public int pk { get; set; } + + public override string ToString () + { + return Name; + } + } + + /// + /// Query the built-in sqlite table_info table for a specific tables columns. + /// + /// The columns contains in the table. + /// Table name. + public List GetTableInfo (string tableName) + { + var query = "pragma table_info(\"" + tableName + "\")"; + return Query (query); + } + + void MigrateTable (TableMapping map, List existingCols) + { + var toBeAdded = new List (); + + foreach (var p in map.Columns) { + var found = false; + foreach (var c in existingCols) { + found = (string.Compare (p.Name, c.Name, StringComparison.OrdinalIgnoreCase) == 0); + if (found) + break; + } + if (!found) { + toBeAdded.Add (p); + } + } + + foreach (var p in toBeAdded) { + var addCol = "alter table \"" + map.TableName + "\" add column " + Orm.SqlDecl (p, StoreDateTimeAsTicks, StoreTimeSpanAsTicks); + Execute (addCol); + } + } + + /// + /// Creates a new SQLiteCommand. Can be overridden to provide a sub-class. + /// + /// + protected virtual SQLiteCommand NewCommand () + { + return new SQLiteCommand (this); + } + + /// + /// Creates a new SQLiteCommand given the command text with arguments. Place a '?' + /// in the command text for each of the arguments. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the command text. + /// + /// + /// A + /// + public SQLiteCommand CreateCommand (string cmdText, params object[] ps) + { + if (!_open) + throw SQLiteException.New (SQLite3.Result.Error, "Cannot create commands from unopened database"); + + var cmd = NewCommand (); + cmd.CommandText = cmdText; + foreach (var o in ps) { + cmd.Bind (o); + } + return cmd; + } + + /// + /// Creates a new SQLiteCommand given the command text with named arguments. Place a "[@:$]VVV" + /// in the command text for each of the arguments. VVV represents an alphanumeric identifier. + /// For example, @name :name and $name can all be used in the query. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of "[@:$]VVV" in the command text. + /// + /// + /// A + /// + public SQLiteCommand CreateCommand (string cmdText, Dictionary args) + { + if (!_open) + throw SQLiteException.New (SQLite3.Result.Error, "Cannot create commands from unopened database"); + + SQLiteCommand cmd = NewCommand (); + cmd.CommandText = cmdText; + foreach (var kv in args) { + cmd.Bind (kv.Key, kv.Value); + } + return cmd; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method instead of Query when you don't expect rows back. Such cases include + /// INSERTs, UPDATEs, and DELETEs. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public int Execute (string query, params object[] args) + { + var cmd = CreateCommand (query, args); + + if (TimeExecution) { + if (_sw == null) { + _sw = new Stopwatch (); + } + _sw.Reset (); + _sw.Start (); + } + + var r = cmd.ExecuteNonQuery (); + + if (TimeExecution) { + _sw.Stop (); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Tracer?.Invoke (string.Format ("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method when return primitive values. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public T ExecuteScalar (string query, params object[] args) + { + var cmd = CreateCommand (query, args); + + if (TimeExecution) { + if (_sw == null) { + _sw = new Stopwatch (); + } + _sw.Reset (); + _sw.Start (); + } + + var r = cmd.ExecuteScalar (); + + if (TimeExecution) { + _sw.Stop (); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Tracer?.Invoke (string.Format ("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query (string query, params object[] args) where T : new() + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteQuery (); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns the first column of each row of the result. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for the first column of each row returned by the query. + /// + public List QueryScalars (string query, params object[] args) + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteQueryScalars ().ToList (); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator (retrieved by calling GetEnumerator() on the result of this method) + /// will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery (string query, params object[] args) where T : new() + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteDeferredQuery (); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query (TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteQuery (map); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator (retrieved by calling GetEnumerator() on the result of this method) + /// will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery (TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteDeferredQuery (map); + } + + /// + /// Returns a queryable interface to the table represented by the given type. + /// + /// + /// A queryable object that is able to translate Where, OrderBy, and Take + /// queries into native SQL. + /// + public TableQuery Table () where T : new() + { + return new TableQuery (this); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public T Get (object pk) where T : new() + { + var map = GetMapping (typeof (T)); + return Query (map.GetByPrimaryKeySql, pk).First (); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public object Get (object pk, TableMapping map) + { + return Query (map, map.GetByPrimaryKeySql, pk).First (); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate. Throws a not found exception + /// if the object is not found. + /// + public T Get (Expression> predicate) where T : new() + { + return Table ().Where (predicate).First (); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public T Find (object pk) where T : new() + { + var map = GetMapping (typeof (T)); + return Query (map.GetByPrimaryKeySql, pk).FirstOrDefault (); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public object Find (object pk, TableMapping map) + { + return Query (map, map.GetByPrimaryKeySql, pk).FirstOrDefault (); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public T Find (Expression> predicate) where T : new() + { + return Table ().Where (predicate).FirstOrDefault (); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public T FindWithQuery (string query, params object[] args) where T : new() + { + return Query (query, args).FirstOrDefault (); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public object FindWithQuery (TableMapping map, string query, params object[] args) + { + return Query (map, query, args).FirstOrDefault (); + } + + /// + /// Whether has been called and the database is waiting for a . + /// + public bool IsInTransaction { + get { return _transactionDepth > 0; } + } + + /// + /// Begins a new transaction. Call to end the transaction. + /// + /// Throws if a transaction has already begun. + public void BeginTransaction () + { + // The BEGIN command only works if the transaction stack is empty, + // or in other words if there are no pending transactions. + // If the transaction stack is not empty when the BEGIN command is invoked, + // then the command fails with an error. + // Rather than crash with an error, we will just ignore calls to BeginTransaction + // that would result in an error. + if (Interlocked.CompareExchange (ref _transactionDepth, 1, 0) == 0) { + try { + Execute ("begin transaction"); + } + catch (Exception ex) { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo (null, true); + break; + } + } + else { + // Call decrement and not VolatileWrite in case we've already + // created a transaction point in SaveTransactionPoint since the catch. + Interlocked.Decrement (ref _transactionDepth); + } + + throw; + } + } + else { + // Calling BeginTransaction on an already open transaction is invalid + throw new InvalidOperationException ("Cannot begin a transaction while already in a transaction."); + } + } + + /// + /// Creates a savepoint in the database at the current point in the transaction timeline. + /// Begins a new transaction if one is not in progress. + /// + /// Call to undo transactions since the returned savepoint. + /// Call to commit transactions after the savepoint returned here. + /// Call to end the transaction, committing all changes. + /// + /// A string naming the savepoint. + public string SaveTransactionPoint () + { + int depth = Interlocked.Increment (ref _transactionDepth) - 1; + string retVal = "S" + _rand.Next (short.MaxValue) + "D" + depth; + + try { + Execute ("savepoint " + retVal); + } + catch (Exception ex) { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo (null, true); + break; + } + } + else { + Interlocked.Decrement (ref _transactionDepth); + } + + throw; + } + + return retVal; + } + + /// + /// Rolls back the transaction that was begun by or . + /// + public void Rollback () + { + RollbackTo (null, false); + } + + /// + /// Rolls back the savepoint created by or SaveTransactionPoint. + /// + /// The name of the savepoint to roll back to, as returned by . If savepoint is null or empty, this method is equivalent to a call to + public void RollbackTo (string savepoint) + { + RollbackTo (savepoint, false); + } + + /// + /// Rolls back the transaction that was begun by . + /// + /// The name of the savepoint to roll back to, as returned by . If savepoint is null or empty, this method is equivalent to a call to + /// true to avoid throwing exceptions, false otherwise + void RollbackTo (string savepoint, bool noThrow) + { + // Rolling back without a TO clause rolls backs all transactions + // and leaves the transaction stack empty. + try { + if (String.IsNullOrEmpty (savepoint)) { + if (Interlocked.Exchange (ref _transactionDepth, 0) > 0) { + Execute ("rollback"); + } + } + else { + DoSavePointExecute (savepoint, "rollback to "); + } + } + catch (SQLiteException) { + if (!noThrow) + throw; + + } + // No need to rollback if there are no transactions open. + } + + /// + /// Releases a savepoint returned from . Releasing a savepoint + /// makes changes since that savepoint permanent if the savepoint began the transaction, + /// or otherwise the changes are permanent pending a call to . + /// + /// The RELEASE command is like a COMMIT for a SAVEPOINT. + /// + /// The name of the savepoint to release. The string should be the result of a call to + public void Release (string savepoint) + { + try { + DoSavePointExecute (savepoint, "release "); + } + catch (SQLiteException ex) { + if (ex.Result == SQLite3.Result.Busy) { + // Force a rollback since most people don't know this function can fail + // Don't call Rollback() since the _transactionDepth is 0 and it won't try + // Calling rollback makes our _transactionDepth variable correct. + // Writes to the database only happen at depth=0, so this failure will only happen then. + try { + Execute ("rollback"); + } + catch { + // rollback can fail in all sorts of wonderful version-dependent ways. Let's just hope for the best + } + } + throw; + } + } + + void DoSavePointExecute (string savepoint, string cmd) + { + // Validate the savepoint + int firstLen = savepoint.IndexOf ('D'); + if (firstLen >= 2 && savepoint.Length > firstLen + 1) { + int depth; + if (Int32.TryParse (savepoint.Substring (firstLen + 1), out depth)) { + // TODO: Mild race here, but inescapable without locking almost everywhere. + if (0 <= depth && depth < _transactionDepth) { +#if NETFX_CORE || USE_SQLITEPCL_RAW || NETCORE + Volatile.Write (ref _transactionDepth, depth); +#elif SILVERLIGHT + _transactionDepth = depth; +#else + Thread.VolatileWrite (ref _transactionDepth, depth); +#endif + Execute (cmd + savepoint); + return; + } + } + } + + throw new ArgumentException ("savePoint is not valid, and should be the result of a call to SaveTransactionPoint.", "savePoint"); + } + + /// + /// Commits the transaction that was begun by . + /// + public void Commit () + { + if (Interlocked.Exchange (ref _transactionDepth, 0) != 0) { + try { + Execute ("commit"); + } + catch { + // Force a rollback since most people don't know this function can fail + // Don't call Rollback() since the _transactionDepth is 0 and it won't try + // Calling rollback makes our _transactionDepth variable correct. + try { + Execute ("rollback"); + } + catch { + // rollback can fail in all sorts of wonderful version-dependent ways. Let's just hope for the best + } + throw; + } + } + // Do nothing on a commit with no open transaction + } + + /// + /// Executes within a (possibly nested) transaction by wrapping it in a SAVEPOINT. If an + /// exception occurs the whole transaction is rolled back, not just the current savepoint. The exception + /// is rethrown. + /// + /// + /// The to perform within a transaction. can contain any number + /// of operations on the connection but should never call or + /// . + /// + public void RunInTransaction (Action action) + { + try { + var savePoint = SaveTransactionPoint (); + action (); + Release (savePoint); + } + catch (Exception) { + Rollback (); + throw; + } + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) { + RunInTransaction (() => { + foreach (var r in objects) { + c += Insert (r); + } + }); + } + else { + foreach (var r in objects) { + c += Insert (r); + } + } + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects, string extra, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) { + RunInTransaction (() => { + foreach (var r in objects) { + c += Insert (r, extra); + } + }); + } + else { + foreach (var r in objects) { + c += Insert (r, extra); + } + } + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects, Type objType, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) { + RunInTransaction (() => { + foreach (var r in objects) { + c += Insert (r, objType); + } + }); + } + else { + foreach (var r in objects) { + c += Insert (r, objType); + } + } + return c; + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj) + { + if (obj == null) { + return 0; + } + return Insert (obj, "", Orm.GetType (obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace (object obj) + { + if (obj == null) { + return 0; + } + return Insert (obj, "OR REPLACE", Orm.GetType (obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, Type objType) + { + return Insert (obj, "", objType); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace (object obj, Type objType) + { + return Insert (obj, "OR REPLACE", objType); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, string extra) + { + if (obj == null) { + return 0; + } + return Insert (obj, extra, Orm.GetType (obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, string extra, Type objType) + { + if (obj == null || objType == null) { + return 0; + } + + var map = GetMapping (objType); + + if (map.PK != null && map.PK.IsAutoGuid) { + if (map.PK.GetValue (obj).Equals (Guid.Empty)) { + map.PK.SetValue (obj, Guid.NewGuid ()); + } + } + + var replacing = string.Compare (extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + var cols = replacing ? map.InsertOrReplaceColumns : map.InsertColumns; + var vals = new object[cols.Length]; + for (var i = 0; i < vals.Length; i++) { + vals[i] = cols[i].GetValue (obj); + } + + var insertCmd = GetInsertCommand (map, extra); + int count; + + lock (insertCmd) { + // We lock here to protect the prepared statement returned via GetInsertCommand. + // A SQLite prepared statement can be bound for only one operation at a time. + try { + count = insertCmd.ExecuteNonQuery (vals); + } + catch (SQLiteException ex) { + if (SQLite3.ExtendedErrCode (this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (ex.Result, ex.Message, map, obj); + } + throw; + } + + if (map.HasAutoIncPK) { + var id = SQLite3.LastInsertRowid (Handle); + map.SetAutoIncPK (obj, id); + } + } + if (count > 0) + OnTableChanged (map, NotifyTableChangedAction.Insert); + + return count; + } + + readonly Dictionary, PreparedSqlLiteInsertCommand> _insertCommandMap = new Dictionary, PreparedSqlLiteInsertCommand> (); + + PreparedSqlLiteInsertCommand GetInsertCommand (TableMapping map, string extra) + { + PreparedSqlLiteInsertCommand prepCmd; + + var key = Tuple.Create (map.MappedType.FullName, extra); + + lock (_insertCommandMap) { + if (_insertCommandMap.TryGetValue (key, out prepCmd)) { + return prepCmd; + } + } + + prepCmd = CreateInsertCommand (map, extra); + + lock (_insertCommandMap) { + if (_insertCommandMap.TryGetValue (key, out var existing)) { + prepCmd.Dispose (); + return existing; + } + + _insertCommandMap.Add (key, prepCmd); + } + + return prepCmd; + } + + PreparedSqlLiteInsertCommand CreateInsertCommand (TableMapping map, string extra) + { + var cols = map.InsertColumns; + string insertSql; + if (cols.Length == 0 && map.Columns.Length == 1 && map.Columns[0].IsAutoInc) { + insertSql = string.Format ("insert {1} into \"{0}\" default values", map.TableName, extra); + } + else { + var replacing = string.Compare (extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + if (replacing) { + cols = map.InsertOrReplaceColumns; + } + + insertSql = string.Format ("insert {3} into \"{0}\"({1}) values ({2})", map.TableName, + string.Join (",", (from c in cols + select "\"" + c.Name + "\"").ToArray ()), + string.Join (",", (from c in cols + select "?").ToArray ()), extra); + + } + + var insertCommand = new PreparedSqlLiteInsertCommand (this, insertSql); + return insertCommand; + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows updated. + /// + public int Update (object obj) + { + if (obj == null) { + return 0; + } + return Update (obj, Orm.GetType (obj)); + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows updated. + /// + public int Update (object obj, Type objType) + { + int rowsAffected = 0; + if (obj == null || objType == null) { + return 0; + } + + var map = GetMapping (objType); + + var pk = map.PK; + + if (pk == null) { + throw new NotSupportedException ("Cannot update " + map.TableName + ": it has no PK"); + } + + var cols = from p in map.Columns + where p != pk + select p; + var vals = from c in cols + select c.GetValue (obj); + var ps = new List (vals); + if (ps.Count == 0) { + // There is a PK but no accompanying data, + // so reset the PK to make the UPDATE work. + cols = map.Columns; + vals = from c in cols + select c.GetValue (obj); + ps = new List (vals); + } + ps.Add (pk.GetValue (obj)); + var q = string.Format ("update \"{0}\" set {1} where \"{2}\" = ? ", map.TableName, string.Join (",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray ()), pk.Name); + + try { + rowsAffected = Execute (q, ps.ToArray ()); + } + catch (SQLiteException ex) { + + if (ex.Result == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode (this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (ex, map, obj); + } + + throw ex; + } + + if (rowsAffected > 0) + OnTableChanged (map, NotifyTableChangedAction.Update); + + return rowsAffected; + } + + /// + /// Updates all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction + /// + /// + /// The number of rows modified. + /// + public int UpdateAll (System.Collections.IEnumerable objects, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) { + RunInTransaction (() => { + foreach (var r in objects) { + c += Update (r); + } + }); + } + else { + foreach (var r in objects) { + c += Update (r); + } + } + return c; + } + + /// + /// Deletes the given object from the database using its primary key. + /// + /// + /// The object to delete. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows deleted. + /// + public int Delete (object objectToDelete) + { + var map = GetMapping (Orm.GetType (objectToDelete)); + var pk = map.PK; + if (pk == null) { + throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + var count = Execute (q, pk.GetValue (objectToDelete)); + if (count > 0) + OnTableChanged (map, NotifyTableChangedAction.Delete); + return count; + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of object. + /// + public int Delete (object primaryKey) + { + return Delete (primaryKey, GetMapping (typeof (T))); + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public int Delete (object primaryKey, TableMapping map) + { + var pk = map.PK; + if (pk == null) { + throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + var count = Execute (q, primaryKey); + if (count > 0) + OnTableChanged (map, NotifyTableChangedAction.Delete); + return count; + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of objects to delete. + /// + public int DeleteAll () + { + var map = GetMapping (typeof (T)); + return DeleteAll (map); + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public int DeleteAll (TableMapping map) + { + var query = string.Format ("delete from \"{0}\"", map.TableName); + var count = Execute (query); + if (count > 0) + OnTableChanged (map, NotifyTableChangedAction.Delete); + return count; + } + + /// + /// Backup the entire database to the specified path. + /// + /// Path to backup file. + /// The name of the database to backup (usually "main"). + public void Backup (string destinationDatabasePath, string databaseName = "main") + { + // Open the destination + var r = SQLite3.Open (destinationDatabasePath, out var destHandle); + if (r != SQLite3.Result.OK) { + throw SQLiteException.New (r, "Failed to open destination database"); + } + + // Init the backup + var backup = SQLite3.BackupInit (destHandle, databaseName, Handle, databaseName); + if (backup == NullBackupHandle) { + SQLite3.Close (destHandle); + throw new Exception ("Failed to create backup"); + } + + // Perform it + SQLite3.BackupStep (backup, -1); + SQLite3.BackupFinish (backup); + + // Check for errors + r = SQLite3.GetResult (destHandle); + string msg = ""; + if (r != SQLite3.Result.OK) { + msg = SQLite3.GetErrmsg (destHandle); + } + + // Close everything and report errors + SQLite3.Close (destHandle); + if (r != SQLite3.Result.OK) { + throw SQLiteException.New (r, msg); + } + } + + ~SQLiteConnection () + { + Dispose (false); + } + + public void Dispose () + { + Dispose (true); + GC.SuppressFinalize (this); + } + + public void Close () + { + Dispose (true); + } + + protected virtual void Dispose (bool disposing) + { + var useClose2 = LibVersionNumber >= 3007014; + + if (_open && Handle != NullHandle) { + try { + if (disposing) { + lock (_insertCommandMap) { + foreach (var sqlInsertCommand in _insertCommandMap.Values) { + sqlInsertCommand.Dispose (); + } + _insertCommandMap.Clear (); + } + + var r = useClose2 ? SQLite3.Close2 (Handle) : SQLite3.Close (Handle); + if (r != SQLite3.Result.OK) { + string msg = SQLite3.GetErrmsg (Handle); + throw SQLiteException.New (r, msg); + } + } + else { + var r = useClose2 ? SQLite3.Close2 (Handle) : SQLite3.Close (Handle); + } + } + finally { + Handle = NullHandle; + _open = false; + } + } + } + + void OnTableChanged (TableMapping table, NotifyTableChangedAction action) + { + var ev = TableChanged; + if (ev != null) + ev (this, new NotifyTableChangedEventArgs (table, action)); + } + + public event EventHandler TableChanged; + } + + public class NotifyTableChangedEventArgs : EventArgs + { + public TableMapping Table { get; private set; } + public NotifyTableChangedAction Action { get; private set; } + + public NotifyTableChangedEventArgs (TableMapping table, NotifyTableChangedAction action) + { + Table = table; + Action = action; + } + } + + public enum NotifyTableChangedAction + { + Insert, + Update, + Delete, + } + + /// + /// Represents a parsed connection string. + /// + public class SQLiteConnectionString + { + const string DateTimeSqliteDefaultFormat = "yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fff"; + + public string UniqueKey { get; } + public string DatabasePath { get; } + public bool StoreDateTimeAsTicks { get; } + public bool StoreTimeSpanAsTicks { get; } + public string DateTimeStringFormat { get; } + public System.Globalization.DateTimeStyles DateTimeStyle { get; } + public object Key { get; } + public SQLiteOpenFlags OpenFlags { get; } + public Action PreKeyAction { get; } + public Action PostKeyAction { get; } + public string VfsName { get; } + +#if NETFX_CORE + static readonly string MetroStyleDataPath = Windows.Storage.ApplicationData.Current.LocalFolder.Path; + + public static readonly string[] InMemoryDbPaths = new[] { - Instance = new SQLite(); + ":memory:", + "file::memory:" + }; + + public static bool IsInMemoryPath(string databasePath) + { + return InMemoryDbPaths.Any(i => i.Equals(databasePath, StringComparison.OrdinalIgnoreCase)); } - public SQLite() - { - m_ConnectionLock = new ReaderWriterLockSlim(); +#endif - var dataSource = Program.ConfigLocation; - m_Connection = new SQLiteConnection($"Data Source=\"{dataSource}\";Version=3;PRAGMA locking_mode=NORMAL;PRAGMA busy_timeout=5000", true); - } + /// + /// Constructs a new SQLiteConnectionString with all the data needed to open an SQLiteConnection. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + public SQLiteConnectionString (string databasePath, bool storeDateTimeAsTicks = true) + : this (databasePath, SQLiteOpenFlags.Create | SQLiteOpenFlags.ReadWrite, storeDateTimeAsTicks) + { + } - internal void Init() - { - m_Connection.Open(); - } + /// + /// Constructs a new SQLiteConnectionString with all the data needed to open an SQLiteConnection. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + /// + /// Executes prior to setting key for SQLCipher databases + /// + /// + /// Executes after setting key for SQLCipher databases + /// + /// + /// Specifies the Virtual File System to use on the database. + /// + public SQLiteConnectionString (string databasePath, bool storeDateTimeAsTicks, object key = null, Action preKeyAction = null, Action postKeyAction = null, string vfsName = null) + : this (databasePath, SQLiteOpenFlags.Create | SQLiteOpenFlags.ReadWrite, storeDateTimeAsTicks, key, preKeyAction, postKeyAction, vfsName) + { + } - internal void Exit() - { - m_Connection.Close(); - m_Connection.Dispose(); - } + /// + /// Constructs a new SQLiteConnectionString with all the data needed to open an SQLiteConnection. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Flags controlling how the connection should be opened. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + /// + /// Executes prior to setting key for SQLCipher databases + /// + /// + /// Executes after setting key for SQLCipher databases + /// + /// + /// Specifies the Virtual File System to use on the database. + /// + /// + /// Specifies the format to use when storing DateTime properties as strings. + /// + /// + /// Specifies whether to store TimeSpan properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeTimeSpanAsTicks = true. + /// + public SQLiteConnectionString (string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks, object key = null, Action preKeyAction = null, Action postKeyAction = null, string vfsName = null, string dateTimeStringFormat = DateTimeSqliteDefaultFormat, bool storeTimeSpanAsTicks = true) + { + if (key != null && !((key is byte[]) || (key is string))) + throw new ArgumentException ("Encryption keys must be strings or byte arrays", nameof (key)); - public void Execute(IJavascriptCallback callback, string sql, IDictionary args = null) - { - try - { - m_ConnectionLock.EnterReadLock(); - try - { - using (var command = new SQLiteCommand(sql, m_Connection)) - { - if (args != null) - { - foreach (var arg in args) - { - command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); - } - } - using (var reader = command.ExecuteReader()) - { - while (reader.Read() == true) - { - var values = new object[reader.FieldCount]; - reader.GetValues(values); - if (callback.CanExecute == true) - { - callback.ExecuteAsync(null, values); - } - } - } - } - if (callback.CanExecute == true) - { - callback.ExecuteAsync(null, null); - } - } - finally - { - m_ConnectionLock.ExitReadLock(); - } - } - catch (Exception e) - { - if (callback.CanExecute == true) - { - callback.ExecuteAsync(e.Message, null); - } - } + UniqueKey = string.Format ("{0}_{1:X8}", databasePath, (uint)openFlags); + StoreDateTimeAsTicks = storeDateTimeAsTicks; + StoreTimeSpanAsTicks = storeTimeSpanAsTicks; + DateTimeStringFormat = dateTimeStringFormat; + DateTimeStyle = "o".Equals (DateTimeStringFormat, StringComparison.OrdinalIgnoreCase) || "r".Equals (DateTimeStringFormat, StringComparison.OrdinalIgnoreCase) ? System.Globalization.DateTimeStyles.RoundtripKind : System.Globalization.DateTimeStyles.None; + Key = key; + PreKeyAction = preKeyAction; + PostKeyAction = postKeyAction; + OpenFlags = openFlags; + VfsName = vfsName; - callback.Dispose(); - } +#if NETFX_CORE + DatabasePath = IsInMemoryPath(databasePath) + ? databasePath + : System.IO.Path.Combine(MetroStyleDataPath, databasePath); - public void Execute(Action callback, string sql, IDictionary args = null) - { - m_ConnectionLock.EnterReadLock(); - try - { - using (var command = new SQLiteCommand(sql, m_Connection)) - { - if (args != null) - { - foreach (var arg in args) - { - command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); - } - } - using (var reader = command.ExecuteReader()) - { - while (reader.Read() == true) - { - var values = new object[reader.FieldCount]; - reader.GetValues(values); - callback(values); - } - } - } - } - catch - { - } - finally - { - m_ConnectionLock.ExitReadLock(); - } - } +#else + DatabasePath = databasePath; +#endif + } + } - public int ExecuteNonQuery(string sql, IDictionary args = null) - { - int result = -1; + [AttributeUsage (AttributeTargets.Class)] + public class TableAttribute : Attribute + { + public string Name { get; set; } - m_ConnectionLock.EnterWriteLock(); - try - { - using (var command = new SQLiteCommand(sql, m_Connection)) - { - if (args != null) - { - foreach (var arg in args) - { - command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); - } - } - result = command.ExecuteNonQuery(); - } - } - finally - { - m_ConnectionLock.ExitWriteLock(); - } + /// + /// Flag whether to create the table without rowid (see https://sqlite.org/withoutrowid.html) + /// + /// The default is false so that sqlite adds an implicit rowid to every table created. + /// + public bool WithoutRowId { get; set; } - return result; - } - } + public TableAttribute (string name) + { + Name = name; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class ColumnAttribute : Attribute + { + public string Name { get; set; } + + public ColumnAttribute (string name) + { + Name = name; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class PrimaryKeyAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class AutoIncrementAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class IndexedAttribute : Attribute + { + public string Name { get; set; } + public int Order { get; set; } + public virtual bool Unique { get; set; } + + public IndexedAttribute () + { + } + + public IndexedAttribute (string name, int order) + { + Name = name; + Order = order; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class IgnoreAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class UniqueAttribute : IndexedAttribute + { + public override bool Unique { + get { return true; } + set { /* throw? */ } + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class MaxLengthAttribute : Attribute + { + public int Value { get; private set; } + + public MaxLengthAttribute (int length) + { + Value = length; + } + } + + public sealed class PreserveAttribute : System.Attribute + { + public bool AllMembers; + public bool Conditional; + } + + /// + /// Select the collating sequence to use on a column. + /// "BINARY", "NOCASE", and "RTRIM" are supported. + /// "BINARY" is the default. + /// + [AttributeUsage (AttributeTargets.Property)] + public class CollationAttribute : Attribute + { + public string Value { get; private set; } + + public CollationAttribute (string collation) + { + Value = collation; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class NotNullAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Enum)] + public class StoreAsTextAttribute : Attribute + { + } + + public class TableMapping + { + public Type MappedType { get; private set; } + + public string TableName { get; private set; } + + public bool WithoutRowId { get; private set; } + + public Column[] Columns { get; private set; } + + public Column PK { get; private set; } + + public string GetByPrimaryKeySql { get; private set; } + + public CreateFlags CreateFlags { get; private set; } + + internal MapMethod Method { get; private set; } = MapMethod.ByName; + + readonly Column _autoPk; + readonly Column[] _insertColumns; + readonly Column[] _insertOrReplaceColumns; + + public TableMapping (Type type, CreateFlags createFlags = CreateFlags.None) + { + MappedType = type; + CreateFlags = createFlags; + + var typeInfo = type.GetTypeInfo (); +#if ENABLE_IL2CPP + var tableAttr = typeInfo.GetCustomAttribute (); +#else + var tableAttr = + typeInfo.CustomAttributes + .Where (x => x.AttributeType == typeof (TableAttribute)) + .Select (x => (TableAttribute)Orm.InflateAttribute (x)) + .FirstOrDefault (); +#endif + + TableName = (tableAttr != null && !string.IsNullOrEmpty (tableAttr.Name)) ? tableAttr.Name : MappedType.Name; + WithoutRowId = tableAttr != null ? tableAttr.WithoutRowId : false; + + var members = GetPublicMembers(type); + var cols = new List(members.Count); + foreach(var m in members) + { + var ignore = m.IsDefined(typeof(IgnoreAttribute), true); + if(!ignore) + cols.Add(new Column(m, createFlags)); + } + Columns = cols.ToArray (); + foreach (var c in Columns) { + if (c.IsAutoInc && c.IsPK) { + _autoPk = c; + } + if (c.IsPK) { + PK = c; + } + } + + HasAutoIncPK = _autoPk != null; + + if (PK != null) { + GetByPrimaryKeySql = string.Format ("select * from \"{0}\" where \"{1}\" = ?", TableName, PK.Name); + } + else { + // People should not be calling Get/Find without a PK + GetByPrimaryKeySql = string.Format ("select * from \"{0}\" limit 1", TableName); + } + + _insertColumns = Columns.Where (c => !c.IsAutoInc).ToArray (); + _insertOrReplaceColumns = Columns.ToArray (); + } + + private IReadOnlyCollection GetPublicMembers(Type type) + { + if(type.Name.StartsWith("ValueTuple`")) + return GetFieldsFromValueTuple(type); + + var members = new List(); + var memberNames = new HashSet(); + var newMembers = new List(); + do + { + var ti = type.GetTypeInfo(); + newMembers.Clear(); + + newMembers.AddRange( + from p in ti.DeclaredProperties + where !memberNames.Contains(p.Name) && + p.CanRead && p.CanWrite && + p.GetMethod != null && p.SetMethod != null && + p.GetMethod.IsPublic && p.SetMethod.IsPublic && + !p.GetMethod.IsStatic && !p.SetMethod.IsStatic + select p); + + members.AddRange(newMembers); + foreach(var m in newMembers) + memberNames.Add(m.Name); + + type = ti.BaseType; + } + while(type != typeof(object)); + + return members; + } + + private IReadOnlyCollection GetFieldsFromValueTuple(Type type) + { + Method = MapMethod.ByPosition; + var fields = type.GetFields(); + + // https://docs.microsoft.com/en-us/dotnet/api/system.valuetuple-8.rest + if(fields.Length >= 8) + throw new NotSupportedException("ValueTuple with more than 7 members not supported due to nesting; see https://docs.microsoft.com/en-us/dotnet/api/system.valuetuple-8.rest"); + + return fields; + } + + public bool HasAutoIncPK { get; private set; } + + public void SetAutoIncPK (object obj, long id) + { + if (_autoPk != null) { + _autoPk.SetValue (obj, Convert.ChangeType (id, _autoPk.ColumnType, null)); + } + } + + public Column[] InsertColumns { + get { + return _insertColumns; + } + } + + public Column[] InsertOrReplaceColumns { + get { + return _insertOrReplaceColumns; + } + } + + public Column FindColumnWithPropertyName (string propertyName) + { + var exact = Columns.FirstOrDefault (c => c.PropertyName == propertyName); + return exact; + } + + public Column FindColumn (string columnName) + { + if(Method != MapMethod.ByName) + throw new InvalidOperationException($"This {nameof(TableMapping)} is not mapped by name, but {Method}."); + + var exact = Columns.FirstOrDefault (c => c.Name.ToLower () == columnName.ToLower ()); + return exact; + } + + public class Column + { + MemberInfo _member; + + public string Name { get; private set; } + + public PropertyInfo PropertyInfo => _member as PropertyInfo; + + public string PropertyName { get { return _member.Name; } } + + public Type ColumnType { get; private set; } + + public string Collation { get; private set; } + + public bool IsAutoInc { get; private set; } + public bool IsAutoGuid { get; private set; } + + public bool IsPK { get; private set; } + + public IEnumerable Indices { get; set; } + + public bool IsNullable { get; private set; } + + public int? MaxStringLength { get; private set; } + + public bool StoreAsText { get; private set; } + + public Column (MemberInfo member, CreateFlags createFlags = CreateFlags.None) + { + _member = member; + var memberType = GetMemberType(member); + + var colAttr = member.CustomAttributes.FirstOrDefault (x => x.AttributeType == typeof (ColumnAttribute)); +#if ENABLE_IL2CPP + var ca = member.GetCustomAttribute(typeof(ColumnAttribute)) as ColumnAttribute; + Name = ca == null ? member.Name : ca.Name; +#else + Name = (colAttr != null && colAttr.ConstructorArguments.Count > 0) ? + colAttr.ConstructorArguments[0].Value?.ToString () : + member.Name; +#endif + //If this type is Nullable then Nullable.GetUnderlyingType returns the T, otherwise it returns null, so get the actual type instead + ColumnType = Nullable.GetUnderlyingType (memberType) ?? memberType; + Collation = Orm.Collation (member); + + IsPK = Orm.IsPK (member) || + (((createFlags & CreateFlags.ImplicitPK) == CreateFlags.ImplicitPK) && + string.Compare (member.Name, Orm.ImplicitPkName, StringComparison.OrdinalIgnoreCase) == 0); + + var isAuto = Orm.IsAutoInc (member) || (IsPK && ((createFlags & CreateFlags.AutoIncPK) == CreateFlags.AutoIncPK)); + IsAutoGuid = isAuto && ColumnType == typeof (Guid); + IsAutoInc = isAuto && !IsAutoGuid; + + Indices = Orm.GetIndices (member); + if (!Indices.Any () + && !IsPK + && ((createFlags & CreateFlags.ImplicitIndex) == CreateFlags.ImplicitIndex) + && Name.EndsWith (Orm.ImplicitIndexSuffix, StringComparison.OrdinalIgnoreCase) + ) { + Indices = new IndexedAttribute[] { new IndexedAttribute () }; + } + IsNullable = !(IsPK || Orm.IsMarkedNotNull (member)); + MaxStringLength = Orm.MaxStringLength (member); + + StoreAsText = memberType.GetTypeInfo ().CustomAttributes.Any (x => x.AttributeType == typeof (StoreAsTextAttribute)); + } + + public Column (PropertyInfo member, CreateFlags createFlags = CreateFlags.None) + : this((MemberInfo)member, createFlags) + { } + + public void SetValue (object obj, object val) + { + if(_member is PropertyInfo propy) + { + if (val != null && ColumnType.GetTypeInfo ().IsEnum) + propy.SetValue (obj, Enum.ToObject (ColumnType, val)); + else + propy.SetValue (obj, val); + } + else if(_member is FieldInfo field) + { + if (val != null && ColumnType.GetTypeInfo ().IsEnum) + field.SetValue (obj, Enum.ToObject (ColumnType, val)); + else + field.SetValue (obj, val); + } + else + throw new InvalidProgramException("unreachable condition"); + } + + public object GetValue (object obj) + { + if(_member is PropertyInfo propy) + return propy.GetValue(obj); + else if(_member is FieldInfo field) + return field.GetValue(obj); + else + throw new InvalidProgramException("unreachable condition"); + } + + private static Type GetMemberType(MemberInfo m) + { + switch(m.MemberType) + { + case MemberTypes.Property: return ((PropertyInfo)m).PropertyType; + case MemberTypes.Field: return ((FieldInfo)m).FieldType; + default: throw new InvalidProgramException($"{nameof(TableMapping)} supports properties or fields only."); + } + } + } + + internal enum MapMethod + { + ByName, + ByPosition + } + } + + class EnumCacheInfo + { + public EnumCacheInfo (Type type) + { + var typeInfo = type.GetTypeInfo (); + + IsEnum = typeInfo.IsEnum; + + if (IsEnum) { + StoreAsText = typeInfo.CustomAttributes.Any (x => x.AttributeType == typeof (StoreAsTextAttribute)); + + if (StoreAsText) { + EnumValues = new Dictionary (); + foreach (object e in Enum.GetValues (type)) { + EnumValues[Convert.ToInt32 (e)] = e.ToString (); + } + } + } + } + + public bool IsEnum { get; private set; } + + public bool StoreAsText { get; private set; } + + public Dictionary EnumValues { get; private set; } + } + + static class EnumCache + { + static readonly Dictionary Cache = new Dictionary (); + + public static EnumCacheInfo GetInfo () + { + return GetInfo (typeof (T)); + } + + public static EnumCacheInfo GetInfo (Type type) + { + lock (Cache) { + EnumCacheInfo info = null; + if (!Cache.TryGetValue (type, out info)) { + info = new EnumCacheInfo (type); + Cache[type] = info; + } + + return info; + } + } + } + + public static class Orm + { + public const int DefaultMaxStringLength = 140; + public const string ImplicitPkName = "Id"; + public const string ImplicitIndexSuffix = "Id"; + + public static Type GetType (object obj) + { + if (obj == null) + return typeof (object); + var rt = obj as IReflectableType; + if (rt != null) + return rt.GetTypeInfo ().AsType (); + return obj.GetType (); + } + + public static string SqlDecl (TableMapping.Column p, bool storeDateTimeAsTicks, bool storeTimeSpanAsTicks) + { + string decl = "\"" + p.Name + "\" " + SqlType (p, storeDateTimeAsTicks, storeTimeSpanAsTicks) + " "; + + if (p.IsPK) { + decl += "primary key "; + } + if (p.IsAutoInc) { + decl += "autoincrement "; + } + if (!p.IsNullable) { + decl += "not null "; + } + if (!string.IsNullOrEmpty (p.Collation)) { + decl += "collate " + p.Collation + " "; + } + + return decl; + } + + public static string SqlType (TableMapping.Column p, bool storeDateTimeAsTicks, bool storeTimeSpanAsTicks) + { + var clrType = p.ColumnType; + if (clrType == typeof (Boolean) || clrType == typeof (Byte) || clrType == typeof (UInt16) || clrType == typeof (SByte) || clrType == typeof (Int16) || clrType == typeof (Int32) || clrType == typeof (UInt32) || clrType == typeof (Int64)) { + return "integer"; + } + else if (clrType == typeof (Single) || clrType == typeof (Double) || clrType == typeof (Decimal)) { + return "float"; + } + else if (clrType == typeof (String) || clrType == typeof (StringBuilder) || clrType == typeof (Uri) || clrType == typeof (UriBuilder)) { + int? len = p.MaxStringLength; + + if (len.HasValue) + return "varchar(" + len.Value + ")"; + + return "varchar"; + } + else if (clrType == typeof (TimeSpan)) { + return storeTimeSpanAsTicks ? "bigint" : "time"; + } + else if (clrType == typeof (DateTime)) { + return storeDateTimeAsTicks ? "bigint" : "datetime"; + } + else if (clrType == typeof (DateTimeOffset)) { + return "bigint"; + } + else if (clrType.GetTypeInfo ().IsEnum) { + if (p.StoreAsText) + return "varchar"; + else + return "integer"; + } + else if (clrType == typeof (byte[])) { + return "blob"; + } + else if (clrType == typeof (Guid)) { + return "varchar(36)"; + } + else { + throw new NotSupportedException ("Don't know about " + clrType); + } + } + + public static bool IsPK (MemberInfo p) + { + return p.CustomAttributes.Any (x => x.AttributeType == typeof (PrimaryKeyAttribute)); + } + + public static string Collation (MemberInfo p) + { +#if ENABLE_IL2CPP + return (p.GetCustomAttribute ()?.Value) ?? ""; +#else + return + (p.CustomAttributes + .Where (x => typeof (CollationAttribute) == x.AttributeType) + .Select (x => { + var args = x.ConstructorArguments; + return args.Count > 0 ? ((args[0].Value as string) ?? "") : ""; + }) + .FirstOrDefault ()) ?? ""; +#endif + } + + public static bool IsAutoInc (MemberInfo p) + { + return p.CustomAttributes.Any (x => x.AttributeType == typeof (AutoIncrementAttribute)); + } + + public static FieldInfo GetField (TypeInfo t, string name) + { + var f = t.GetDeclaredField (name); + if (f != null) + return f; + return GetField (t.BaseType.GetTypeInfo (), name); + } + + public static PropertyInfo GetProperty (TypeInfo t, string name) + { + var f = t.GetDeclaredProperty (name); + if (f != null) + return f; + return GetProperty (t.BaseType.GetTypeInfo (), name); + } + + public static object InflateAttribute (CustomAttributeData x) + { + var atype = x.AttributeType; + var typeInfo = atype.GetTypeInfo (); +#if ENABLE_IL2CPP + var r = Activator.CreateInstance (x.AttributeType); +#else + var args = x.ConstructorArguments.Select (a => a.Value).ToArray (); + var r = Activator.CreateInstance (x.AttributeType, args); + foreach (var arg in x.NamedArguments) { + if (arg.IsField) { + GetField (typeInfo, arg.MemberName).SetValue (r, arg.TypedValue.Value); + } + else { + GetProperty (typeInfo, arg.MemberName).SetValue (r, arg.TypedValue.Value); + } + } +#endif + return r; + } + + public static IEnumerable GetIndices (MemberInfo p) + { +#if ENABLE_IL2CPP + return p.GetCustomAttributes (); +#else + var indexedInfo = typeof (IndexedAttribute).GetTypeInfo (); + return + p.CustomAttributes + .Where (x => indexedInfo.IsAssignableFrom (x.AttributeType.GetTypeInfo ())) + .Select (x => (IndexedAttribute)InflateAttribute (x)); +#endif + } + + public static int? MaxStringLength (MemberInfo p) + { +#if ENABLE_IL2CPP + return p.GetCustomAttribute ()?.Value; +#else + var attr = p.CustomAttributes.FirstOrDefault (x => x.AttributeType == typeof (MaxLengthAttribute)); + if (attr != null) { + var attrv = (MaxLengthAttribute)InflateAttribute (attr); + return attrv.Value; + } + return null; +#endif + } + + public static int? MaxStringLength (PropertyInfo p) => MaxStringLength((MemberInfo)p); + + public static bool IsMarkedNotNull (MemberInfo p) + { + return p.CustomAttributes.Any (x => x.AttributeType == typeof (NotNullAttribute)); + } + } + + public partial class SQLiteCommand + { + SQLiteConnection _conn; + private List _bindings; + + public string CommandText { get; set; } + + public SQLiteCommand (SQLiteConnection conn) + { + _conn = conn; + _bindings = new List (); + CommandText = ""; + } + + public int ExecuteNonQuery () + { + if (_conn.Trace) { + _conn.Tracer?.Invoke ("Executing: " + this); + } + + var r = SQLite3.Result.OK; + var stmt = Prepare (); + r = SQLite3.Step (stmt); + Finalize (stmt); + if (r == SQLite3.Result.Done) { + int rowsAffected = SQLite3.Changes (_conn.Handle); + return rowsAffected; + } + else if (r == SQLite3.Result.Error) { + string msg = SQLite3.GetErrmsg (_conn.Handle); + throw SQLiteException.New (r, msg); + } + else if (r == SQLite3.Result.Constraint) { + if (SQLite3.ExtendedErrCode (_conn.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + } + } + + throw SQLiteException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + } + + public IEnumerable ExecuteDeferredQuery () + { + return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))); + } + + public List ExecuteQuery () + { + return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))).ToList (); + } + + public List ExecuteQuery (TableMapping map) + { + return ExecuteDeferredQuery (map).ToList (); + } + + /// + /// Invoked every time an instance is loaded from the database. + /// + /// + /// The newly created object. + /// + /// + /// This can be overridden in combination with the + /// method to hook into the life-cycle of objects. + /// + protected virtual void OnInstanceCreated (object obj) + { + // Can be overridden. + } + + public IEnumerable ExecuteDeferredQuery (TableMapping map) + { + if (_conn.Trace) { + _conn.Tracer?.Invoke ("Executing Query: " + this); + } + + var stmt = Prepare (); + try { + var cols = new TableMapping.Column[SQLite3.ColumnCount (stmt)]; + var fastColumnSetters = new Action[SQLite3.ColumnCount (stmt)]; + + if (map.Method == TableMapping.MapMethod.ByPosition) + { + Array.Copy(map.Columns, cols, Math.Min(cols.Length, map.Columns.Length)); + } + else if (map.Method == TableMapping.MapMethod.ByName) { + MethodInfo getSetter = null; + if (typeof(T) != map.MappedType) { + getSetter = typeof(FastColumnSetter) + .GetMethod (nameof(FastColumnSetter.GetFastSetter), + BindingFlags.NonPublic | BindingFlags.Static).MakeGenericMethod (map.MappedType); + } + + for (int i = 0; i < cols.Length; i++) { + var name = SQLite3.ColumnName16 (stmt, i); + cols[i] = map.FindColumn (name); + if (cols[i] != null) + if (getSetter != null) { + fastColumnSetters[i] = (Action)getSetter.Invoke(null, new object[]{ _conn, cols[i]}); + } + else { + fastColumnSetters[i] = FastColumnSetter.GetFastSetter(_conn, cols[i]); + } + } + } + + while (SQLite3.Step (stmt) == SQLite3.Result.Row) { + var obj = Activator.CreateInstance (map.MappedType); + for (int i = 0; i < cols.Length; i++) { + if (cols[i] == null) + continue; + + if (fastColumnSetters[i] != null) { + fastColumnSetters[i].Invoke (obj, stmt, i); + } + else { + var colType = SQLite3.ColumnType (stmt, i); + var val = ReadCol (stmt, i, colType, cols[i].ColumnType); + cols[i].SetValue (obj, val); + } + } + OnInstanceCreated (obj); + yield return (T)obj; + } + } + finally { + SQLite3.Finalize (stmt); + } + } + + public T ExecuteScalar () + { + if (_conn.Trace) { + _conn.Tracer?.Invoke ("Executing Query: " + this); + } + + T val = default (T); + + var stmt = Prepare (); + + try { + var r = SQLite3.Step (stmt); + if (r == SQLite3.Result.Row) { + var colType = SQLite3.ColumnType (stmt, 0); + var colval = ReadCol (stmt, 0, colType, typeof (T)); + if (colval != null) { + val = (T)colval; + } + } + else if (r == SQLite3.Result.Done) { + } + else { + throw SQLiteException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + } + } + finally { + Finalize (stmt); + } + + return val; + } + + public IEnumerable ExecuteQueryScalars () + { + if (_conn.Trace) { + _conn.Tracer?.Invoke ("Executing Query: " + this); + } + var stmt = Prepare (); + try { + if (SQLite3.ColumnCount (stmt) < 1) { + throw new InvalidOperationException ("QueryScalars should return at least one column"); + } + while (SQLite3.Step (stmt) == SQLite3.Result.Row) { + var colType = SQLite3.ColumnType (stmt, 0); + var val = ReadCol (stmt, 0, colType, typeof (T)); + if (val == null) { + yield return default (T); + } + else { + yield return (T)val; + } + } + } + finally { + Finalize (stmt); + } + } + + public void Bind (string name, object val) + { + _bindings.Add (new Binding { + Name = name, + Value = val + }); + } + + public void Bind (object val) + { + Bind (null, val); + } + + public override string ToString () + { + var parts = new string[1 + _bindings.Count]; + parts[0] = CommandText; + var i = 1; + foreach (var b in _bindings) { + parts[i] = string.Format (" {0}: {1}", i - 1, b.Value); + i++; + } + return string.Join (Environment.NewLine, parts); + } + + Sqlite3Statement Prepare () + { + var stmt = SQLite3.Prepare2 (_conn.Handle, CommandText); + BindAll (stmt); + return stmt; + } + + void Finalize (Sqlite3Statement stmt) + { + SQLite3.Finalize (stmt); + } + + void BindAll (Sqlite3Statement stmt) + { + int nextIdx = 1; + foreach (var b in _bindings) { + if (b.Name != null) { + b.Index = SQLite3.BindParameterIndex (stmt, b.Name); + } + else { + b.Index = nextIdx++; + } + + BindParameter (stmt, b.Index, b.Value, _conn.StoreDateTimeAsTicks, _conn.DateTimeStringFormat, _conn.StoreTimeSpanAsTicks); + } + } + + static IntPtr NegativePointer = new IntPtr (-1); + + internal static void BindParameter (Sqlite3Statement stmt, int index, object value, bool storeDateTimeAsTicks, string dateTimeStringFormat, bool storeTimeSpanAsTicks) + { + if (value == null) { + SQLite3.BindNull (stmt, index); + } + else { + if (value is Int32) { + SQLite3.BindInt (stmt, index, (int)value); + } + else if (value is String) { + SQLite3.BindText (stmt, index, (string)value, -1, NegativePointer); + } + else if (value is Byte || value is UInt16 || value is SByte || value is Int16) { + SQLite3.BindInt (stmt, index, Convert.ToInt32 (value)); + } + else if (value is Boolean) { + SQLite3.BindInt (stmt, index, (bool)value ? 1 : 0); + } + else if (value is UInt32 || value is Int64) { + SQLite3.BindInt64 (stmt, index, Convert.ToInt64 (value)); + } + else if (value is Single || value is Double || value is Decimal) { + SQLite3.BindDouble (stmt, index, Convert.ToDouble (value)); + } + else if (value is TimeSpan) { + if (storeTimeSpanAsTicks) { + SQLite3.BindInt64 (stmt, index, ((TimeSpan)value).Ticks); + } + else { + SQLite3.BindText (stmt, index, ((TimeSpan)value).ToString (), -1, NegativePointer); + } + } + else if (value is DateTime) { + if (storeDateTimeAsTicks) { + SQLite3.BindInt64 (stmt, index, ((DateTime)value).Ticks); + } + else { + SQLite3.BindText (stmt, index, ((DateTime)value).ToString (dateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture), -1, NegativePointer); + } + } + else if (value is DateTimeOffset) { + SQLite3.BindInt64 (stmt, index, ((DateTimeOffset)value).UtcTicks); + } + else if (value is byte[]) { + SQLite3.BindBlob (stmt, index, (byte[])value, ((byte[])value).Length, NegativePointer); + } + else if (value is Guid) { + SQLite3.BindText (stmt, index, ((Guid)value).ToString (), 72, NegativePointer); + } + else if (value is Uri) { + SQLite3.BindText (stmt, index, ((Uri)value).ToString (), -1, NegativePointer); + } + else if (value is StringBuilder) { + SQLite3.BindText (stmt, index, ((StringBuilder)value).ToString (), -1, NegativePointer); + } + else if (value is UriBuilder) { + SQLite3.BindText (stmt, index, ((UriBuilder)value).ToString (), -1, NegativePointer); + } + else { + // Now we could possibly get an enum, retrieve cached info + var valueType = value.GetType (); + var enumInfo = EnumCache.GetInfo (valueType); + if (enumInfo.IsEnum) { + var enumIntValue = Convert.ToInt32 (value); + if (enumInfo.StoreAsText) + SQLite3.BindText (stmt, index, enumInfo.EnumValues[enumIntValue], -1, NegativePointer); + else + SQLite3.BindInt (stmt, index, enumIntValue); + } + else { + throw new NotSupportedException ("Cannot store type: " + Orm.GetType (value)); + } + } + } + } + + class Binding + { + public string Name { get; set; } + + public object Value { get; set; } + + public int Index { get; set; } + } + + object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clrType) + { + if (type == SQLite3.ColType.Null) { + return null; + } + else { + var clrTypeInfo = clrType.GetTypeInfo (); + if (clrTypeInfo.IsGenericType && clrTypeInfo.GetGenericTypeDefinition () == typeof (Nullable<>)) { + clrType = clrTypeInfo.GenericTypeArguments[0]; + clrTypeInfo = clrType.GetTypeInfo (); + } + + if (clrType == typeof (String)) { + return SQLite3.ColumnString (stmt, index); + } + else if (clrType == typeof (Int32)) { + return (int)SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (Boolean)) { + return SQLite3.ColumnInt (stmt, index) == 1; + } + else if (clrType == typeof (double)) { + return SQLite3.ColumnDouble (stmt, index); + } + else if (clrType == typeof (float)) { + return (float)SQLite3.ColumnDouble (stmt, index); + } + else if (clrType == typeof (TimeSpan)) { + if (_conn.StoreTimeSpanAsTicks) { + return new TimeSpan (SQLite3.ColumnInt64 (stmt, index)); + } + else { + var text = SQLite3.ColumnString (stmt, index); + TimeSpan resultTime; + if (!TimeSpan.TryParseExact (text, "c", System.Globalization.CultureInfo.InvariantCulture, System.Globalization.TimeSpanStyles.None, out resultTime)) { + resultTime = TimeSpan.Parse (text); + } + return resultTime; + } + } + else if (clrType == typeof (DateTime)) { + if (_conn.StoreDateTimeAsTicks) { + return new DateTime (SQLite3.ColumnInt64 (stmt, index)); + } + else { + var text = SQLite3.ColumnString (stmt, index); + DateTime resultDate; + if (!DateTime.TryParseExact (text, _conn.DateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture, _conn.DateTimeStyle, out resultDate)) { + resultDate = DateTime.Parse (text); + } + return resultDate; + } + } + else if (clrType == typeof (DateTimeOffset)) { + return new DateTimeOffset (SQLite3.ColumnInt64 (stmt, index), TimeSpan.Zero); + } + else if (clrTypeInfo.IsEnum) { + if (type == SQLite3.ColType.Text) { + var value = SQLite3.ColumnString (stmt, index); + return Enum.Parse (clrType, value.ToString (), true); + } + else + return SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (Int64)) { + return SQLite3.ColumnInt64 (stmt, index); + } + else if (clrType == typeof (UInt32)) { + return (uint)SQLite3.ColumnInt64 (stmt, index); + } + else if (clrType == typeof (decimal)) { + return (decimal)SQLite3.ColumnDouble (stmt, index); + } + else if (clrType == typeof (Byte)) { + return (byte)SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (UInt16)) { + return (ushort)SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (Int16)) { + return (short)SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (sbyte)) { + return (sbyte)SQLite3.ColumnInt (stmt, index); + } + else if (clrType == typeof (byte[])) { + return SQLite3.ColumnByteArray (stmt, index); + } + else if (clrType == typeof (Guid)) { + var text = SQLite3.ColumnString (stmt, index); + return new Guid (text); + } + else if (clrType == typeof (Uri)) { + var text = SQLite3.ColumnString (stmt, index); + return new Uri (text); + } + else if (clrType == typeof (StringBuilder)) { + var text = SQLite3.ColumnString (stmt, index); + return new StringBuilder (text); + } + else if (clrType == typeof (UriBuilder)) { + var text = SQLite3.ColumnString (stmt, index); + return new UriBuilder (text); + } + else { + throw new NotSupportedException ("Don't know how to read " + clrType); + } + } + } + } + + internal class FastColumnSetter + { + /// + /// Creates a delegate that can be used to quickly set object members from query columns. + /// + /// Note that this frontloads the slow reflection-based type checking for columns to only happen once at the beginning of a query, + /// and then afterwards each row of the query can invoke the delegate returned by this function to get much better performance (up to 10x speed boost, depending on query size and platform). + /// + /// The type of the destination object that the query will read into + /// The active connection. Note that this is primarily needed in order to read preferences regarding how certain data types (such as TimeSpan / DateTime) should be encoded in the database. + /// The table mapping used to map the statement column to a member of the destination object type + /// + /// A delegate for fast-setting of object members from statement columns. + /// + /// If no fast setter is available for the requested column (enums in particular cause headache), then this function returns null. + /// + internal static Action GetFastSetter (SQLiteConnection conn, TableMapping.Column column) + { + Action fastSetter = null; + + Type clrType = column.PropertyInfo.PropertyType; + + var clrTypeInfo = clrType.GetTypeInfo (); + if (clrTypeInfo.IsGenericType && clrTypeInfo.GetGenericTypeDefinition () == typeof (Nullable<>)) { + clrType = clrTypeInfo.GenericTypeArguments[0]; + clrTypeInfo = clrType.GetTypeInfo (); + } + + if (clrType == typeof (String)) { + fastSetter = CreateTypedSetterDelegate (column, (stmt, index) => { + return SQLite3.ColumnString (stmt, index); + }); + } + else if (clrType == typeof (Int32)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index)=>{ + return SQLite3.ColumnInt (stmt, index); + }); + } + else if (clrType == typeof (Boolean)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return SQLite3.ColumnInt (stmt, index) == 1; + }); + } + else if (clrType == typeof (double)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return SQLite3.ColumnDouble (stmt, index); + }); + } + else if (clrType == typeof (float)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (float) SQLite3.ColumnDouble (stmt, index); + }); + } + else if (clrType == typeof (TimeSpan)) { + if (conn.StoreTimeSpanAsTicks) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return new TimeSpan (SQLite3.ColumnInt64 (stmt, index)); + }); + } + else { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + TimeSpan resultTime; + if (!TimeSpan.TryParseExact (text, "c", System.Globalization.CultureInfo.InvariantCulture, System.Globalization.TimeSpanStyles.None, out resultTime)) { + resultTime = TimeSpan.Parse (text); + } + return resultTime; + }); + } + } + else if (clrType == typeof (DateTime)) { + if (conn.StoreDateTimeAsTicks) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return new DateTime (SQLite3.ColumnInt64 (stmt, index)); + }); + } + else { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + DateTime resultDate; + if (!DateTime.TryParseExact (text, conn.DateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture, conn.DateTimeStyle, out resultDate)) { + resultDate = DateTime.Parse (text); + } + return resultDate; + }); + } + } + else if (clrType == typeof (DateTimeOffset)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return new DateTimeOffset (SQLite3.ColumnInt64 (stmt, index), TimeSpan.Zero); + }); + } + else if (clrTypeInfo.IsEnum) { + // NOTE: Not sure of a good way (if any?) to do a strongly-typed fast setter like this for enumerated types -- for now, return null and column sets will revert back to the safe (but slow) Reflection-based method of column prop.Set() + } + else if (clrType == typeof (Int64)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return SQLite3.ColumnInt64 (stmt, index); + }); + } + else if (clrType == typeof (UInt32)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (uint)SQLite3.ColumnInt64 (stmt, index); + }); + } + else if (clrType == typeof (decimal)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (decimal)SQLite3.ColumnDouble (stmt, index); + }); + } + else if (clrType == typeof (Byte)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (byte)SQLite3.ColumnInt (stmt, index); + }); + } + else if (clrType == typeof (UInt16)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (ushort)SQLite3.ColumnInt (stmt, index); + }); + } + else if (clrType == typeof (Int16)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (short)SQLite3.ColumnInt (stmt, index); + }); + } + else if (clrType == typeof (sbyte)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + return (sbyte)SQLite3.ColumnInt (stmt, index); + }); + } + else if (clrType == typeof (byte[])) { + fastSetter = CreateTypedSetterDelegate (column, (stmt, index) => { + return SQLite3.ColumnByteArray (stmt, index); + }); + } + else if (clrType == typeof (Guid)) { + fastSetter = CreateNullableTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + return new Guid (text); + }); + } + else if (clrType == typeof (Uri)) { + fastSetter = CreateTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + return new Uri (text); + }); + } + else if (clrType == typeof (StringBuilder)) { + fastSetter = CreateTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + return new StringBuilder (text); + }); + } + else if (clrType == typeof (UriBuilder)) { + fastSetter = CreateTypedSetterDelegate (column, (stmt, index) => { + var text = SQLite3.ColumnString (stmt, index); + return new UriBuilder (text); + }); + } + else { + // NOTE: Will fall back to the slow setter method in the event that we are unable to create a fast setter delegate for a particular column type + } + return fastSetter; + } + + /// + /// This creates a strongly typed delegate that will permit fast setting of column values given a Sqlite3Statement and a column index. + /// + /// Note that this is identical to CreateTypedSetterDelegate(), but has an extra check to see if it should create a nullable version of the delegate. + /// + /// The type of the object whose member column is being set + /// The CLR type of the member in the object which corresponds to the given SQLite columnn + /// The column mapping that identifies the target member of the destination object + /// A lambda that can be used to retrieve the column value at query-time + /// A strongly-typed delegate + private static Action CreateNullableTypedSetterDelegate (TableMapping.Column column, Func getColumnValue) where ColumnMemberType : struct + { + var clrTypeInfo = column.PropertyInfo.PropertyType.GetTypeInfo(); + bool isNullable = false; + + if (clrTypeInfo.IsGenericType && clrTypeInfo.GetGenericTypeDefinition () == typeof (Nullable<>)) { + isNullable = true; + } + + if (isNullable) { + var setProperty = (Action)Delegate.CreateDelegate ( + typeof (Action), null, + column.PropertyInfo.GetSetMethod ()); + + return (o, stmt, i) => { + var colType = SQLite3.ColumnType (stmt, i); + if (colType != SQLite3.ColType.Null) + setProperty.Invoke ((ObjectType)o, getColumnValue.Invoke (stmt, i)); + }; + } + + return CreateTypedSetterDelegate (column, getColumnValue); + } + + /// + /// This creates a strongly typed delegate that will permit fast setting of column values given a Sqlite3Statement and a column index. + /// + /// The type of the object whose member column is being set + /// The CLR type of the member in the object which corresponds to the given SQLite columnn + /// The column mapping that identifies the target member of the destination object + /// A lambda that can be used to retrieve the column value at query-time + /// A strongly-typed delegate + private static Action CreateTypedSetterDelegate (TableMapping.Column column, Func getColumnValue) + { + var setProperty = (Action)Delegate.CreateDelegate ( + typeof (Action), null, + column.PropertyInfo.GetSetMethod ()); + + return (o, stmt, i) => { + var colType = SQLite3.ColumnType (stmt, i); + if (colType != SQLite3.ColType.Null) + setProperty.Invoke ((ObjectType)o, getColumnValue.Invoke (stmt, i)); + }; + } + } + + /// + /// Since the insert never changed, we only need to prepare once. + /// + class PreparedSqlLiteInsertCommand : IDisposable + { + bool Initialized; + + SQLiteConnection Connection; + + string CommandText; + + Sqlite3Statement Statement; + static readonly Sqlite3Statement NullStatement = default (Sqlite3Statement); + + public PreparedSqlLiteInsertCommand (SQLiteConnection conn, string commandText) + { + Connection = conn; + CommandText = commandText; + } + + public int ExecuteNonQuery (object[] source) + { + if (Initialized && Statement == NullStatement) { + throw new ObjectDisposedException (nameof (PreparedSqlLiteInsertCommand)); + } + + if (Connection.Trace) { + Connection.Tracer?.Invoke ("Executing: " + CommandText); + } + + var r = SQLite3.Result.OK; + + if (!Initialized) { + Statement = SQLite3.Prepare2 (Connection.Handle, CommandText); + Initialized = true; + } + + //bind the values. + if (source != null) { + for (int i = 0; i < source.Length; i++) { + SQLiteCommand.BindParameter (Statement, i + 1, source[i], Connection.StoreDateTimeAsTicks, Connection.DateTimeStringFormat, Connection.StoreTimeSpanAsTicks); + } + } + r = SQLite3.Step (Statement); + + if (r == SQLite3.Result.Done) { + int rowsAffected = SQLite3.Changes (Connection.Handle); + SQLite3.Reset (Statement); + return rowsAffected; + } + else if (r == SQLite3.Result.Error) { + string msg = SQLite3.GetErrmsg (Connection.Handle); + SQLite3.Reset (Statement); + throw SQLiteException.New (r, msg); + } + else if (r == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode (Connection.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + SQLite3.Reset (Statement); + throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (Connection.Handle)); + } + else { + SQLite3.Reset (Statement); + throw SQLiteException.New (r, SQLite3.GetErrmsg (Connection.Handle)); + } + } + + public void Dispose () + { + Dispose (true); + GC.SuppressFinalize (this); + } + + void Dispose (bool disposing) + { + var s = Statement; + Statement = NullStatement; + Connection = null; + if (s != NullStatement) { + SQLite3.Finalize (s); + } + } + + ~PreparedSqlLiteInsertCommand () + { + Dispose (false); + } + } + + public enum CreateTableResult + { + Created, + Migrated, + } + + public class CreateTablesResult + { + public Dictionary Results { get; private set; } + + public CreateTablesResult () + { + Results = new Dictionary (); + } + } + + public abstract class BaseTableQuery + { + protected class Ordering + { + public string ColumnName { get; set; } + public bool Ascending { get; set; } + } + } + + public class TableQuery : BaseTableQuery, IEnumerable + { + public SQLiteConnection Connection { get; private set; } + + public TableMapping Table { get; private set; } + + Expression _where; + List _orderBys; + int? _limit; + int? _offset; + + BaseTableQuery _joinInner; + Expression _joinInnerKeySelector; + BaseTableQuery _joinOuter; + Expression _joinOuterKeySelector; + Expression _joinSelector; + + Expression _selector; + + TableQuery (SQLiteConnection conn, TableMapping table) + { + Connection = conn; + Table = table; + } + + public TableQuery (SQLiteConnection conn) + { + Connection = conn; + Table = Connection.GetMapping (typeof (T)); + } + + public TableQuery Clone () + { + var q = new TableQuery (Connection, Table); + q._where = _where; + q._deferred = _deferred; + if (_orderBys != null) { + q._orderBys = new List (_orderBys); + } + q._limit = _limit; + q._offset = _offset; + q._joinInner = _joinInner; + q._joinInnerKeySelector = _joinInnerKeySelector; + q._joinOuter = _joinOuter; + q._joinOuterKeySelector = _joinOuterKeySelector; + q._joinSelector = _joinSelector; + q._selector = _selector; + return q; + } + + /// + /// Filters the query based on a predicate. + /// + public TableQuery Where (Expression> predExpr) + { + if (predExpr.NodeType == ExpressionType.Lambda) { + var lambda = (LambdaExpression)predExpr; + var pred = lambda.Body; + var q = Clone (); + q.AddWhere (pred); + return q; + } + else { + throw new NotSupportedException ("Must be a predicate"); + } + } + + /// + /// Delete all the rows that match this query. + /// + public int Delete () + { + return Delete (null); + } + + /// + /// Delete all the rows that match this query and the given predicate. + /// + public int Delete (Expression> predExpr) + { + if (_limit.HasValue || _offset.HasValue) + throw new InvalidOperationException ("Cannot delete with limits or offsets"); + + if (_where == null && predExpr == null) + throw new InvalidOperationException ("No condition specified"); + + var pred = _where; + + if (predExpr != null && predExpr.NodeType == ExpressionType.Lambda) { + var lambda = (LambdaExpression)predExpr; + pred = pred != null ? Expression.AndAlso (pred, lambda.Body) : lambda.Body; + } + + var args = new List (); + var cmdText = "delete from \"" + Table.TableName + "\""; + var w = CompileExpr (pred, args); + cmdText += " where " + w.CommandText; + + var command = Connection.CreateCommand (cmdText, args.ToArray ()); + + int result = command.ExecuteNonQuery (); + return result; + } + + /// + /// Yields a given number of elements from the query and then skips the remainder. + /// + public TableQuery Take (int n) + { + var q = Clone (); + q._limit = n; + return q; + } + + /// + /// Skips a given number of elements from the query and then yields the remainder. + /// + public TableQuery Skip (int n) + { + var q = Clone (); + q._offset = n; + return q; + } + + /// + /// Returns the element at a given index + /// + public T ElementAt (int index) + { + return Skip (index).Take (1).First (); + } + + bool _deferred; + public TableQuery Deferred () + { + var q = Clone (); + q._deferred = true; + return q; + } + + /// + /// Order the query results according to a key. + /// + public TableQuery OrderBy (Expression> orderExpr) + { + return AddOrderBy (orderExpr, true); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery OrderByDescending (Expression> orderExpr) + { + return AddOrderBy (orderExpr, false); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery ThenBy (Expression> orderExpr) + { + return AddOrderBy (orderExpr, true); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery ThenByDescending (Expression> orderExpr) + { + return AddOrderBy (orderExpr, false); + } + + TableQuery AddOrderBy (Expression> orderExpr, bool asc) + { + if (orderExpr.NodeType == ExpressionType.Lambda) { + var lambda = (LambdaExpression)orderExpr; + + MemberExpression mem = null; + + var unary = lambda.Body as UnaryExpression; + if (unary != null && unary.NodeType == ExpressionType.Convert) { + mem = unary.Operand as MemberExpression; + } + else { + mem = lambda.Body as MemberExpression; + } + + if (mem != null && (mem.Expression.NodeType == ExpressionType.Parameter)) { + var q = Clone (); + if (q._orderBys == null) { + q._orderBys = new List (); + } + q._orderBys.Add (new Ordering { + ColumnName = Table.FindColumnWithPropertyName (mem.Member.Name).Name, + Ascending = asc + }); + return q; + } + else { + throw new NotSupportedException ("Order By does not support: " + orderExpr); + } + } + else { + throw new NotSupportedException ("Must be a predicate"); + } + } + + private void AddWhere (Expression pred) + { + if (_where == null) { + _where = pred; + } + else { + _where = Expression.AndAlso (_where, pred); + } + } + + ///// + ///// Performs an inner join of two queries based on matching keys extracted from the elements. + ///// + //public TableQuery Join ( + // TableQuery inner, + // Expression> outerKeySelector, + // Expression> innerKeySelector, + // Expression> resultSelector) + //{ + // var q = new TableQuery (Connection, Connection.GetMapping (typeof (TResult))) { + // _joinOuter = this, + // _joinOuterKeySelector = outerKeySelector, + // _joinInner = inner, + // _joinInnerKeySelector = innerKeySelector, + // _joinSelector = resultSelector, + // }; + // return q; + //} + + // Not needed until Joins are supported + // Keeping this commented out forces the default Linq to objects processor to run + //public TableQuery Select (Expression> selector) + //{ + // var q = Clone (); + // q._selector = selector; + // return q; + //} + + private SQLiteCommand GenerateCommand (string selectionList) + { + if (_joinInner != null && _joinOuter != null) { + throw new NotSupportedException ("Joins are not supported."); + } + else { + var cmdText = "select " + selectionList + " from \"" + Table.TableName + "\""; + var args = new List (); + if (_where != null) { + var w = CompileExpr (_where, args); + cmdText += " where " + w.CommandText; + } + if ((_orderBys != null) && (_orderBys.Count > 0)) { + var t = string.Join (", ", _orderBys.Select (o => "\"" + o.ColumnName + "\"" + (o.Ascending ? "" : " desc")).ToArray ()); + cmdText += " order by " + t; + } + if (_limit.HasValue) { + cmdText += " limit " + _limit.Value; + } + if (_offset.HasValue) { + if (!_limit.HasValue) { + cmdText += " limit -1 "; + } + cmdText += " offset " + _offset.Value; + } + return Connection.CreateCommand (cmdText, args.ToArray ()); + } + } + + class CompileResult + { + public string CommandText { get; set; } + + public object Value { get; set; } + } + + private CompileResult CompileExpr (Expression expr, List queryArgs) + { + if (expr == null) { + throw new NotSupportedException ("Expression is NULL"); + } + else if (expr is BinaryExpression) { + var bin = (BinaryExpression)expr; + + // VB turns 'x=="foo"' into 'CompareString(x,"foo",true/false)==0', so we need to unwrap it + // http://blogs.msdn.com/b/vbteam/archive/2007/09/18/vb-expression-trees-string-comparisons.aspx + if (bin.Left.NodeType == ExpressionType.Call) { + var call = (MethodCallExpression)bin.Left; + if (call.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators" + && call.Method.Name == "CompareString") + bin = Expression.MakeBinary (bin.NodeType, call.Arguments[0], call.Arguments[1]); + } + + + var leftr = CompileExpr (bin.Left, queryArgs); + var rightr = CompileExpr (bin.Right, queryArgs); + + //If either side is a parameter and is null, then handle the other side specially (for "is null"/"is not null") + string text; + if (leftr.CommandText == "?" && leftr.Value == null) + text = CompileNullBinaryExpression (bin, rightr); + else if (rightr.CommandText == "?" && rightr.Value == null) + text = CompileNullBinaryExpression (bin, leftr); + else + text = "(" + leftr.CommandText + " " + GetSqlName (bin) + " " + rightr.CommandText + ")"; + return new CompileResult { CommandText = text }; + } + else if (expr.NodeType == ExpressionType.Not) { + var operandExpr = ((UnaryExpression)expr).Operand; + var opr = CompileExpr (operandExpr, queryArgs); + object val = opr.Value; + if (val is bool) + val = !((bool)val); + return new CompileResult { + CommandText = "NOT(" + opr.CommandText + ")", + Value = val + }; + } + else if (expr.NodeType == ExpressionType.Call) { + + var call = (MethodCallExpression)expr; + var args = new CompileResult[call.Arguments.Count]; + var obj = call.Object != null ? CompileExpr (call.Object, queryArgs) : null; + + for (var i = 0; i < args.Length; i++) { + args[i] = CompileExpr (call.Arguments[i], queryArgs); + } + + var sqlCall = ""; + + if (call.Method.Name == "Like" && args.Length == 2) { + sqlCall = "(" + args[0].CommandText + " like " + args[1].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 2) { + sqlCall = "(" + args[1].CommandText + " in " + args[0].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 1) { + if (call.Object != null && call.Object.Type == typeof (string)) { + sqlCall = "( instr(" + obj.CommandText + "," + args[0].CommandText + ") >0 )"; + } + else { + sqlCall = "(" + args[0].CommandText + " in " + obj.CommandText + ")"; + } + } + else if (call.Method.Name == "StartsWith" && args.Length >= 1) { + var startsWithCmpOp = StringComparison.CurrentCulture; + if (args.Length == 2) { + startsWithCmpOp = (StringComparison)args[1].Value; + } + switch (startsWithCmpOp) { + case StringComparison.Ordinal: + case StringComparison.CurrentCulture: + sqlCall = "( substr(" + obj.CommandText + ", 1, " + args[0].Value.ToString ().Length + ") = " + args[0].CommandText + ")"; + break; + case StringComparison.OrdinalIgnoreCase: + case StringComparison.CurrentCultureIgnoreCase: + sqlCall = "(" + obj.CommandText + " like (" + args[0].CommandText + " || '%'))"; + break; + } + + } + else if (call.Method.Name == "EndsWith" && args.Length >= 1) { + var endsWithCmpOp = StringComparison.CurrentCulture; + if (args.Length == 2) { + endsWithCmpOp = (StringComparison)args[1].Value; + } + switch (endsWithCmpOp) { + case StringComparison.Ordinal: + case StringComparison.CurrentCulture: + sqlCall = "( substr(" + obj.CommandText + ", length(" + obj.CommandText + ") - " + args[0].Value.ToString ().Length + "+1, " + args[0].Value.ToString ().Length + ") = " + args[0].CommandText + ")"; + break; + case StringComparison.OrdinalIgnoreCase: + case StringComparison.CurrentCultureIgnoreCase: + sqlCall = "(" + obj.CommandText + " like ('%' || " + args[0].CommandText + "))"; + break; + } + } + else if (call.Method.Name == "Equals" && args.Length == 1) { + sqlCall = "(" + obj.CommandText + " = (" + args[0].CommandText + "))"; + } + else if (call.Method.Name == "ToLower") { + sqlCall = "(lower(" + obj.CommandText + "))"; + } + else if (call.Method.Name == "ToUpper") { + sqlCall = "(upper(" + obj.CommandText + "))"; + } + else if (call.Method.Name == "Replace" && args.Length == 2) { + sqlCall = "(replace(" + obj.CommandText + "," + args[0].CommandText + "," + args[1].CommandText + "))"; + } + else if (call.Method.Name == "IsNullOrEmpty" && args.Length == 1) { + sqlCall = "(" + args[0].CommandText + " is null or" + args[0].CommandText + " ='' )"; + } + else { + sqlCall = call.Method.Name.ToLower () + "(" + string.Join (",", args.Select (a => a.CommandText).ToArray ()) + ")"; + } + return new CompileResult { CommandText = sqlCall }; + + } + else if (expr.NodeType == ExpressionType.Constant) { + var c = (ConstantExpression)expr; + queryArgs.Add (c.Value); + return new CompileResult { + CommandText = "?", + Value = c.Value + }; + } + else if (expr.NodeType == ExpressionType.Convert) { + var u = (UnaryExpression)expr; + var ty = u.Type; + var valr = CompileExpr (u.Operand, queryArgs); + return new CompileResult { + CommandText = valr.CommandText, + Value = valr.Value != null ? ConvertTo (valr.Value, ty) : null + }; + } + else if (expr.NodeType == ExpressionType.MemberAccess) { + var mem = (MemberExpression)expr; + + var paramExpr = mem.Expression as ParameterExpression; + if (paramExpr == null) { + var convert = mem.Expression as UnaryExpression; + if (convert != null && convert.NodeType == ExpressionType.Convert) { + paramExpr = convert.Operand as ParameterExpression; + } + } + + if (paramExpr != null) { + // + // This is a column of our table, output just the column name + // Need to translate it if that column name is mapped + // + var columnName = Table.FindColumnWithPropertyName (mem.Member.Name).Name; + return new CompileResult { CommandText = "\"" + columnName + "\"" }; + } + else { + object obj = null; + if (mem.Expression != null) { + var r = CompileExpr (mem.Expression, queryArgs); + if (r.Value == null) { + throw new NotSupportedException ("Member access failed to compile expression"); + } + if (r.CommandText == "?") { + queryArgs.RemoveAt (queryArgs.Count - 1); + } + obj = r.Value; + } + + // + // Get the member value + // + object val = null; + + if (mem.Member is PropertyInfo) { + var m = (PropertyInfo)mem.Member; + val = m.GetValue (obj, null); + } + else if (mem.Member is FieldInfo) { + var m = (FieldInfo)mem.Member; + val = m.GetValue (obj); + } + else { + throw new NotSupportedException ("MemberExpr: " + mem.Member.GetType ()); + } + + // + // Work special magic for enumerables + // + if (val != null && val is System.Collections.IEnumerable && !(val is string) && !(val is System.Collections.Generic.IEnumerable)) { + var sb = new System.Text.StringBuilder (); + sb.Append ("("); + var head = ""; + foreach (var a in (System.Collections.IEnumerable)val) { + queryArgs.Add (a); + sb.Append (head); + sb.Append ("?"); + head = ","; + } + sb.Append (")"); + return new CompileResult { + CommandText = sb.ToString (), + Value = val + }; + } + else { + queryArgs.Add (val); + return new CompileResult { + CommandText = "?", + Value = val + }; + } + } + } + throw new NotSupportedException ("Cannot compile: " + expr.NodeType.ToString ()); + } + + static object ConvertTo (object obj, Type t) + { + Type nut = Nullable.GetUnderlyingType (t); + + if (nut != null) { + if (obj == null) + return null; + return Convert.ChangeType (obj, nut); + } + else { + return Convert.ChangeType (obj, t); + } + } + + /// + /// Compiles a BinaryExpression where one of the parameters is null. + /// + /// The expression to compile + /// The non-null parameter + private string CompileNullBinaryExpression (BinaryExpression expression, CompileResult parameter) + { + if (expression.NodeType == ExpressionType.Equal) + return "(" + parameter.CommandText + " is ?)"; + else if (expression.NodeType == ExpressionType.NotEqual) + return "(" + parameter.CommandText + " is not ?)"; + else if (expression.NodeType == ExpressionType.GreaterThan + || expression.NodeType == ExpressionType.GreaterThanOrEqual + || expression.NodeType == ExpressionType.LessThan + || expression.NodeType == ExpressionType.LessThanOrEqual) + return "(" + parameter.CommandText + " < ?)"; // always false + else + throw new NotSupportedException ("Cannot compile Null-BinaryExpression with type " + expression.NodeType.ToString ()); + } + + string GetSqlName (Expression expr) + { + var n = expr.NodeType; + if (n == ExpressionType.GreaterThan) + return ">"; + else if (n == ExpressionType.GreaterThanOrEqual) { + return ">="; + } + else if (n == ExpressionType.LessThan) { + return "<"; + } + else if (n == ExpressionType.LessThanOrEqual) { + return "<="; + } + else if (n == ExpressionType.And) { + return "&"; + } + else if (n == ExpressionType.AndAlso) { + return "and"; + } + else if (n == ExpressionType.Or) { + return "|"; + } + else if (n == ExpressionType.OrElse) { + return "or"; + } + else if (n == ExpressionType.Equal) { + return "="; + } + else if (n == ExpressionType.NotEqual) { + return "!="; + } + else { + throw new NotSupportedException ("Cannot get SQL for: " + n); + } + } + + /// + /// Execute SELECT COUNT(*) on the query + /// + public int Count () + { + return GenerateCommand ("count(*)").ExecuteScalar (); + } + + /// + /// Execute SELECT COUNT(*) on the query with an additional WHERE clause. + /// + public int Count (Expression> predExpr) + { + return Where (predExpr).Count (); + } + + public IEnumerator GetEnumerator () + { + if (!_deferred) + return GenerateCommand ("*").ExecuteQuery ().GetEnumerator (); + + return GenerateCommand ("*").ExecuteDeferredQuery ().GetEnumerator (); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator () + { + return GetEnumerator (); + } + + /// + /// Queries the database and returns the results as a List. + /// + public List ToList () + { + return GenerateCommand ("*").ExecuteQuery (); + } + + /// + /// Queries the database and returns the results as an array. + /// + public T[] ToArray () + { + return GenerateCommand ("*").ExecuteQuery ().ToArray (); + } + + /// + /// Returns the first element of this query. + /// + public T First () + { + var query = Take (1); + return query.ToList ().First (); + } + + /// + /// Returns the first element of this query, or null if no element is found. + /// + public T FirstOrDefault () + { + var query = Take (1); + return query.ToList ().FirstOrDefault (); + } + + /// + /// Returns the first element of this query that matches the predicate. + /// + public T First (Expression> predExpr) + { + return Where (predExpr).First (); + } + + /// + /// Returns the first element of this query that matches the predicate, or null + /// if no element is found. + /// + public T FirstOrDefault (Expression> predExpr) + { + return Where (predExpr).FirstOrDefault (); + } + } + + public static class SQLite3 + { + public enum Result : int + { + OK = 0, + Error = 1, + Internal = 2, + Perm = 3, + Abort = 4, + Busy = 5, + Locked = 6, + NoMem = 7, + ReadOnly = 8, + Interrupt = 9, + IOError = 10, + Corrupt = 11, + NotFound = 12, + Full = 13, + CannotOpen = 14, + LockErr = 15, + Empty = 16, + SchemaChngd = 17, + TooBig = 18, + Constraint = 19, + Mismatch = 20, + Misuse = 21, + NotImplementedLFS = 22, + AccessDenied = 23, + Format = 24, + Range = 25, + NonDBFile = 26, + Notice = 27, + Warning = 28, + Row = 100, + Done = 101 + } + + public enum ExtendedResult : int + { + IOErrorRead = (Result.IOError | (1 << 8)), + IOErrorShortRead = (Result.IOError | (2 << 8)), + IOErrorWrite = (Result.IOError | (3 << 8)), + IOErrorFsync = (Result.IOError | (4 << 8)), + IOErrorDirFSync = (Result.IOError | (5 << 8)), + IOErrorTruncate = (Result.IOError | (6 << 8)), + IOErrorFStat = (Result.IOError | (7 << 8)), + IOErrorUnlock = (Result.IOError | (8 << 8)), + IOErrorRdlock = (Result.IOError | (9 << 8)), + IOErrorDelete = (Result.IOError | (10 << 8)), + IOErrorBlocked = (Result.IOError | (11 << 8)), + IOErrorNoMem = (Result.IOError | (12 << 8)), + IOErrorAccess = (Result.IOError | (13 << 8)), + IOErrorCheckReservedLock = (Result.IOError | (14 << 8)), + IOErrorLock = (Result.IOError | (15 << 8)), + IOErrorClose = (Result.IOError | (16 << 8)), + IOErrorDirClose = (Result.IOError | (17 << 8)), + IOErrorSHMOpen = (Result.IOError | (18 << 8)), + IOErrorSHMSize = (Result.IOError | (19 << 8)), + IOErrorSHMLock = (Result.IOError | (20 << 8)), + IOErrorSHMMap = (Result.IOError | (21 << 8)), + IOErrorSeek = (Result.IOError | (22 << 8)), + IOErrorDeleteNoEnt = (Result.IOError | (23 << 8)), + IOErrorMMap = (Result.IOError | (24 << 8)), + LockedSharedcache = (Result.Locked | (1 << 8)), + BusyRecovery = (Result.Busy | (1 << 8)), + CannottOpenNoTempDir = (Result.CannotOpen | (1 << 8)), + CannotOpenIsDir = (Result.CannotOpen | (2 << 8)), + CannotOpenFullPath = (Result.CannotOpen | (3 << 8)), + CorruptVTab = (Result.Corrupt | (1 << 8)), + ReadonlyRecovery = (Result.ReadOnly | (1 << 8)), + ReadonlyCannotLock = (Result.ReadOnly | (2 << 8)), + ReadonlyRollback = (Result.ReadOnly | (3 << 8)), + AbortRollback = (Result.Abort | (2 << 8)), + ConstraintCheck = (Result.Constraint | (1 << 8)), + ConstraintCommitHook = (Result.Constraint | (2 << 8)), + ConstraintForeignKey = (Result.Constraint | (3 << 8)), + ConstraintFunction = (Result.Constraint | (4 << 8)), + ConstraintNotNull = (Result.Constraint | (5 << 8)), + ConstraintPrimaryKey = (Result.Constraint | (6 << 8)), + ConstraintTrigger = (Result.Constraint | (7 << 8)), + ConstraintUnique = (Result.Constraint | (8 << 8)), + ConstraintVTab = (Result.Constraint | (9 << 8)), + NoticeRecoverWAL = (Result.Notice | (1 << 8)), + NoticeRecoverRollback = (Result.Notice | (2 << 8)) + } + + + public enum ConfigOption : int + { + SingleThread = 1, + MultiThread = 2, + Serialized = 3 + } + + const string LibraryPath = "sqlite3"; + +#if !USE_CSHARP_SQLITE && !USE_WP8_NATIVE_SQLITE && !USE_SQLITEPCL_RAW + [DllImport(LibraryPath, EntryPoint = "sqlite3_threadsafe", CallingConvention=CallingConvention.Cdecl)] + public static extern int Threadsafe (); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Open ([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open_v2", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Open ([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db, int flags, [MarshalAs (UnmanagedType.LPStr)] string zvfs); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open(byte[] filename, out IntPtr db, int flags, [MarshalAs (UnmanagedType.LPStr)] string zvfs); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open16", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open16([MarshalAs(UnmanagedType.LPWStr)] string filename, out IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_enable_load_extension", CallingConvention=CallingConvention.Cdecl)] + public static extern Result EnableLoadExtension (IntPtr db, int onoff); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_close", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Close (IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_close_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Close2(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_initialize", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Initialize(); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_shutdown", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Shutdown(); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_config", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Config (ConfigOption option); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_win32_set_directory", CallingConvention=CallingConvention.Cdecl, CharSet=CharSet.Unicode)] + public static extern int SetDirectory (uint directoryType, string directoryPath); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_busy_timeout", CallingConvention=CallingConvention.Cdecl)] + public static extern Result BusyTimeout (IntPtr db, int milliseconds); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_changes", CallingConvention=CallingConvention.Cdecl)] + public static extern int Changes (IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_prepare_v2", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Prepare2 (IntPtr db, [MarshalAs(UnmanagedType.LPStr)] string sql, int numBytes, out IntPtr stmt, IntPtr pzTail); + +#if NETFX_CORE + [DllImport (LibraryPath, EntryPoint = "sqlite3_prepare_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Prepare2 (IntPtr db, byte[] queryBytes, int numBytes, out IntPtr stmt, IntPtr pzTail); +#endif + + public static IntPtr Prepare2 (IntPtr db, string query) + { + IntPtr stmt; +#if NETFX_CORE + byte[] queryBytes = System.Text.UTF8Encoding.UTF8.GetBytes (query); + var r = Prepare2 (db, queryBytes, queryBytes.Length, out stmt, IntPtr.Zero); +#else + var r = Prepare2 (db, query, System.Text.UTF8Encoding.UTF8.GetByteCount (query), out stmt, IntPtr.Zero); +#endif + if (r != Result.OK) { + throw SQLiteException.New (r, GetErrmsg (db)); + } + return stmt; + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_step", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Step (IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_reset", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Reset (IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_finalize", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Finalize (IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_last_insert_rowid", CallingConvention=CallingConvention.Cdecl)] + public static extern long LastInsertRowid (IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_errmsg16", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr Errmsg (IntPtr db); + + public static string GetErrmsg (IntPtr db) + { + return Marshal.PtrToStringUni (Errmsg (db)); + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_parameter_index", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindParameterIndex (IntPtr stmt, [MarshalAs(UnmanagedType.LPStr)] string name); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_null", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindNull (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_int", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindInt (IntPtr stmt, int index, int val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_int64", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindInt64 (IntPtr stmt, int index, long val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_double", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindDouble (IntPtr stmt, int index, double val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_text16", CallingConvention=CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + public static extern int BindText (IntPtr stmt, int index, [MarshalAs(UnmanagedType.LPWStr)] string val, int n, IntPtr free); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_blob", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindBlob (IntPtr stmt, int index, byte[] val, int n, IntPtr free); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_count", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnCount (IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_name", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnName (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_name16", CallingConvention=CallingConvention.Cdecl)] + static extern IntPtr ColumnName16Internal (IntPtr stmt, int index); + public static string ColumnName16(IntPtr stmt, int index) + { + return Marshal.PtrToStringUni(ColumnName16Internal(stmt, index)); + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_type", CallingConvention=CallingConvention.Cdecl)] + public static extern ColType ColumnType (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_int", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnInt (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_int64", CallingConvention=CallingConvention.Cdecl)] + public static extern long ColumnInt64 (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_double", CallingConvention=CallingConvention.Cdecl)] + public static extern double ColumnDouble (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_text", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnText (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_text16", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnText16 (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_blob", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnBlob (IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_bytes", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnBytes (IntPtr stmt, int index); + + public static string ColumnString (IntPtr stmt, int index) + { + return Marshal.PtrToStringUni (SQLite3.ColumnText16 (stmt, index)); + } + + public static byte[] ColumnByteArray (IntPtr stmt, int index) + { + int length = ColumnBytes (stmt, index); + var result = new byte[length]; + if (length > 0) + Marshal.Copy (ColumnBlob (stmt, index), result, 0, length); + return result; + } + + [DllImport (LibraryPath, EntryPoint = "sqlite3_errcode", CallingConvention = CallingConvention.Cdecl)] + public static extern Result GetResult (Sqlite3DatabaseHandle db); + + [DllImport (LibraryPath, EntryPoint = "sqlite3_extended_errcode", CallingConvention = CallingConvention.Cdecl)] + public static extern ExtendedResult ExtendedErrCode (IntPtr db); + + [DllImport (LibraryPath, EntryPoint = "sqlite3_libversion_number", CallingConvention = CallingConvention.Cdecl)] + public static extern int LibVersionNumber (); + + [DllImport (LibraryPath, EntryPoint = "sqlite3_backup_init", CallingConvention = CallingConvention.Cdecl)] + public static extern Sqlite3BackupHandle BackupInit (Sqlite3DatabaseHandle destDb, [MarshalAs (UnmanagedType.LPStr)] string destName, Sqlite3DatabaseHandle sourceDb, [MarshalAs (UnmanagedType.LPStr)] string sourceName); + + [DllImport (LibraryPath, EntryPoint = "sqlite3_backup_step", CallingConvention = CallingConvention.Cdecl)] + public static extern Result BackupStep (Sqlite3BackupHandle backup, int numPages); + + [DllImport (LibraryPath, EntryPoint = "sqlite3_backup_finish", CallingConvention = CallingConvention.Cdecl)] + public static extern Result BackupFinish (Sqlite3BackupHandle backup); +#else + public static Result Open (string filename, out Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_open (filename, out db); + } + + public static Result Open (string filename, out Sqlite3DatabaseHandle db, int flags, string vfsName) + { +#if USE_WP8_NATIVE_SQLITE + return (Result)Sqlite3.sqlite3_open_v2(filename, out db, flags, vfsName ?? ""); +#else + return (Result)Sqlite3.sqlite3_open_v2 (filename, out db, flags, vfsName); +#endif + } + + public static Result Close (Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_close (db); + } + + public static Result Close2 (Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_close_v2 (db); + } + + public static Result BusyTimeout (Sqlite3DatabaseHandle db, int milliseconds) + { + return (Result)Sqlite3.sqlite3_busy_timeout (db, milliseconds); + } + + public static int Changes (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_changes (db); + } + + public static Sqlite3Statement Prepare2 (Sqlite3DatabaseHandle db, string query) + { + Sqlite3Statement stmt = default (Sqlite3Statement); +#if USE_WP8_NATIVE_SQLITE || USE_SQLITEPCL_RAW + var r = Sqlite3.sqlite3_prepare_v2 (db, query, out stmt); +#else + stmt = new Sqlite3Statement(); + var r = Sqlite3.sqlite3_prepare_v2(db, query, -1, ref stmt, 0); +#endif + if (r != 0) { + throw SQLiteException.New ((Result)r, GetErrmsg (db)); + } + return stmt; + } + + public static Result Step (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_step (stmt); + } + + public static Result Reset (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_reset (stmt); + } + + public static Result Finalize (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_finalize (stmt); + } + + public static long LastInsertRowid (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_last_insert_rowid (db); + } + + public static string GetErrmsg (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_errmsg (db).utf8_to_string (); + } + + public static int BindParameterIndex (Sqlite3Statement stmt, string name) + { + return Sqlite3.sqlite3_bind_parameter_index (stmt, name); + } + + public static int BindNull (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_bind_null (stmt, index); + } + + public static int BindInt (Sqlite3Statement stmt, int index, int val) + { + return Sqlite3.sqlite3_bind_int (stmt, index, val); + } + + public static int BindInt64 (Sqlite3Statement stmt, int index, long val) + { + return Sqlite3.sqlite3_bind_int64 (stmt, index, val); + } + + public static int BindDouble (Sqlite3Statement stmt, int index, double val) + { + return Sqlite3.sqlite3_bind_double (stmt, index, val); + } + + public static int BindText (Sqlite3Statement stmt, int index, string val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_text(stmt, index, val, n); +#elif USE_SQLITEPCL_RAW + return Sqlite3.sqlite3_bind_text (stmt, index, val); +#else + return Sqlite3.sqlite3_bind_text(stmt, index, val, n, null); +#endif + } + + public static int BindBlob (Sqlite3Statement stmt, int index, byte[] val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n); +#elif USE_SQLITEPCL_RAW + return Sqlite3.sqlite3_bind_blob (stmt, index, val); +#else + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n, null); +#endif + } + + public static int ColumnCount (Sqlite3Statement stmt) + { + return Sqlite3.sqlite3_column_count (stmt); + } + + public static string ColumnName (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name (stmt, index).utf8_to_string (); + } + + public static string ColumnName16 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name (stmt, index).utf8_to_string (); + } + + public static ColType ColumnType (Sqlite3Statement stmt, int index) + { + return (ColType)Sqlite3.sqlite3_column_type (stmt, index); + } + + public static int ColumnInt (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int (stmt, index); + } + + public static long ColumnInt64 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int64 (stmt, index); + } + + public static double ColumnDouble (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_double (stmt, index); + } + + public static string ColumnText (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index).utf8_to_string (); + } + + public static string ColumnText16 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index).utf8_to_string (); + } + + public static byte[] ColumnBlob (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_blob (stmt, index).ToArray (); + } + + public static int ColumnBytes (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_bytes (stmt, index); + } + + public static string ColumnString (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index).utf8_to_string (); + } + + public static byte[] ColumnByteArray (Sqlite3Statement stmt, int index) + { + int length = ColumnBytes (stmt, index); + if (length > 0) { + return ColumnBlob (stmt, index); + } + return new byte[0]; + } + + public static Result EnableLoadExtension (Sqlite3DatabaseHandle db, int onoff) + { + return (Result)Sqlite3.sqlite3_enable_load_extension (db, onoff); + } + + public static int LibVersionNumber () + { + return Sqlite3.sqlite3_libversion_number (); + } + + public static Result GetResult (Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_errcode (db); + } + + public static ExtendedResult ExtendedErrCode (Sqlite3DatabaseHandle db) + { + return (ExtendedResult)Sqlite3.sqlite3_extended_errcode (db); + } + + public static Sqlite3BackupHandle BackupInit (Sqlite3DatabaseHandle destDb, string destName, Sqlite3DatabaseHandle sourceDb, string sourceName) + { + return Sqlite3.sqlite3_backup_init (destDb, destName, sourceDb, sourceName); + } + + public static Result BackupStep (Sqlite3BackupHandle backup, int numPages) + { + return (Result)Sqlite3.sqlite3_backup_step (backup, numPages); + } + + public static Result BackupFinish (Sqlite3BackupHandle backup) + { + return (Result)Sqlite3.sqlite3_backup_finish (backup); + } +#endif + + public enum ColType : int + { + Integer = 1, + Float = 2, + Text = 3, + Blob = 4, + Null = 5 + } + } } diff --git a/SQLiteLegacy.cs b/SQLiteLegacy.cs new file mode 100644 index 00000000..4d0874b8 --- /dev/null +++ b/SQLiteLegacy.cs @@ -0,0 +1,151 @@ +using CefSharp; +using System; +using System.Collections.Generic; +using System.Data.SQLite; +using System.IO; +using System.Threading; + +namespace VRCX +{ + public class SQLiteLegacy + { + public static readonly SQLiteLegacy Instance; + private readonly ReaderWriterLockSlim m_ConnectionLock; + private readonly SQLiteConnection m_Connection; + + static SQLiteLegacy() + { + Instance = new SQLiteLegacy(); + } + + public SQLiteLegacy() + { + m_ConnectionLock = new ReaderWriterLockSlim(); + + var dataSource = Program.ConfigLocation; + m_Connection = new SQLiteConnection($"Data Source=\"{dataSource}\";Version=3;PRAGMA locking_mode=NORMAL;PRAGMA busy_timeout=5000", true); + } + + internal void Init() + { + m_Connection.Open(); + } + + internal void Exit() + { + m_Connection.Close(); + m_Connection.Dispose(); + } + + public void Execute(IJavascriptCallback callback, string sql, IDictionary args = null) + { + try + { + m_ConnectionLock.EnterReadLock(); + try + { + using (var command = new SQLiteCommand(sql, m_Connection)) + { + if (args != null) + { + foreach (var arg in args) + { + command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); + } + } + using (var reader = command.ExecuteReader()) + { + while (reader.Read() == true) + { + var values = new object[reader.FieldCount]; + reader.GetValues(values); + if (callback.CanExecute == true) + { + callback.ExecuteAsync(null, values); + } + } + } + } + if (callback.CanExecute == true) + { + callback.ExecuteAsync(null, null); + } + } + finally + { + m_ConnectionLock.ExitReadLock(); + } + } + catch (Exception e) + { + if (callback.CanExecute == true) + { + callback.ExecuteAsync(e.Message, null); + } + } + + callback.Dispose(); + } + + public void Execute(Action callback, string sql, IDictionary args = null) + { + m_ConnectionLock.EnterReadLock(); + try + { + using (var command = new SQLiteCommand(sql, m_Connection)) + { + if (args != null) + { + foreach (var arg in args) + { + command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); + } + } + using (var reader = command.ExecuteReader()) + { + while (reader.Read() == true) + { + var values = new object[reader.FieldCount]; + reader.GetValues(values); + callback(values); + } + } + } + } + catch + { + } + finally + { + m_ConnectionLock.ExitReadLock(); + } + } + + public int ExecuteNonQuery(string sql, IDictionary args = null) + { + int result = -1; + + m_ConnectionLock.EnterWriteLock(); + try + { + using (var command = new SQLiteCommand(sql, m_Connection)) + { + if (args != null) + { + foreach (var arg in args) + { + command.Parameters.Add(new SQLiteParameter(arg.Key, arg.Value)); + } + } + result = command.ExecuteNonQuery(); + } + } + finally + { + m_ConnectionLock.ExitWriteLock(); + } + + return result; + } + } +} diff --git a/ScreenshotHelper.cs b/ScreenshotHelper.cs index b2b26d8e..6868d949 100644 --- a/ScreenshotHelper.cs +++ b/ScreenshotHelper.cs @@ -181,7 +181,11 @@ namespace VRCX return new PNGChunk(type, chunkData); } - // parse LFS screenshot PNG metadata + /// + /// Parses the metadata string of a vrchat screenshot with taken with LFS and returns a JObject containing the parsed data. + /// + /// The metadata string to parse. + /// A JObject containing the parsed data. public static JObject ParseLfsPicture(string metadataString) { var metadata = new JObject(); diff --git a/Util.cs b/Util.cs index 7151ce6a..18689fb3 100644 --- a/Util.cs +++ b/Util.cs @@ -1,3 +1,4 @@ +using System; using CefSharp; namespace VRCX @@ -11,7 +12,7 @@ namespace VRCX repository.Register("SharedVariable", SharedVariable.Instance, false); repository.Register("WebApi", WebApi.Instance, true); repository.Register("VRCXStorage", VRCXStorage.Instance, true); - repository.Register("SQLite", SQLite.Instance, true); + repository.Register("SQLite", SQLiteLegacy.Instance, true); repository.Register("LogWatcher", LogWatcher.Instance, true); repository.Register("Discord", Discord.Instance, true); repository.Register("AssetBundleCacher", AssetBundleCacher.Instance, true); diff --git a/VRCX.csproj b/VRCX.csproj index 6ac0542d..1d0526d3 100644 --- a/VRCX.csproj +++ b/VRCX.csproj @@ -1,4 +1,4 @@ - + @@ -87,11 +87,13 @@ + + @@ -101,7 +103,7 @@ - + @@ -123,6 +125,8 @@ + + Form @@ -222,6 +226,7 @@ - xcopy /y "$(ProjectDir)OpenVR\win64\openvr_api.dll" + xcopy /y "$(ProjectDir)OpenVR\win64\openvr_api.dll" +xcopy /y "$(ProjectDir)lib\sqlite3.dll" \ No newline at end of file diff --git a/WebApi.cs b/WebApi.cs index ac823abc..12e658fd 100644 --- a/WebApi.cs +++ b/WebApi.cs @@ -60,8 +60,8 @@ namespace VRCX internal void LoadCookies() { - SQLite.Instance.ExecuteNonQuery("CREATE TABLE IF NOT EXISTS `cookies` (`key` TEXT PRIMARY KEY, `value` TEXT)"); - SQLite.Instance.Execute((values) => + SQLiteLegacy.Instance.ExecuteNonQuery("CREATE TABLE IF NOT EXISTS `cookies` (`key` TEXT PRIMARY KEY, `value` TEXT)"); + SQLiteLegacy.Instance.Execute((values) => { try { @@ -92,7 +92,7 @@ namespace VRCX using (var memoryStream = new MemoryStream()) { new BinaryFormatter().Serialize(memoryStream, _cookieContainer); - SQLite.Instance.ExecuteNonQuery( + SQLiteLegacy.Instance.ExecuteNonQuery( "INSERT OR REPLACE INTO `cookies` (`key`, `value`) VALUES (@key, @value)", new Dictionary() { {"@key", "default"}, diff --git a/WorldDBManager.cs b/WorldDBManager.cs new file mode 100644 index 00000000..768e963d --- /dev/null +++ b/WorldDBManager.cs @@ -0,0 +1,423 @@ +using System.Linq; +using System.Text; +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Runtime.Serialization.Formatters.Binary; +using System.Threading.Tasks; +using CefSharp; +using Newtonsoft.Json; + +namespace VRCX +{ + public class WorldDBManager + { + public static WorldDBManager Instance; + private readonly HttpListener listener; + private readonly WorldDatabase worldDB; + + private string currentWorldId = null; + private string lastError = null; + + public WorldDBManager(string url) + { + Instance = this; + // http://localhost:22500 + listener = new HttpListener(); + listener.Prefixes.Add(url); + + worldDB = new WorldDatabase(Path.Combine(Program.AppDataDirectory, "VRCX-WorldData.db")); + } + + public async Task Start() + { + listener.Start(); + + while (true) + { + var context = await listener.GetContextAsync(); + var request = context.Request; + var responseData = new WorldDataRequestResponse(false, null, null); + + if (MainForm.Instance?.Browser == null || MainForm.Instance.Browser.IsLoading || !MainForm.Instance.Browser.CanExecuteJavascriptInMainFrame) + { + responseData.Error = "VRCX not yet initialized. Try again in a moment."; + responseData.StatusCode = 503; + SendJsonResponse(context.Response, responseData); + continue; + }; + + switch (request.Url.LocalPath) + { + case "/vrcx/data/init": + responseData = await HandleInitRequest(context); + SendJsonResponse(context.Response, responseData); + break; + case "/vrcx/data/get": + responseData = await HandleDataRequest(context); + SendJsonResponse(context.Response, responseData); + break; + case "/vrcx/data/lasterror": + responseData.OK = lastError == null; + responseData.Data = lastError; + lastError = null; + SendJsonResponse(context.Response, responseData); + break; + case "/vrcx/data/getbulk": + responseData = await HandleBulkDataRequest(context); + SendJsonResponse(context.Response, responseData); + break; + case "/vrcx/status": + context.Response.StatusCode = 200; + context.Response.Close(); + break; + default: + responseData.Error = "Invalid VRCX endpoint."; + responseData.StatusCode = 404; + SendJsonResponse(context.Response, responseData); + break; + } + } + + } + + /// + /// Handles an HTTP listener request to initialize a connection to the world db manager. + /// + /// The HTTP listener context object. + /// A object containing the response data. + private async Task HandleInitRequest(HttpListenerContext context) + { + var request = context.Request; + var responseData = new WorldDataRequestResponse(false, null, null); + + if (request.QueryString["debug"] == "true") + { + if (!worldDB.DoesWorldExist("wrld_12345")) + { + worldDB.AddWorld("wrld_12345", "12345"); + worldDB.AddDataEntry("wrld_12345", "test", "testvalue"); + } + + currentWorldId = "wrld_12345"; + responseData.OK = true; + responseData.StatusCode = 200; + responseData.Data = "12345"; + return responseData; + } + + string worldId = await GetCurrentWorldID(); + + if (String.IsNullOrEmpty(worldId)) + { + responseData.Error = "Failed to get/verify current world ID."; + responseData.StatusCode = 500; + return responseData; + } + + currentWorldId = worldId; + + var existsInDB = worldDB.DoesWorldExist(currentWorldId); + string connectionKey; + + if (!existsInDB) + { + connectionKey = GenerateWorldConnectionKey(); + worldDB.AddWorld(currentWorldId, connectionKey); + } + else + { + connectionKey = worldDB.GetWorldConnectionKey(currentWorldId); + } + + responseData.OK = true; + responseData.StatusCode = 200; + responseData.Data = connectionKey; + return responseData; + } + + /// + /// Handles an HTTP listener request for data from the world database. + /// + /// The HTTP listener context object. + /// A object containing the response data. + private async Task HandleDataRequest(HttpListenerContext context) + { + var request = context.Request; + var responseData = new WorldDataRequestResponse(false, null, null); + + var key = request.QueryString["key"]; + if (key == null) + { + responseData.Error = "Missing key parameter."; + responseData.StatusCode = 400; + return responseData; + } + + var worldIdOverride = request.QueryString["world"]; + + if (worldIdOverride != null) + { + var world = worldDB.GetWorld(worldIdOverride); + + if (world == null) + { + responseData.OK = false; + responseData.Error = $"World ID '{worldIdOverride}' not initialized in this user's database."; + responseData.StatusCode = 200; + responseData.Data = null; + return responseData; + } + + if (!world.AllowExternalRead) + { + responseData.OK = false; + responseData.Error = $"World ID '{worldIdOverride}' does not allow external reads."; + responseData.StatusCode = 200; + responseData.Data = null; + return responseData; + } + } + + if (currentWorldId == "wrld_12345" && worldIdOverride == null) + worldIdOverride = "wrld_12345"; + + var worldId = worldIdOverride ?? await GetCurrentWorldID(); + + if (worldIdOverride == null && (String.IsNullOrEmpty(currentWorldId) || worldId != currentWorldId)) + { + responseData.Error = "World ID not initialized."; + responseData.StatusCode = 400; + return responseData; + } + + var value = worldDB.GetDataEntry(worldId, key); + + responseData.OK = true; + responseData.StatusCode = 200; + responseData.Error = null; + responseData.Data = value?.Value; + return responseData; + } + + /// + /// Handles an HTTP listener request for bulk data from the world database. + /// + /// The HTTP listener context object. + /// A object containing the response data. + private async Task HandleBulkDataRequest(HttpListenerContext context) + { + var request = context.Request; + var responseData = new WorldDataRequestResponse(false, null, null); + + var keys = request.QueryString["keys"]; + if (keys == null) + { + responseData.Error = "Missing/invalid keys parameter."; + responseData.StatusCode = 400; + return responseData; + } + + var keyArray = keys.Split(','); + + var worldId = await GetCurrentWorldID(); + + if (String.IsNullOrEmpty(currentWorldId) || (worldId != currentWorldId && currentWorldId != "wrld_12345")) + { + responseData.Error = "World ID not initialized."; + responseData.StatusCode = 400; + return responseData; + } + + var values = worldDB.GetDataEntries(currentWorldId, keyArray).ToList(); + + /*if (values == null) + { + responseData.Error = $"No data found for keys '{keys}' under world id '{currentWorldId}'."; + responseData.StatusCode = 404; + return responseData; + }*/ + + // Build a dictionary of key/value pairs to send back. If a key doesn't exist in the database, the key will be included in the response as requested but with a null value. + var data = new Dictionary(); + for (int i = 0; i < keyArray.Length; i++) + { + string dataKey = keyArray[i]; + string dataValue = values?.Where(x => x.Key == dataKey).FirstOrDefault()?.Value; // Get the value from the list of data entries, if it exists, otherwise null + + data.Add(dataKey, dataValue); + } + + responseData.OK = true; + responseData.StatusCode = 200; + responseData.Error = null; + responseData.Data = JsonConvert.SerializeObject(data); + return responseData; + } + + /// + /// Generates a unique identifier for a world connection request. + /// + /// A string representation of a GUID that can be used to identify the world on requests. + private string GenerateWorldConnectionKey() + { + // Ditched the old method of generating a short key, since we're just going with json anyway who cares about a longer identifier + // Since we can rely on this GUID being unique, we can use it to identify the world on requests instead of trying to keep track of the user's current world. + // I uhh, should probably make sure this is actually unique though. Just in case. I'll do that later. + return Guid.NewGuid().ToString(); + } + + /// + /// Gets the ID of the current world by evaluating a JavaScript function in the main browser instance. + /// + /// The ID of the current world as a string, or null if it could not be retrieved. + private async Task GetCurrentWorldID() + { + JavascriptResponse funcResult = await MainForm.Instance.Browser.EvaluateScriptAsync("$app.API.actuallyGetCurrentLocation();", TimeSpan.FromSeconds(5)); + + try + { + funcResult = await MainForm.Instance.Browser.EvaluateScriptAsync("$app.API.actuallyGetCurrentLocation();", TimeSpan.FromSeconds(5)); + } + catch (Exception ex) + { + return null; + } + + string worldId = funcResult?.Result?.ToString(); + + if (String.IsNullOrEmpty(worldId)) + { + // implement + // wait what was i going to do here again + // seriously i forgot, hope it wasn't important + return null; + } + + return worldId; + } + + /// + /// Sends a JSON response to an HTTP listener request with the specified response data and status code. + /// + /// The HTTP listener response object. + /// The response data to be serialized to JSON. + /// The HTTP status code to be returned. + /// The HTTP listener response object. + private HttpListenerResponse SendJsonResponse(HttpListenerResponse response, WorldDataRequestResponse responseData) + { + response.ContentType = "application/json"; + response.StatusCode = responseData.StatusCode; + response.AddHeader("Cache-Control", "no-cache"); + + // Use newtonsoft.json to serialize WorldDataRequestResponse to json + var json = JsonConvert.SerializeObject(responseData); + var buffer = System.Text.Encoding.UTF8.GetBytes(json); + response.ContentLength64 = buffer.Length; + response.OutputStream.Write(buffer, 0, buffer.Length); + response.Close(); + return response; + } + + /// + /// Processes a JSON request containing world data and logs it to the world database. + /// + /// The JSON request containing the world data. + public async void ProcessLogWorldDataRequest(string json) + { + // Current format: + // { + // "requestType": "store", + // "connectionKey": "abc123", + // "key": "example_key", + // "value": "example_value" + // } + + // * I could rate limit the processing of this, but I don't think it's necessary. + // * At the amount of data you'd need to be spitting out to lag vrcx, you'd fill up the log file and lag out VRChat far before VRCX would have any issues; at least in my testing. + // As long as malicious worlds can't permanently *store* stupid amounts of unculled data, this is pretty safe with the 10MB cap. If a world wants to just fill up a users HDD with logs, they can do that already anyway. + + WorldDataRequest request; + + try // try to deserialize the json into a WorldDataRequest object + { + request = JsonConvert.DeserializeObject(json); + } + catch (JsonReaderException ex) + { + this.lastError = ex.Message; + // invalid json + return; + } + catch (Exception ex) + { + this.lastError = ex.Message; + // something else happened lol + return; + } + + if (String.IsNullOrEmpty(request.Key)) + { + this.lastError = "`key` is missing or null"; + return; + } + + if (String.IsNullOrEmpty(request.Value)) + { + this.lastError = "`value` is missing or null"; + return; + } + + if (String.IsNullOrEmpty(request.ConnectionKey)) + { + this.lastError = "`connectionKey` is missing or null"; + return; + } + + // Make sure the connection key is a valid GUID. No point in doing anything else if it's not. + if (!Guid.TryParse(request.ConnectionKey, out Guid _)) + { + this.lastError = "Invalid GUID provided as connection key"; + // invalid guid + return; + } + + // Get the world ID from the connection key + string worldId = worldDB.GetWorldByConnectionKey(request.ConnectionKey); + if (worldId == null) + { + this.lastError = "Invalid connection key"; + // invalid connection key + return; + } + + // Get/calculate the old and new data sizes for this key/the world + int oldTotalDataSize = worldDB.GetWorldDataSize(worldId); + int oldDataSize = worldDB.GetDataEntrySize(worldId, request.Key); + int newDataSize = Encoding.UTF8.GetByteCount(request.Value); + int newTotalDataSize = oldTotalDataSize + newDataSize - oldDataSize; + + // Make sure we don't exceed 10MB total size for this world + // This works, I tested it. Hopefully this prevents/limits any possible abuse. + if (newTotalDataSize > 1024 * 1024 * 10) + { + this.lastError = $"You have hit the 10MB total data cap. The previous data entry was *not* stored. Your request was {newDataSize} bytes, your current shared byte total is {oldTotalDataSize} and you went over the table limit by {newTotalDataSize - (1024 * 1024 * 10)} bytes."; + // too much data + //throw new Exception("Too much data"); + return; + } + + + worldDB.AddDataEntry(worldId, request.Key, request.Value, newDataSize); + worldDB.UpdateWorldDataSize(worldId, newTotalDataSize); + } + + public void Stop() + { + listener.Stop(); + listener.Close(); + worldDB.Close(); + } + } +} \ No newline at end of file diff --git a/WorldDataRequestResponse.cs b/WorldDataRequestResponse.cs new file mode 100644 index 00000000..0523e3f9 --- /dev/null +++ b/WorldDataRequestResponse.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Newtonsoft.Json; + +namespace VRCX +{ + public class WorldDataRequestResponse + { + /// + /// Gets or sets a value indicating whether the request was successful. + /// + [JsonProperty("ok")] + public bool OK { get; set; } + /// + /// Gets or sets the error message if the request was not successful. + /// + [JsonProperty("error")] + public string Error { get; set; } + /// + /// Gets or sets the data returned by the request. + /// + [JsonProperty("data")] + public string Data { get; set; } + /// + /// Gets or sets the response code. + /// + /// + [JsonProperty("statusCode")] + public int StatusCode { get; set; } + + public WorldDataRequestResponse(bool ok, string error, string data) + { + OK = ok; + Error = error; + Data = data; + } + } + + public class WorldDataRequest + { + [JsonProperty("requestType")] + public string RequestType; + [JsonProperty("connectionKey")] + public string ConnectionKey; + [JsonProperty("key")] + public string Key; + [JsonProperty("value")] + public string Value; + } +} diff --git a/WorldDatabase.cs b/WorldDatabase.cs new file mode 100644 index 00000000..34030371 --- /dev/null +++ b/WorldDatabase.cs @@ -0,0 +1,270 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using SQLite; + +namespace VRCX +{ + [Table("data")] + public class WorldData + { + [PrimaryKey, AutoIncrement] + [Column("id")] + public int Id { get; set; } + [Column("world_id"), NotNull] + public string WorldId { get; set; } + [Column("key"), NotNull] + public string Key { get; set; } + [Column("value"), NotNull] + public string Value { get; set; } + [Column("value_size"), NotNull] + public int ValueSize { get; set; } + [Column("last_accessed")] + public DateTimeOffset LastAccessed { get; set; } + [Column("last_modified")] + public DateTimeOffset LastModified { get; set; } + } + + [Table("worlds")] + public class World + { + [PrimaryKey, AutoIncrement] + [Column("id")] + public int Id { get; set; } + [Column("world_id"), NotNull] + public string WorldId { get; set; } + [Column("connection_key"), NotNull] + public string ConnectionKey { get; set; } + [Column("total_data_size"), NotNull] + public int TotalDataSize { get; set; } + [Column("allow_external_read")] + public bool AllowExternalRead { get; set; } + } + + internal class WorldDatabase + { + private static SQLiteConnection sqlite; + private readonly static string dbInitQuery = @" +CREATE TABLE IF NOT EXISTS worlds ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + world_id TEXT NOT NULL UNIQUE, + connection_key TEXT NOT NULL, + total_data_size INTEGER DEFAULT 0, + allow_external_read INTEGER DEFAULT 0 +); +\ +CREATE TABLE IF NOT EXISTS data ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + world_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + value_size INTEGER NOT NULL DEFAULT 0, + last_accessed INTEGER DEFAULT (strftime('%s', 'now')), + last_modified INTEGER DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (world_id) REFERENCES worlds(world_id) ON DELETE CASCADE, + UNIQUE (world_id, key) +); +\ +CREATE TRIGGER IF NOT EXISTS data_update_trigger +AFTER UPDATE ON data +FOR EACH ROW +BEGIN + UPDATE data SET last_modified = (strftime('%s', 'now')) WHERE id = OLD.id; +END; +\ +CREATE TRIGGER IF NOT EXISTS data_insert_trigger +AFTER INSERT ON data +FOR EACH ROW +BEGIN + UPDATE data SET last_accessed = (strftime('%s', 'now')), last_modified = (strftime('%s', 'now')) WHERE id = NEW.id; +END;"; + public WorldDatabase(string databaseLocation) + { + var options = new SQLiteConnectionString(databaseLocation, true); + sqlite = new SQLiteConnection(options); + sqlite.Execute(dbInitQuery); + + // TODO: Split these init queries into their own functions so we can call/update them individually. + var queries = dbInitQuery.Split('\\'); + sqlite.BeginTransaction(); + foreach (var query in queries) + { + sqlite.Execute(query); + } + sqlite.Commit(); + } + + /// + /// Checks if a world with the specified ID exists in the database. + /// + /// The ID of the world to check for. + /// True if the world exists in the database, false otherwise. + public bool DoesWorldExist(string worldId) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId).Select(w => w.WorldId); + + return query.Any(); + } + + /// + /// Gets the ID of the world with the specified connection key from the database. + /// + /// The connection key of the world to get the ID for. + /// The ID of the world with the specified connection key, or null if no such world exists in the database. + public string GetWorldByConnectionKey(string connectionKey) + { + var query = sqlite.Table().Where(w => w.ConnectionKey == connectionKey).Select(w => w.WorldId); + + return query.FirstOrDefault(); + } + + /// + /// Gets the connection key for a world from the database. + /// + /// The ID of the world to get the connection key for. + /// The connection key for the specified world, or null if the world does not exist in the database. + public string GetWorldConnectionKey(string worldId) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId).Select(w => w.ConnectionKey); + + return query.FirstOrDefault(); + } + + /// + /// Sets the connection key for a world in the database. If the world already exists in the database, the connection key is updated. Otherwise, a new world is added to the database with the specified connection key. + /// + /// The ID of the world to set the connection key for. + /// The connection key to set for the world. + /// The connection key that was set. + public string SetWorldConnectionKey(string worldId, string connectionKey) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId).Select(w => w.ConnectionKey); + + if (query.Any()) + { + sqlite.Execute("UPDATE worlds SET connection_key = ? WHERE world_id = ?", connectionKey, worldId); + } + else + { + sqlite.Insert(new World() { WorldId = worldId, ConnectionKey = connectionKey }); + } + + return connectionKey; + } + + /// + /// Adds a new world to the database. + /// + /// The ID of the world to add. + /// The connection key of the world to add. + /// Thrown if a world with the specified ID already exists in the database. + public void AddWorld(string worldId, string connectionKey) + { + // * This will throw an error if the world already exists.. so don't do that + sqlite.Insert(new World() { WorldId = worldId, ConnectionKey = connectionKey }); + } + + /// + /// Gets the world with the specified ID from the database. + /// + /// The ID of the world to get. + /// The world with the specified ID, or null if no such world exists in the database. + public World GetWorld(string worldId) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId); + return query.FirstOrDefault(); + } + + /// + /// Gets the total data size shared across all rows, in bytes, for the world with the specified ID from the database. + /// + /// The ID of the world to get the total data size for. + /// The total data size for the world, in bytes. + public int GetWorldDataSize(string worldId) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId).Select(w => w.TotalDataSize); + + return query.FirstOrDefault(); + } + + /// + /// Updates the total data size, in bytes for the world with the specified ID in the database. + /// + /// The ID of the world to update the total data size for. + /// The new total data size for the world, in bytes. + public void UpdateWorldDataSize(string worldId, int size) + { + sqlite.Execute("UPDATE worlds SET total_data_size = ? WHERE world_id = ?", size, worldId); + } + + /// + /// Adds or updates a data entry in the database with the specified world ID, key, and value. + /// + /// The ID of the world to add the data entry for. + /// The key of the data entry to add or replace. + /// The value of the data entry to add or replace. + /// The size of the data entry to add or replace, in bytes. If null, the size is calculated from the value automatically. + public void AddDataEntry(string worldId, string key, string value, int? dataSize = null) + { + int byteSize = dataSize ?? Encoding.UTF8.GetByteCount(value); + + // check if entry already exists; + // INSERT OR REPLACE(InsertOrReplace method) deletes the old row and creates a new one, incrementing the id, which I don't want + var query = sqlite.Table().Where(w => w.WorldId == worldId && w.Key == key); + if (query.Any()) + { + sqlite.Execute("UPDATE data SET value = ?, value_size = ? WHERE world_id = ? AND key = ?", value, byteSize, worldId, key); + } + else + { + sqlite.Insert(new WorldData() { WorldId = worldId, Key = key, Value = value, ValueSize = byteSize }); + } + } + + /// + /// Gets the data entry with the specified world ID and key from the database. + /// + /// The ID of the world to get the data entry for. + /// The key of the data entry to get. + /// The data entry with the specified world ID and key, or null if no such data entry exists in the database. + public WorldData GetDataEntry(string worldId, string key) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId && w.Key == key); + + return query.FirstOrDefault(); + } + + /// + /// Gets the data entries with the specified world ID and keys from the database. + /// + /// The ID of the world to get the data entries for. + /// The keys of the data entries to get. + /// An enumerable collection of the data entries with the specified world ID and keys. + public IEnumerable GetDataEntries(string worldId, string[] keys) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId && keys.Contains(w.Key)); + + return query.ToList(); + } + + /// + /// Gets the size of the data entry, in bytes, with the specified world ID and key from the database. + /// + /// The ID of the world to get the data entry size for. + /// The key of the data entry to get the size for. + /// The size of the data entry with the specified world ID and key, or 0 if no such data entry exists in the database. + public int GetDataEntrySize(string worldId, string key) + { + var query = sqlite.Table().Where(w => w.WorldId == worldId && w.Key == key).Select(w => w.ValueSize); + + return query.FirstOrDefault(); + } + + public void Close() + { + sqlite.Close(); + } + } +} diff --git a/html/src/app.js b/html/src/app.js index 9a9aeecd..ac9d5d37 100644 --- a/html/src/app.js +++ b/html/src/app.js @@ -2092,6 +2092,47 @@ speechSynthesis.getVoices(); }); }); + API.getUserApiCurrentLocation = function () { + return this.currentUser?.presence?.world; + }; + + // TODO: traveling to world checks + API.actuallyGetCurrentLocation = async function () { + const gameLogLocation = $app.lastLocation.location; + console.log('gameLog Location', gameLogLocation); + + const presence = this.currentUser.presence; + let presenceLocation = this.currentUser.$locationTag; + if (presenceLocation === 'traveling') { + console.log("User is traveling, using $travelingToLocation", this.currentUser.$travelingToLocation); + presenceLocation = this.currentUser.$travelingToLocation; + } + console.log('presence Location', presenceLocation); + + // We want to use presence if it's valid to avoid extra API calls, but its prone to being outdated when this function is called. + // So we check if the presence location is the same as the gameLog location; If it is, the presence is (probably) valid and we can use it. + // If it's not, we need to get the user manually to get the correct location. + // If the user happens to be offline or the api is just being dumb, we assume that the user logged into VRCX is different than the one in-game and return the gameLog location. + // This is really dumb. + if (presenceLocation === gameLogLocation) { + console.log('ok presence return'); + return presence.world; + } + let args = await this.getUser({ userId: this.currentUser.id }); + let user = args.json + + console.log('presence bad, got user', user); + if (!$app.isRealInstance(user.location)) { + console.warn( + 'presence invalid, user offline and/or instance invalid. returning gamelog location: ', + gameLogLocation + ); + return gameLogLocation; + } + console.warn('presence outdated, got user api location instead: ', user.location); + return this.parseLocation(user.location).worldId; + }; + API.applyWorld = function (json) { var ref = this.cachedWorlds.get(json.id); if (typeof ref === 'undefined') { diff --git a/lib/sqlite3.dll b/lib/sqlite3.dll new file mode 100644 index 00000000..1544c2db Binary files /dev/null and b/lib/sqlite3.dll differ